diff --git a/packages/bun-usockets/src/context.c b/packages/bun-usockets/src/context.c index 46a75ac5ea..16ff482a1d 100644 --- a/packages/bun-usockets/src/context.c +++ b/packages/bun-usockets/src/context.c @@ -128,7 +128,7 @@ void us_internal_socket_context_link_listen_socket(struct us_socket_context_t *c context->head_listen_sockets->s.prev = &ls->s; } context->head_listen_sockets = ls; - context->ref_count++; + us_socket_context_ref(0, context); } /* We always add in the top, so we don't modify any s.next */ @@ -140,7 +140,7 @@ void us_internal_socket_context_link_socket(struct us_socket_context_t *context, context->head_sockets->prev = s; } context->head_sockets = s; - context->ref_count++; + us_socket_context_ref(0, context); } struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *context) { @@ -277,9 +277,8 @@ struct us_bun_verify_error_t us_socket_verify_error(int ssl, struct us_socket_t return (struct us_bun_verify_error_t) { .error = 0, .code = NULL, .reason = NULL }; } - - void us_internal_socket_context_free(int ssl, struct us_socket_context_t *context) { + #ifndef LIBUS_NO_SSL if (ssl) { /* This function will call us again with SSL=false */ @@ -300,9 +299,10 @@ void us_internal_socket_context_free(int ssl, struct us_socket_context_t *contex void us_socket_context_ref(int ssl, struct us_socket_context_t *context) { context->ref_count++; } - void us_socket_context_unref(int ssl, struct us_socket_context_t *context) { - if (--context->ref_count == 0) { + uint32_t ref_count = context->ref_count; + context->ref_count--; + if (ref_count == 1) { us_internal_socket_context_free(ssl, context); } } @@ -481,6 +481,7 @@ void *us_socket_context_connect(int ssl, struct us_socket_context_t *context, co struct us_connecting_socket_t *c = us_calloc(1, sizeof(struct us_connecting_socket_t) + socket_ext_size); c->socket_ext_size = socket_ext_size; c->context = context; + us_socket_context_ref(ssl, context); c->options = options; c->ssl = ssl > 0; c->timeout = 255; @@ -548,7 +549,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) { c->pending_resolve_callback = 0; // if the socket was closed while we were resolving the address, free it if (c->closed) { - us_connecting_socket_free(c); + us_connecting_socket_free(c->ssl, c); return; } struct addrinfo_result *result = Bun__addrinfo_getRequestResult(c->addrinfo_req); @@ -556,7 +557,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) { c->error = result->error; c->context->on_connect_error(c, result->error); Bun__addrinfo_freeRequest(c->addrinfo_req, 0); - us_connecting_socket_close(0, c); + us_connecting_socket_close(c->ssl, c); return; } @@ -567,7 +568,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) { c->error = ECONNREFUSED; c->context->on_connect_error(c, ECONNREFUSED); Bun__addrinfo_freeRequest(c->addrinfo_req, 1); - us_connecting_socket_close(0, c); + us_connecting_socket_close(c->ssl, c); return; } } @@ -638,7 +639,7 @@ void us_internal_socket_after_open(struct us_socket_t *s, int error) { c->error = ECONNREFUSED; c->context->on_connect_error(c, error); Bun__addrinfo_freeRequest(c->addrinfo_req, ECONNREFUSED); - us_connecting_socket_close(0, c); + us_connecting_socket_close(c->ssl, c); } } } else { @@ -667,7 +668,7 @@ void us_internal_socket_after_open(struct us_socket_t *s, int error) { } // now that the socket is open, we can release the associated us_connecting_socket_t if it exists Bun__addrinfo_freeRequest(c->addrinfo_req, 0); - us_connecting_socket_free(c); + us_connecting_socket_free(c->ssl, c); s->connect_state = NULL; } diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index ea8de154be..435f8658f9 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -1820,6 +1820,7 @@ ssl_wrapped_context_on_close(struct us_internal_ssl_socket_t *s, int code, wrapped_context->old_events.on_close((struct us_socket_t *)s, code, reason); } + us_socket_context_unref(0, wrapped_context->tcp_context); return s; } @@ -1976,6 +1977,7 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls( } struct us_socket_context_t *old_context = us_socket_context(0, s); + us_socket_context_ref(0,old_context); struct us_socket_context_t *context = us_create_bun_socket_context( 1, old_context->loop, sizeof(struct us_wrapped_socket_context_t), @@ -1998,6 +2000,7 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls( }; wrapped_context->old_events = old_events; wrapped_context->events = events; + wrapped_context->tcp_context = old_context; // no need to wrap open because socket is already open (only new context will // be called so we can configure hostname and ssl stuff normally here before diff --git a/packages/bun-usockets/src/internal/internal.h b/packages/bun-usockets/src/internal/internal.h index 5c4b6245e2..6807e6fc90 100644 --- a/packages/bun-usockets/src/internal/internal.h +++ b/packages/bun-usockets/src/internal/internal.h @@ -189,6 +189,7 @@ struct us_connecting_socket_t { }; struct us_wrapped_socket_context_t { + struct us_socket_context_t* tcp_context; struct us_socket_events_t events; struct us_socket_events_t old_events; }; diff --git a/packages/bun-usockets/src/libusockets.h b/packages/bun-usockets/src/libusockets.h index 08f5647348..b6d9733f3e 100644 --- a/packages/bun-usockets/src/libusockets.h +++ b/packages/bun-usockets/src/libusockets.h @@ -314,7 +314,7 @@ struct us_socket_t *us_socket_context_connect_unix(int ssl, us_socket_context_r * Can also be used to determine if a socket is a listen_socket or not, but you probably know that already. */ int us_socket_is_established(int ssl, us_socket_r s) nonnull_fn_decl; -void us_connecting_socket_free(struct us_connecting_socket_t *c) nonnull_fn_decl; +void us_connecting_socket_free(int ssl, struct us_connecting_socket_t *c) nonnull_fn_decl; /* Cancel a connecting socket. Can be used together with us_socket_timeout to limit connection times. * Entirely destroys the socket - this function works like us_socket_close but does not trigger on_close event since diff --git a/packages/bun-usockets/src/socket.c b/packages/bun-usockets/src/socket.c index bee0431d13..c53645dcc8 100644 --- a/packages/bun-usockets/src/socket.c +++ b/packages/bun-usockets/src/socket.c @@ -129,19 +129,20 @@ int us_socket_is_established(int ssl, struct us_socket_t *s) { return us_internal_poll_type((struct us_poll_t *) s) != POLL_TYPE_SEMI_SOCKET; } -void us_connecting_socket_free(struct us_connecting_socket_t *c) { +void us_connecting_socket_free(int ssl, struct us_connecting_socket_t *c) { // we can't just free c immediately, as it may be enqueued in the dns_ready_head list // instead, we move it to a close list and free it after the iteration c->next = c->context->loop->data.closed_connecting_head; c->context->loop->data.closed_connecting_head = c; + us_socket_context_unref(ssl, c->context); } void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) { if (c->closed) return; c->closed = 1; - for (struct us_socket_t *s = c->connecting_head; s; s = s->connect_next) { us_internal_socket_context_unlink_socket(ssl, s->context, s); + us_poll_stop((struct us_poll_t *) s, s->context->loop); bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); @@ -156,7 +157,7 @@ void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) { // we can only schedule the socket to be freed if there is no pending callback // otherwise, the callback will see that the socket is closed and will free it if (!c->pending_resolve_callback) { - us_connecting_socket_free(c); + us_connecting_socket_free(ssl, c); } } diff --git a/packages/bun-uws/src/AsyncSocket.h b/packages/bun-uws/src/AsyncSocket.h index 58df7608d8..753212cc12 100644 --- a/packages/bun-uws/src/AsyncSocket.h +++ b/packages/bun-uws/src/AsyncSocket.h @@ -361,4 +361,4 @@ public: } -#endif // UWS_ASYNCSOCKET_H +#endif // UWS_ASYNCSOCKET_H \ No newline at end of file diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 70fe82f3e5..94c3529b2f 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -70,7 +70,7 @@ noinline fn getSSLException(globalThis: *JSC.JSGlobalObject, defaultMessage: []c if (written > 0) { const message = output_buf[0..written]; - zig_str = ZigString.init(std.fmt.allocPrint(bun.default_allocator, "OpenSSL {s}", .{message}) catch unreachable); + zig_str = ZigString.init(std.fmt.allocPrint(bun.default_allocator, "OpenSSL {s}", .{message}) catch bun.outOfMemory()); var encoded_str = zig_str.withEncoding(); encoded_str.mark(); @@ -136,21 +136,19 @@ const Handlers = struct { pub const Scope = struct { handlers: *Handlers, - socket_context: ?*uws.SocketContext, - pub fn exit(this: *Scope, ssl: bool, wrapped: WrappedType) void { + pub fn exit(this: *Scope) void { var vm = this.handlers.vm; defer vm.eventLoop().exit(); - this.handlers.markInactive(ssl, this.socket_context, wrapped); + this.handlers.markInactive(); } }; - pub fn enter(this: *Handlers, context: ?*uws.SocketContext) Scope { + pub fn enter(this: *Handlers) Scope { this.markActive(); this.vm.eventLoop().enter(); return .{ .handlers = this, - .socket_context = context, }; } @@ -179,7 +177,7 @@ const Handlers = struct { return true; } - pub fn markInactive(this: *Handlers, ssl: bool, ctx: ?*uws.SocketContext, wrapped: WrappedType) void { + pub fn markInactive(this: *Handlers) void { Listener.log("markInactive", .{}); this.active_connections -= 1; if (this.active_connections == 0) { @@ -191,12 +189,6 @@ const Handlers = struct { } } else { this.unprotect(); - // will deinit when is not wrapped or when is the TCP wrapped connection - if (wrapped != .tls) { - if (ctx) |ctx_| { - ctx_.deinit(ssl); - } - } bun.default_allocator.destroy(this); } } @@ -501,13 +493,13 @@ pub const Listener = struct { switch (this) { .unix => |u| { return .{ - .unix = (bun.default_allocator.dupe(u8, u) catch unreachable), + .unix = (bun.default_allocator.dupe(u8, u) catch bun.outOfMemory()), }; }, .host => |h| { return .{ .host = .{ - .host = (bun.default_allocator.dupe(u8, h.host) catch unreachable), + .host = (bun.default_allocator.dupe(u8, h.host) catch bun.outOfMemory()), .port = this.host.port, }, }; @@ -665,15 +657,15 @@ pub const Listener = struct { } var connection: Listener.UnixOrHost = if (port) |port_| .{ - .host = .{ .host = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch unreachable).slice(), .port = port_ }, + .host = .{ .host = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch bun.outOfMemory()).slice(), .port = port_ }, } else .{ - .unix = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch unreachable).slice(), + .unix = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch bun.outOfMemory()).slice(), }; const listen_socket: *uws.ListenSocket = brk: { switch (connection) { .host => |c| { - const host = bun.default_allocator.dupeZ(u8, c.host) catch unreachable; + const host = bun.default_allocator.dupeZ(u8, c.host) catch bun.outOfMemory(); defer bun.default_allocator.free(host); const socket = uws.us_socket_context_listen( @@ -691,7 +683,7 @@ pub const Listener = struct { break :brk socket; }, .unix => |u| { - const host = bun.default_allocator.dupeZ(u8, u) catch unreachable; + const host = bun.default_allocator.dupeZ(u8, u) catch bun.outOfMemory(); defer bun.default_allocator.free(host); break :brk uws.us_socket_context_listen_unix(@intFromBool(ssl_enabled), socket_context, host, host.len, socket_flags, 8); }, @@ -730,7 +722,7 @@ pub const Listener = struct { .ssl = ssl_enabled, .socket_context = socket_context, .listener = listen_socket, - .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch bun.outOfMemory()) else null, }; socket.handlers.protect(); @@ -788,12 +780,16 @@ pub const Listener = struct { .socket = socket, .protos = listener.protos, .flags = .{ .owned_protos = false }, + .socket_context = null, // dont own the socket context }); + this_socket.ref(); if (listener.strong_data.get()) |default_data| { const globalObject = listener.handlers.globalObject; Socket.dataSetCached(this_socket.getThisValue(globalObject), globalObject, default_data); } - socket.ext(**anyopaque).* = bun.cast(**anyopaque, this_socket); + if (socket.ext(**anyopaque)) |ctx| { + ctx.* = bun.cast(**anyopaque, this_socket); + } socket.setTimeout(120000); } @@ -1004,9 +1000,9 @@ pub const Listener = struct { } } if (port) |_| { - break :blk .{ .host = .{ .host = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch unreachable).slice(), .port = port.? } }; + break :blk .{ .host = .{ .host = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch bun.outOfMemory()).slice(), .port = port.? } }; } - break :blk .{ .unix = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch unreachable).slice() }; + break :blk .{ .unix = (hostname_or_unix.cloneIfNeeded(bun.default_allocator) catch bun.outOfMemory()).slice() }; }; if (ssl_enabled) { @@ -1014,7 +1010,7 @@ pub const Listener = struct { protos = p[0..ssl.?.protos_len]; } if (ssl.?.server_name) |s| { - server_name = bun.default_allocator.dupe(u8, s[0..bun.len(s)]) catch unreachable; + server_name = bun.default_allocator.dupe(u8, s[0..bun.len(s)]) catch bun.outOfMemory(); } uws.NewSocketHandler(true).configure( socket_context, @@ -1063,18 +1059,20 @@ pub const Listener = struct { var tls = TLSSocket.new(.{ .handlers = handlers_ptr, .this_value = .zero, - .socket = undefined, + .socket = TLSSocket.Socket.detached, .connection = connection, - .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch bun.outOfMemory()) else null, .server_name = server_name, + .socket_context = socket_context, // owns the socket context }); TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data); - tls.doConnect(connection, socket_context) catch { - tls.handleConnectError(socket_context, @intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED)); + tls.doConnect(connection) catch { + tls.handleConnectError(@intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED)); return promise_value; }; + tls.poll_ref.ref(handlers.vm); return promise_value; @@ -1082,16 +1080,16 @@ pub const Listener = struct { var tcp = TCPSocket.new(.{ .handlers = handlers_ptr, .this_value = .zero, - .socket = undefined, + .socket = TCPSocket.Socket.detached, .connection = null, .protos = null, .server_name = null, + .socket_context = socket_context, // owns the socket context }); TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data); - - tcp.doConnect(connection, socket_context) catch { - tcp.handleConnectError(socket_context, @intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED)); + tcp.doConnect(connection) catch { + tcp.handleConnectError(@intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED)); return promise_value; }; tcp.poll_ref.ref(handlers.vm); @@ -1142,6 +1140,8 @@ fn NewSocket(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); socket: Socket, + // if the socket owns a context it will be here + socket_context: ?*uws.SocketContext, flags: Flags = .{}, ref_count: u32 = 1, @@ -1174,7 +1174,6 @@ fn NewSocket(comptime ssl: bool) type { is_active: bool = false, /// Prevent onClose from calling into JavaScript while we are finalizing finalizing: bool = false, - detached: bool = true, authorized: bool = false, owned_protos: bool = true, }; @@ -1189,29 +1188,28 @@ fn NewSocket(comptime ssl: bool) type { return this.has_pending_activity.load(.acquire); } - pub fn doConnect(this: *This, connection: Listener.UnixOrHost, socket_ctx: *uws.SocketContext) !void { + pub fn doConnect(this: *This, connection: Listener.UnixOrHost) !void { + bun.assert(this.socket_context != null); switch (connection) { .host => |c| { - _ = try This.Socket.connectPtr( + this.ref(); + this.socket = try This.Socket.connectAnon( normalizeHost(c.host), c.port, - socket_ctx, - This, + this.socket_context.?, this, - "socket", ); }, .unix => |u| { - _ = try This.Socket.connectUnixPtr( + this.ref(); + this.socket = try This.Socket.connectUnixAnon( u, - socket_ctx, - This, + this.socket_context.?, this, - "socket", ); }, .fd => |f| { - const socket = This.Socket.fromFd(socket_ctx, f, This, this, "socket") orelse return error.ConnectionFailed; + const socket = This.Socket.fromFd(this.socket_context.?, f, This, this, null) orelse return error.ConnectionFailed; this.onOpen(socket); }, } @@ -1228,7 +1226,7 @@ fn NewSocket(comptime ssl: bool) type { ) void { JSC.markBinding(@src()); log("onWritable", .{}); - if (this.flags.detached) return; + if (this.socket.isDetached()) return; const handlers = this.handlers; const callback = handlers.onWritable; if (callback == .zero) return; @@ -1252,11 +1250,11 @@ fn NewSocket(comptime ssl: bool) type { } pub fn onTimeout( this: *This, - socket: Socket, + _: Socket, ) void { JSC.markBinding(@src()); log("onTimeout", .{}); - if (this.flags.detached) return; + if (this.socket.isDetached()) return; const handlers = this.handlers; const callback = handlers.onTimeout; @@ -1267,8 +1265,8 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can - var scope = handlers.enter(socket.context()); - defer scope.exit(ssl, this.wrapped); + var scope = handlers.enter(); + defer scope.exit(); const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1280,13 +1278,12 @@ fn NewSocket(comptime ssl: bool) type { _ = handlers.callErrorHandler(this_value, &[_]JSC.JSValue{ this_value, err_value }); } } - fn handleConnectError(this: *This, socket_ctx: ?*uws.SocketContext, errno: c_int) void { - log("onConnectError({d})", .{errno}); - const needs_deref = !this.flags.detached; - this.flags.detached = true; + fn handleConnectError(this: *This, errno: c_int) void { + log("onConnectError({d}, {})", .{ errno, this.ref_count }); + const needs_deref = !this.socket.isDetached(); + this.socket = Socket.detached; defer if (needs_deref) this.deref(); - - defer this.markInactive(socket_ctx); + defer this.markInactive(); const handlers = this.handlers; const vm = handlers.vm; @@ -1341,9 +1338,9 @@ fn NewSocket(comptime ssl: bool) type { this.has_pending_activity.store(false, .release); } } - pub fn onConnectError(this: *This, socket: Socket, errno: c_int) void { + pub fn onConnectError(this: *This, _: Socket, errno: c_int) void { JSC.markBinding(@src()); - this.handleConnectError(socket.context(), errno); + this.handleConnectError(errno); } pub fn markActive(this: *This) void { @@ -1354,70 +1351,75 @@ fn NewSocket(comptime ssl: bool) type { } } - pub fn markInactive(this: *This, socket_ctx: ?*uws.SocketContext) void { + pub fn closeAndDetach(this: *This, code: uws.CloseCode) void { + const socket = this.socket; + this.socket.detach(); + socket.close(code); + } + + pub fn markInactive(this: *This) void { if (this.flags.is_active) { - if (!this.flags.detached) { - // we have to close the socket before the socket context is closed - // otherwise we will get a segfault - // uSockets will defer freeing the TCP socket until the next tick - if (!this.socket.isClosed()) { - this.socket.close(.normal); - // onClose will call markInactive again - return; - } + // we have to close the socket before the socket context is closed + // otherwise we will get a segfault + // uSockets will defer freeing the TCP socket until the next tick + if (!this.socket.isClosed()) { + this.closeAndDetach(.normal); + // onClose will call markInactive again + return; } + this.flags.is_active = false; const vm = this.handlers.vm; - this.handlers.markInactive(ssl, socket_ctx, this.wrapped); + this.handlers.markInactive(); this.poll_ref.unref(vm); this.has_pending_activity.store(false, .release); } } pub fn onOpen(this: *This, socket: Socket) void { + log("onOpen {} {}", .{ this.socket.isDetached(), this.ref_count }); // update the internal socket instance to the one that was just connected + // This socket must be replaced because the previous one is a connecting socket not a uSockets socket this.socket = socket; JSC.markBinding(@src()); log("onOpen ssl: {}", .{comptime ssl}); // Add SNI support for TLS (mongodb and others requires this) if (comptime ssl) { - var ssl_ptr = this.socket.ssl(); - - if (!ssl_ptr.isInitFinished()) { - if (this.server_name) |server_name| { - const host = normalizeHost(server_name); - if (host.len > 0) { - const host__ = default_allocator.dupeZ(u8, host) catch unreachable; - defer default_allocator.free(host__); - ssl_ptr.setHostname(host__); - } - } else if (this.connection) |connection| { - if (connection == .host) { - const host = normalizeHost(connection.host.host); + if (this.socket.ssl()) |ssl_ptr| { + if (!ssl_ptr.isInitFinished()) { + if (this.server_name) |server_name| { + const host = normalizeHost(server_name); if (host.len > 0) { - const host__ = default_allocator.dupeZ(u8, host) catch unreachable; + const host__ = default_allocator.dupeZ(u8, host) catch bun.outOfMemory(); defer default_allocator.free(host__); ssl_ptr.setHostname(host__); } + } else if (this.connection) |connection| { + if (connection == .host) { + const host = normalizeHost(connection.host.host); + if (host.len > 0) { + const host__ = default_allocator.dupeZ(u8, host) catch bun.outOfMemory(); + defer default_allocator.free(host__); + ssl_ptr.setHostname(host__); + } + } } - } - if (this.protos) |protos| { - if (this.handlers.is_server) { - BoringSSL.SSL_CTX_set_alpn_select_cb(BoringSSL.SSL_get_SSL_CTX(ssl_ptr), selectALPNCallback, bun.cast(*anyopaque, this)); - } else { - _ = BoringSSL.SSL_set_alpn_protos(ssl_ptr, protos.ptr, @as(c_uint, @intCast(protos.len))); + if (this.protos) |protos| { + if (this.handlers.is_server) { + BoringSSL.SSL_CTX_set_alpn_select_cb(BoringSSL.SSL_get_SSL_CTX(ssl_ptr), selectALPNCallback, bun.cast(*anyopaque, this)); + } else { + _ = BoringSSL.SSL_set_alpn_protos(ssl_ptr, protos.ptr, @as(c_uint, @intCast(protos.len))); + } } } } } - this.flags.detached = false; - this.socket = socket; - this.ref(); - if (this.wrapped == .none) { - socket.ext(**anyopaque).* = bun.cast(**anyopaque, this); + if (socket.ext(**anyopaque)) |ctx| { + ctx.* = bun.cast(**anyopaque, this); + } } const handlers = this.handlers; @@ -1446,7 +1448,7 @@ fn NewSocket(comptime ssl: bool) type { }); if (result.toError()) |err| { - defer this.markInactive(socket.context()); + defer this.markInactive(); if (!this.socket.isClosed()) { log("Closing due to error", .{}); } else { @@ -1469,10 +1471,10 @@ fn NewSocket(comptime ssl: bool) type { return this.this_value; } - pub fn onEnd(this: *This, socket: Socket) void { + pub fn onEnd(this: *This, _: Socket) void { JSC.markBinding(@src()); log("onEnd", .{}); - if (this.flags.detached) return; + if (this.socket.isDetached()) return; const handlers = this.handlers; @@ -1481,14 +1483,14 @@ fn NewSocket(comptime ssl: bool) type { this.poll_ref.unref(handlers.vm); // If you don't handle TCP fin, we assume you're done. - this.markInactive(this.socket.context()); + this.markInactive(); return; } // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can - var scope = handlers.enter(socket.context()); - defer scope.exit(ssl, this.wrapped); + var scope = handlers.enter(); + defer scope.exit(); const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1501,10 +1503,10 @@ fn NewSocket(comptime ssl: bool) type { } } - pub fn onHandshake(this: *This, socket: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + pub fn onHandshake(this: *This, _: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { log("onHandshake({d})", .{success}); JSC.markBinding(@src()); - if (this.flags.detached) return; + if (this.socket.isDetached()) return; const authorized = if (success == 1) true else false; this.flags.authorized = authorized; @@ -1528,8 +1530,8 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can - var scope = handlers.enter(socket.context()); - defer scope.exit(ssl, this.wrapped); + var scope = handlers.enter(); + defer scope.exit(); const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1577,12 +1579,12 @@ fn NewSocket(comptime ssl: bool) type { } } - pub fn onClose(this: *This, socket: Socket, err: c_int, _: ?*anyopaque) void { + pub fn onClose(this: *This, _: Socket, err: c_int, _: ?*anyopaque) void { JSC.markBinding(@src()); log("onClose", .{}); + this.socket.detach(); defer this.deref(); - this.flags.detached = true; - defer this.markInactive(socket.context()); + defer this.markInactive(); if (this.flags.finalizing) { return; @@ -1603,8 +1605,8 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can - var scope = handlers.enter(socket.context()); - defer scope.exit(ssl, this.wrapped); + var scope = handlers.enter(); + defer scope.exit(); const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -1618,10 +1620,10 @@ fn NewSocket(comptime ssl: bool) type { } } - pub fn onData(this: *This, socket: Socket, data: []const u8) void { + pub fn onData(this: *This, _: Socket, data: []const u8) void { JSC.markBinding(@src()); log("onData({d})", .{data.len}); - if (this.flags.detached) return; + if (this.socket.isDetached()) return; const handlers = this.handlers; const callback = handlers.onData; @@ -1636,8 +1638,8 @@ fn NewSocket(comptime ssl: bool) type { // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can - var scope = handlers.enter(socket.context()); - defer scope.exit(ssl, this.wrapped); + var scope = handlers.enter(); + defer scope.exit(); // const encoding = handlers.encoding; const result = callback.call(globalObject, this_value, &[_]JSValue{ @@ -1672,7 +1674,7 @@ fn NewSocket(comptime ssl: bool) type { this: *This, _: *JSC.JSGlobalObject, ) JSValue { - if (!this.handlers.is_server or this.flags.detached) { + if (!this.handlers.is_server or this.socket.isDetached()) { return JSValue.jsUndefined(); } @@ -1686,7 +1688,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { log("getReadyState()", .{}); - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsNumber(@as(i32, -1)); } else if (this.socket.isClosed()) { return JSValue.jsNumber(@as(i32, 0)); @@ -1713,7 +1715,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { JSC.markBinding(@src()); const args = callframe.arguments(1); - if (this.flags.detached) return JSValue.jsUndefined(); + if (this.socket.isDetached()) return JSValue.jsUndefined(); if (args.len == 0) { globalObject.throw("Expected 1 argument, got 0", .{}); return .zero; @@ -1736,7 +1738,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { JSC.markBinding(@src()); - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsNull(); } @@ -1766,7 +1768,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { JSC.markBinding(@src()); - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsNumber(@as(i32, -1)); } @@ -1787,7 +1789,7 @@ fn NewSocket(comptime ssl: bool) type { this: *This, _: *JSC.JSGlobalObject, ) JSValue { - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsUndefined(); } @@ -1798,7 +1800,7 @@ fn NewSocket(comptime ssl: bool) type { this: *This, globalThis: *JSC.JSGlobalObject, ) JSValue { - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsUndefined(); } @@ -1819,7 +1821,7 @@ fn NewSocket(comptime ssl: bool) type { } fn writeMaybeCorked(this: *This, buffer: []const u8, is_end: bool) i32 { - if (this.flags.detached or this.socket.isShutdown() or this.socket.isClosed()) { + if (this.socket.isShutdown() or this.socket.isClosed()) { return -1; } // we don't cork yet but we might later @@ -1990,8 +1992,7 @@ fn NewSocket(comptime ssl: bool) type { _: *JSC.CallFrame, ) JSValue { JSC.markBinding(@src()); - if (!this.flags.detached) - this.socket.flush(); + this.socket.flush(); return JSValue.jsUndefined(); } @@ -2002,10 +2003,7 @@ fn NewSocket(comptime ssl: bool) type { _: *JSC.CallFrame, ) JSValue { JSC.markBinding(@src()); - if (!this.flags.detached) { - this.socket.close(.failure); - } - + this.closeAndDetach(.failure); return JSValue.jsUndefined(); } @@ -2016,12 +2014,10 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { JSC.markBinding(@src()); const args = callframe.arguments(1); - if (!this.flags.detached) { - if (args.len > 0 and args.ptr[0].toBoolean()) { - this.socket.shutdownRead(); - } else { - this.socket.shutdown(); - } + if (args.len > 0 and args.ptr[0].toBoolean()) { + this.socket.shutdownRead(); + } else { + this.socket.shutdown(); } return JSValue.jsUndefined(); @@ -2038,7 +2034,7 @@ fn NewSocket(comptime ssl: bool) type { log("end({d} args)", .{args.len}); - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsNumber(@as(i32, -1)); } @@ -2048,7 +2044,7 @@ fn NewSocket(comptime ssl: bool) type { if (result.wrote == result.total) { this.socket.flush(); // markInactive does .detached = true - this.markInactive(this.socket.context()); + this.markInactive(); } break :brk JSValue.jsNumber(result.wrote); }, @@ -2057,7 +2053,7 @@ fn NewSocket(comptime ssl: bool) type { pub fn jsRef(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) JSValue { JSC.markBinding(@src()); - if (this.flags.detached) return JSValue.jsUndefined(); + if (this.socket.isDetached()) return JSValue.jsUndefined(); this.poll_ref.ref(globalObject.bunVM()); return JSValue.jsUndefined(); } @@ -2069,7 +2065,7 @@ fn NewSocket(comptime ssl: bool) type { } pub fn deinit(this: *This) void { - this.markInactive(null); + this.markInactive(); this.poll_ref.unref(JSC.VirtualMachine.get()); // need to deinit event without being attached @@ -2089,16 +2085,18 @@ fn NewSocket(comptime ssl: bool) type { this.connection = null; connection.deinit(); } + if (this.socket_context) |socket_context| { + this.socket_context = null; + socket_context.deinit(ssl); + } this.destroy(); } pub fn finalize(this: *This) void { - log("finalize() {d}", .{@intFromPtr(this)}); + log("finalize() {d} {}", .{ @intFromPtr(this), this.socket_context != null }); this.flags.finalizing = true; - if (!this.flags.detached) { - if (!this.socket.isClosed()) { - this.socket.close(.failure); - } + if (!this.socket.isClosed()) { + this.closeAndDetach(.failure); } this.deref(); @@ -2112,7 +2110,7 @@ fn NewSocket(comptime ssl: bool) type { return .zero; } - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsUndefined(); } @@ -2150,11 +2148,7 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - if (this.flags.detached) { - return JSValue.jsUndefined(); - } - - const ssl_ptr = this.socket.ssl(); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); BoringSSL.SSL_set_renegotiate_mode(ssl_ptr, BoringSSL.ssl_renegotiate_never); return JSValue.jsUndefined(); } @@ -2167,7 +2161,7 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsUndefined(); } @@ -2208,11 +2202,8 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - if (this.flags.detached) { - return JSValue.jsUndefined(); - } - const ssl_ptr = this.socket.ssl(); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); BoringSSL.ERR_clear_error(); if (BoringSSL.SSL_renegotiate(ssl_ptr) != 1) { globalObject.throwValue(getSSLException(globalObject, "SSL_renegotiate error")); @@ -2229,11 +2220,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { - return JSValue.jsUndefined(); - } - - const ssl_ptr = this.socket.ssl(); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); const session = BoringSSL.SSL_get_session(ssl_ptr) orelse return JSValue.jsUndefined(); var ticket: [*c]const u8 = undefined; var length: usize = 0; @@ -2256,7 +2243,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsUndefined(); } @@ -2302,11 +2289,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { - return JSValue.jsUndefined(); - } - - const ssl_ptr = this.socket.ssl(); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); const session = BoringSSL.SSL_get_session(ssl_ptr) orelse return JSValue.jsUndefined(); const size = BoringSSL.i2d_SSL_SESSION(session, null); if (size <= 0) { @@ -2330,14 +2313,10 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsBoolean(false); } - if (this.flags.detached) { - return JSValue.jsBoolean(false); - } - var alpn_proto: [*c]const u8 = null; var alpn_proto_len: u32 = 0; - const ssl_ptr = this.socket.ssl(); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsBoolean(false); BoringSSL.SSL_get0_alpn_selected(ssl_ptr, &alpn_proto, &alpn_proto_len); if (alpn_proto == null or alpn_proto_len == 0) { @@ -2362,7 +2341,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsUndefined(); } @@ -2396,7 +2375,7 @@ fn NewSocket(comptime ssl: bool) type { defer label.deinit(); const label_slice = label.slice(); - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); if (args.len > 2) { const context_arg = args.ptr[2]; @@ -2450,17 +2429,14 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsNull(); } - if (this.flags.detached) { - return JSValue.jsNull(); - } - // only available for clients if (this.handlers.is_server) { return JSValue.jsNull(); } var result = JSValue.createEmptyObject(globalObject, 3); - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsNull(); + // TODO: investigate better option or compatible way to get the key // this implementation follows nodejs but for BoringSSL SSL_get_server_tmp_key will always return 0 // wich will result in a empty object @@ -2519,13 +2495,10 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { - return JSValue.jsUndefined(); - } + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); + const cipher = BoringSSL.SSL_get_current_cipher(ssl_ptr); var result = JSValue.createEmptyObject(globalObject, 3); - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); - const cipher = BoringSSL.SSL_get_current_cipher(ssl_ptr); if (cipher == null) { result.put(globalObject, ZigString.static("name"), JSValue.jsNull()); result.put(globalObject, ZigString.static("standardName"), JSValue.jsNull()); @@ -2566,11 +2539,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { - return JSValue.jsUndefined(); - } - - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); // We cannot just pass nullptr to SSL_get_peer_finished() // because it would further be propagated to memcpy(), // where the standard requirements as described in ISO/IEC 9899:2011 @@ -2598,11 +2567,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { - return JSValue.jsUndefined(); - } - - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); // We cannot just pass nullptr to SSL_get_finished() // because it would further be propagated to memcpy(), // where the standard requirements as described in ISO/IEC 9899:2011 @@ -2631,10 +2596,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsNull(); } - if (this.flags.detached) { - return JSValue.jsNull(); - } - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsNull(); const nsig = BoringSSL.SSL_get_shared_sigalgs(ssl_ptr, 0, null, null, null, null, null); @@ -2693,7 +2655,7 @@ fn NewSocket(comptime ssl: bool) type { if (hash_str != null) { const hash_str_len = bun.len(hash_str); const hash_slice = hash_str[0..hash_str_len]; - const buffer = bun.default_allocator.alloc(u8, sig_with_md.len + hash_str_len + 1) catch unreachable; + const buffer = bun.default_allocator.alloc(u8, sig_with_md.len + hash_str_len + 1) catch bun.outOfMemory(); defer bun.default_allocator.free(buffer); bun.copy(u8, buffer, sig_with_md); @@ -2701,7 +2663,7 @@ fn NewSocket(comptime ssl: bool) type { bun.copy(u8, buffer[sig_with_md.len + 1 ..], hash_slice); array.putIndex(globalObject, @as(u32, @intCast(i)), JSC.ZigString.fromUTF8(buffer).toJS(globalObject)); } else { - const buffer = bun.default_allocator.alloc(u8, sig_with_md.len + 6) catch unreachable; + const buffer = bun.default_allocator.alloc(u8, sig_with_md.len + 6) catch bun.outOfMemory(); defer bun.default_allocator.free(buffer); bun.copy(u8, buffer, sig_with_md); @@ -2722,11 +2684,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsNull(); } - if (this.flags.detached) { - return JSValue.jsNull(); - } - - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsNull(); const version = BoringSSL.SSL_get_version(ssl_ptr); if (version == null) return JSValue.jsNull(); const version_len = bun.len(version); @@ -2745,10 +2703,6 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsBoolean(false); } - if (this.flags.detached) { - return JSValue.jsBoolean(false); - } - const args = callframe.arguments(1); if (args.len < 1) { @@ -2771,7 +2725,7 @@ fn NewSocket(comptime ssl: bool) type { return .zero; } - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsBoolean(false); return JSValue.jsBoolean(BoringSSL.SSL_set_max_send_fragment(ssl_ptr, @as(usize, @intCast(size))) == 1); } pub fn getPeerCertificate( @@ -2784,10 +2738,6 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { - return JSValue.jsUndefined(); - } - const args = callframe.arguments(1); var abbreviated: bool = true; if (args.len > 0) { @@ -2799,7 +2749,7 @@ fn NewSocket(comptime ssl: bool) type { abbreviated = arg.toBoolean(); } - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); if (abbreviated) { if (this.handlers.is_server) { @@ -2837,12 +2787,7 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - - if (this.flags.detached) { - return JSValue.jsUndefined(); - } - - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + const ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); const cert = BoringSSL.SSL_get_certificate(ssl_ptr); if (cert) |x509| { @@ -2893,7 +2838,7 @@ fn NewSocket(comptime ssl: bool) type { return .zero; } - const slice = server_name.getZigString(globalObject).toOwnedSlice(bun.default_allocator) catch unreachable; + const slice = server_name.getZigString(globalObject).toOwnedSlice(bun.default_allocator) catch bun.outOfMemory(); if (this.server_name) |old| { this.server_name = slice; default_allocator.free(old); @@ -2901,21 +2846,16 @@ fn NewSocket(comptime ssl: bool) type { this.server_name = slice; } - if (this.flags.detached) { - // will be attached onOpen - return JSValue.jsUndefined(); - } - const host = normalizeHost(@as([]const u8, slice)); if (host.len > 0) { - var ssl_ptr = this.socket.ssl(); + var ssl_ptr = this.socket.ssl() orelse return JSValue.jsUndefined(); if (ssl_ptr.isInitFinished()) { // match node.js exceptions globalObject.throw("Already started.", .{}); return .zero; } - const host__ = default_allocator.dupeZ(u8, host) catch unreachable; + const host__ = default_allocator.dupeZ(u8, host) catch bun.outOfMemory(); defer default_allocator.free(host__); ssl_ptr.setHostname(host__); } @@ -2936,7 +2876,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.flags.detached) { + if (this.socket.isDetached()) { return JSValue.jsUndefined(); } @@ -3010,11 +2950,12 @@ fn NewSocket(comptime ssl: bool) type { var tls = TLSSocket.new(.{ .handlers = handlers_ptr, .this_value = .zero, - .socket = undefined, + .socket = TLSSocket.Socket.detached, .connection = if (this.connection) |c| c.clone() else null, .wrapped = .tls, - .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null, - .server_name = if (socket_config.server_name) |server_name| (bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch unreachable) else null, + .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch bun.outOfMemory()) else null, + .server_name = if (socket_config.server_name) |server_name| (bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch bun.outOfMemory()) else null, + .socket_context = null, // only set after the wrapTLS }); const tls_js_value = tls.getThisValue(globalObject); @@ -3024,7 +2965,6 @@ fn NewSocket(comptime ssl: bool) type { // reconfigure context to use the new wrapper handlers Socket.unsafeConfigure(this.socket.context().?, true, true, WrappedSocket, TCPHandler); - const old_context = this.socket.context(); const TLSHandler = NewWrappedHandler(true); const new_socket = this.socket.wrapTLS( options, @@ -3040,6 +2980,8 @@ fn NewSocket(comptime ssl: bool) type { }; tls.socket = new_socket; + tls.socket_context = new_socket.context(); // owns the new tls context that have a ref from the old one + tls.ref(); var raw_handlers_ptr = handlers.vm.allocator.create(Handlers) catch bun.outOfMemory(); raw_handlers_ptr.* = .{ @@ -3067,7 +3009,9 @@ fn NewSocket(comptime ssl: bool) type { .connection = if (this.connection) |c| c.clone() else null, .wrapped = .tcp, .protos = null, + .socket_context = null, // raw socket will dont own the context }); + raw.ref(); const raw_js_value = raw.getThisValue(globalObject); if (JSSocketType(ssl).dataGetCached(this.getThisValue(globalObject))) |raw_default_data| { @@ -3084,19 +3028,22 @@ fn NewSocket(comptime ssl: bool) type { tls.poll_ref.ref(this.handlers.vm); // mark both instances on socket data - new_socket.ext(WrappedSocket).* = .{ .tcp = raw, .tls = tls }; + if (new_socket.ext(WrappedSocket)) |ctx| { + ctx.* = .{ .tcp = raw, .tls = tls }; + } // start TLS handshake after we set ext new_socket.startTLS(!this.handlers.is_server); //detach and invalidate the old instance - this.flags.detached = true; + this.socket.detach(); + this.deref(); if (this.flags.is_active) { const vm = this.handlers.vm; this.flags.is_active = false; - // will free handlers and the old_context when hits 0 active connections + // will free handlers when hits 0 active connections // the connection can be upgraded inside a handler call so we need to garantee that it will be still alive - this.handlers.markInactive(ssl, old_context, this.wrapped); + this.handlers.markInactive(); this.poll_ref.unref(vm); this.has_pending_activity.store(false, .release); } diff --git a/src/bun.js/api/bun/subprocess.zig b/src/bun.js/api/bun/subprocess.zig index c2e80ec629..6323001159 100644 --- a/src/bun.js/api/bun/subprocess.zig +++ b/src/bun.js/api/bun/subprocess.zig @@ -2131,7 +2131,9 @@ pub const Subprocess = struct { if (subprocess.ipc_data) |*ipc_data| { if (Environment.isPosix) { - posix_ipc_info.ext(*Subprocess).* = subprocess; + if (posix_ipc_info.ext(*Subprocess)) |ctx| { + ctx.* = subprocess; + } } else { if (ipc_data.configureServer( Subprocess, diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 744c888b31..93fcf58bdf 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -83,9 +83,16 @@ pub const InternalLoopData = extern struct { pub const InternalSocket = union(enum) { done: *Socket, connecting: *ConnectingSocket, - + detached: void, + pub fn isDetached(this: InternalSocket) bool { + return this == .detached; + } + pub fn detach(this: *InternalSocket) void { + this.* = .detached; + } pub fn close(this: InternalSocket, comptime is_ssl: bool, code: CloseCode) void { switch (this) { + .detached => {}, .done => |socket| { debug("us_socket_close({d})", .{@intFromPtr(socket)}); _ = us_socket_close( @@ -109,6 +116,7 @@ pub const InternalSocket = union(enum) { return switch (this) { .done => |socket| us_socket_is_closed(@intFromBool(is_ssl), socket) > 0, .connecting => |socket| us_connecting_socket_is_closed(@intFromBool(is_ssl), socket) > 0, + .detached => true, }; } @@ -116,6 +124,7 @@ pub const InternalSocket = union(enum) { return switch (this) { .done => this.done, .connecting => null, + .detached => null, }; } @@ -123,12 +132,16 @@ pub const InternalSocket = union(enum) { return switch (this) { .done => switch (other) { .done => this.done == other.done, - .connecting => false, + .connecting, .detached => false, }, .connecting => switch (other) { - .done => false, + .done, .detached => false, .connecting => this.connecting == other.connecting, }, + .detached => switch (other) { + .detached => true, + .done, .connecting => false, + }, }; } }; @@ -138,7 +151,13 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const ssl_int: i32 = @intFromBool(is_ssl); socket: InternalSocket, const ThisSocket = @This(); - + pub const detached: NewSocketHandler(is_ssl) = NewSocketHandler(is_ssl){ .socket = .{ .detached = {} } }; + pub fn detach(this: *ThisSocket) void { + this.socket.detach(); + } + pub fn isDetached(this: ThisSocket) bool { + return this.socket.isDetached(); + } pub fn verifyError(this: ThisSocket) us_bun_verify_error_t { const socket = this.socket.get() orelse return std.mem.zeroes(us_bun_verify_error_t); const ssl_error: us_bun_verify_error_t = uws.us_socket_verify_error(comptime ssl_int, socket); @@ -154,6 +173,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { switch (this.socket) { .done => |socket| us_socket_timeout(comptime ssl_int, socket, seconds), .connecting => |socket| us_connecting_socket_timeout(comptime ssl_int, socket, seconds), + .detached => {}, } } @@ -177,6 +197,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { us_connecting_socket_long_timeout(comptime ssl_int, socket, 0); } }, + .detached => {}, } } @@ -190,19 +211,23 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { us_connecting_socket_timeout(comptime ssl_int, socket, 0); us_connecting_socket_long_timeout(comptime ssl_int, socket, minutes); }, + .detached => {}, } } pub fn startTLS(this: ThisSocket, is_client: bool) void { - const socket = this.socket.get() orelse @panic("socket is not open"); + const socket = this.socket.get() orelse return; _ = us_socket_open(comptime ssl_int, socket, @intFromBool(is_client), null, 0); } - pub fn ssl(this: ThisSocket) *BoringSSL.SSL { + pub fn ssl(this: ThisSocket) ?*BoringSSL.SSL { if (comptime is_ssl) { - return @as(*BoringSSL.SSL, @ptrCast(this.getNativeHandle())); + if(this.getNativeHandle()) |handle| { + return @as(*BoringSSL.SSL, @ptrCast(handle)); + } + return null; } - @panic("socket is not a TLS socket"); + return null; } // Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context @@ -229,7 +254,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { } if (comptime deref_) { - return (TLSSocket.from(socket)).ext(ContextType).*; + return (TLSSocket.from(socket)).ext(ContextType).?.*; } return (TLSSocket.from(socket)).ext(ContextType); @@ -289,7 +314,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { } pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket { Fields.onConnectError( - TLSSocket.from(socket).ext(ContextType).*, + TLSSocket.from(socket).ext(ContextType).?.*, TLSSocket.from(socket), code, ); @@ -328,7 +353,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { .on_long_timeout = SocketHandler.on_long_timeout, }; - const this_socket = this.socket.get() orelse @panic("socket is not open"); + const this_socket = this.socket.get() orelse return null; const socket = us_socket_wrap_with_tls(ssl_int, this_socket, options, events, socket_ext_size) orelse return null; return NewSocketHandler(true).from(socket); @@ -338,6 +363,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { return @ptrCast(switch (this.socket) { .done => |socket| us_socket_get_native_handle(comptime ssl_int, socket), .connecting => |socket| us_connecting_socket_get_native_handle(comptime ssl_int, socket), + .detached => null, } orelse return null); } @@ -361,7 +387,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { us_socket_sendfile_needs_more(socket); } - pub fn ext(this: ThisSocket, comptime ContextType: type) *ContextType { + pub fn ext(this: ThisSocket, comptime ContextType: type) ?*ContextType { const alignment = if (ContextType == *anyopaque) @sizeOf(usize) else @@ -370,6 +396,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const ptr = switch (this.socket) { .done => |sock| us_socket_ext(comptime ssl_int, sock), .connecting => |sock| us_connecting_socket_ext(comptime ssl_int, sock), + .detached => return null, }; return @as(*align(alignment) ContextType, @ptrCast(@alignCast(ptr))); @@ -380,6 +407,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { switch (this.socket) { .done => |socket| return us_socket_context(comptime ssl_int, socket), .connecting => |socket| return us_connecting_socket_context(comptime ssl_int, socket), + .detached => return null, } } @@ -434,6 +462,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { socket, ); }, + .detached => {}, } } @@ -453,6 +482,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { socket, ); }, + .detached => {}, } } @@ -470,6 +500,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { socket, ) > 0; }, + .detached => return true, } } @@ -495,6 +526,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { socket, ); }, + .detached => return 0, } } @@ -642,8 +674,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { ) ?ThisSocket { const socket_ = ThisSocket{ .socket = .{ .done = us_socket_from_fd(ctx, @sizeOf(*anyopaque), bun.socketcast(handle)) orelse return null } }; - const holder = socket_.ext(*anyopaque); - holder.* = this; + if(socket_.ext(*anyopaque)) |holder| { + holder.* = this; + } if (comptime socket_field_name) |field| { @field(this, field) = socket_; @@ -679,8 +712,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { return error.FailedToOpenSocket; const socket_ = ThisSocket{ .socket = .{ .done = socket } }; - const holder = socket_.ext(*anyopaque); - holder.* = ctx; + if(socket_.ext(*anyopaque)) |holder| { + holder.* = ctx; + } return socket_; } @@ -721,9 +755,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { ThisSocket{ .socket = .{ .connecting = @ptrCast(socket_ptr) }, }; - - const holder = socket.ext(*anyopaque); - holder.* = ptr; + if(socket.ext(*anyopaque)) |holder| { + holder.* = ptr; + } return socket; } @@ -751,7 +785,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { } if (comptime deref_) { - return (SocketHandlerType.from(socket)).ext(ContextType).*; + return (SocketHandlerType.from(socket)).ext(ContextType).?.*; } return (SocketHandlerType.from(socket)).ext(ContextType); @@ -806,7 +840,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const val = if (comptime ContextType == anyopaque) us_connecting_socket_ext(comptime ssl_int, socket) else if (comptime deref_) - SocketHandlerType.fromConnecting(socket).ext(ContextType).* + SocketHandlerType.fromConnecting(socket).ext(ContextType).?.* else SocketHandlerType.fromConnecting(socket).ext(ContextType); Fields.onConnectError( @@ -820,7 +854,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const val = if (comptime ContextType == anyopaque) us_socket_ext(comptime ssl_int, socket) else if (comptime deref_) - SocketHandlerType.from(socket).ext(ContextType).* + SocketHandlerType.from(socket).ext(ContextType).?.* else SocketHandlerType.from(socket).ext(ContextType); Fields.onConnectError( @@ -883,7 +917,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { } if (comptime deref_) { - return (ThisSocket.from(socket)).ext(ContextType).*; + return (ThisSocket.from(socket)).ext(ContextType).?.*; } return (ThisSocket.from(socket)).ext(ContextType); @@ -945,7 +979,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const val = if (comptime ContextType == anyopaque) us_connecting_socket_ext(comptime ssl_int, socket) else if (comptime deref_) - ThisSocket.fromConnecting(socket).ext(ContextType).* + ThisSocket.fromConnecting(socket).ext(ContextType).?.* else ThisSocket.fromConnecting(socket).ext(ContextType); Fields.onConnectError( @@ -959,7 +993,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const val = if (comptime ContextType == anyopaque) us_socket_ext(comptime ssl_int, socket) else if (comptime deref_) - ThisSocket.from(socket).ext(ContextType).* + ThisSocket.from(socket).ext(ContextType).?.* else ThisSocket.from(socket).ext(ContextType); @@ -1033,14 +1067,14 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const new_socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, -1) orelse return false; bun.assert(new_socket == socket); var adopted = ThisSocket.from(new_socket); - const holder = adopted.ext(*anyopaque); - holder.* = ctx; + if(adopted.ext(*anyopaque)) |holder| { + holder.* = ctx; + } @field(ctx, socket_field_name) = adopted; return true; } }; } - pub const SocketTCP = NewSocketHandler(false); pub const SocketTLS = NewSocketHandler(true); @@ -2312,7 +2346,6 @@ pub fn NewApp(comptime ssl: bool) type { } } pub fn onAborted(res: *Response, comptime UserDataType: type, comptime handler: fn (UserDataType, *Response) void, opcional_data: UserDataType) void { - const Wrapper = struct { pub fn handle(this: *uws_res, user_data: ?*anyopaque) callconv(.C) void { if (comptime UserDataType == void) { diff --git a/src/http.zig b/src/http.zig index cfeecb7500..4c01335caa 100644 --- a/src/http.zig +++ b/src/http.zig @@ -342,7 +342,9 @@ fn NewHTTPContext(comptime ssl: bool) type { }; pub fn markSocketAsDead(socket: HTTPSocket) void { - socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + if (socket.ext(**anyopaque)) |ctx| { + ctx.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + } } fn terminateSocket(socket: HTTPSocket) void { @@ -461,7 +463,9 @@ fn NewHTTPContext(comptime ssl: bool) type { if (hostname.len <= MAX_KEEPALIVE_HOSTNAME and !socket.isClosedOrHasError() and socket.isEstablished()) { if (this.pending_sockets.get()) |pending| { - socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(pending).ptr()); + if (socket.ext(**anyopaque)) |ctx| { + ctx.* = bun.cast(**anyopaque, ActiveSocket.init(pending).ptr()); + } socket.flush(); socket.timeout(0); socket.setTimeoutMinutes(5); @@ -517,9 +521,9 @@ fn NewHTTPContext(comptime ssl: bool) type { // handshake completed but we may have ssl errors client.flags.did_have_handshaking_error = handshake_error.error_no != 0; if (handshake_success) { - if(client.flags.reject_unauthorized) { + if (client.flags.reject_unauthorized) { // only reject the connection if reject_unauthorized == true - if(client.flags.did_have_handshaking_error) { + if (client.flags.did_have_handshaking_error) { client.closeAndFail(BoringSSL.getCertErrorFromNo(handshake_error.error_no), comptime ssl, socket); return; } @@ -537,7 +541,7 @@ fn NewHTTPContext(comptime ssl: bool) type { } else { // if we are here is because server rejected us, and the error_no is the cause of this // if we set reject_unauthorized == false this means the server requires custom CA aka NODE_EXTRA_CA_CERTS - if(client.flags.did_have_handshaking_error) { + if (client.flags.did_have_handshaking_error) { client.closeAndFail(BoringSSL.getCertErrorFromNo(handshake_error.error_no), comptime ssl, socket); return; } @@ -545,7 +549,6 @@ fn NewHTTPContext(comptime ssl: bool) type { client.closeAndFail(error.ConnectionRefused, comptime ssl, socket); return; } - } if (socket.isClosed()) { @@ -745,7 +748,9 @@ fn NewHTTPContext(comptime ssl: bool) type { if (client.isKeepAlivePossible()) { if (this.existingSocket(client.flags.reject_unauthorized, hostname, port)) |sock| { - sock.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr()); + if (sock.ext(**anyopaque)) |ctx| { + ctx.* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr()); + } client.allow_retry = true; client.onOpen(comptime ssl, sock); if (comptime ssl) { @@ -2858,10 +2863,10 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s } pub fn closeAndFail(this: *HTTPClient, err: anyerror, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket) void { - if (!socket.isClosed()) { - NewHTTPContext(is_ssl).terminateSocket(socket); - } if (this.state.stage != .fail and this.state.stage != .done) { + if (!socket.isClosed()) { + NewHTTPContext(is_ssl).terminateSocket(socket); + } log("closeAndFail: {s}", .{@errorName(err)}); this.fail(err); } diff --git a/test/harness.ts b/test/harness.ts index 44f82baf0f..84715b6899 100644 --- a/test/harness.ts +++ b/test/harness.ts @@ -100,7 +100,7 @@ export async function expectMaxObjectTypeCount( await Bun.sleep(wait); gc(); } - expect(heapStats().objectTypeCounts[type]).toBeLessThanOrEqual(count); + expect(heapStats().objectTypeCounts[type] || 0).toBeLessThanOrEqual(count); } // we must ensure that finalizers are run diff --git a/test/js/bun/http/serve-body-leak.test.ts b/test/js/bun/http/serve-body-leak.test.ts index 1e27f1d48e..b7409ad565 100644 --- a/test/js/bun/http/serve-body-leak.test.ts +++ b/test/js/bun/http/serve-body-leak.test.ts @@ -126,9 +126,9 @@ async function calculateMemoryLeak(fn: () => Promise) { // If it was leaking the body, the memory usage would be at least 512 KB * 10_000 = 5 GB // If it ends up around 280 MB, it's probably not leaking the body. for (const test_info of [ - ["#10265 should not leak memory when ignoring the body", callIgnore, false, 48], - ["should not leak memory when buffering the body", callBuffering, false, 48], - ["should not leak memory when streaming the body", callStreaming, false, 48], + ["#10265 should not leak memory when ignoring the body", callIgnore, false, 64], + ["should not leak memory when buffering the body", callBuffering, false, 64], + ["should not leak memory when streaming the body", callStreaming, false, 64], ["should not leak memory when streaming the body incompletely", callIncompleteStreaming, false, 64], ["should not leak memory when streaming the body and echoing it back", callStreamingEcho, false, 64], ] as const) { diff --git a/test/js/bun/net/socket.test.ts b/test/js/bun/net/socket.test.ts index 18486dbb90..c1b42a7f4c 100644 --- a/test/js/bun/net/socket.test.ts +++ b/test/js/bun/net/socket.test.ts @@ -1,5 +1,5 @@ import { expect, it } from "bun:test"; -import { bunEnv, bunExe, expectMaxObjectTypeCount, isWindows } from "harness"; +import { bunEnv, bunExe, expectMaxObjectTypeCount, isWindows, tls } from "harness"; import { connect, fileURLToPath, SocketHandler, spawn } from "bun"; import type { Socket } from "bun"; it("should coerce '0' to 0", async () => { @@ -354,7 +354,7 @@ it("it should not crash when returning a Error on client socket open", async () }); it("it should only call open once", async () => { - const server = Bun.listen({ + using server = Bun.listen({ port: 0, hostname: "localhost", socket: { @@ -381,7 +381,6 @@ it("it should only call open once", async () => { expect().fail("connectError should not be called"); }, close(socket) { - server.stop(); resolve(); }, data(socket, data) {}, @@ -397,7 +396,7 @@ it.skipIf(isWindows)("should not leak file descriptors when connecting", async ( }); it("should not call open if the connection had an error", async () => { - const server = Bun.listen({ + using server = Bun.listen({ port: 0, hostname: "0.0.0.0", socket: { @@ -435,12 +434,11 @@ it("should not call open if the connection had an error", async () => { await Bun.sleep(50); await promise; - server.stop(); expect(hadError).toBe(true); }); it("should connect directly when using an ip address", async () => { - const server = Bun.listen({ + using server = Bun.listen({ port: 0, hostname: "127.0.0.1", socket: { @@ -467,7 +465,6 @@ it("should connect directly when using an ip address", async () => { expect().fail("connectError should not be called"); }, close(socket) { - server.stop(); resolve(); }, data(socket, data) {}, @@ -498,3 +495,77 @@ it("should not call drain before handshake", async () => { await promise; expect(socket.authorized).toBe(true); }); +it("should be able to upgrade to TLS", async () => { + using server = Bun.serve({ + tls, + async fetch(req) { + return new Response("Hello World"); + }, + }); + const { promise: tlsSocketPromise, resolve, reject } = Promise.withResolvers(); + const { promise: rawSocketPromise, resolve: rawSocketResolve, reject: rawSocketReject } = Promise.withResolvers(); + { + let body = ""; + let rawBody = Buffer.alloc(0); + const socket = await Bun.connect({ + hostname: "localhost", + port: server.port, + socket: { + data(socket, data) { + rawBody = Buffer.concat([rawBody, data]); + }, + close() { + rawSocketResolve(rawBody); + }, + error(err) { + rawSocketReject(err); + }, + }, + }); + const result = socket.upgradeTLS({ + data: Buffer.from("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n"), + tls, + socket: { + data(socket, data) { + body += data.toString("utf8"); + if (body.includes("\r\n\r\n")) { + socket.end(); + } + }, + close() { + resolve(body); + }, + drain(socket) { + while (socket.data.byteLength > 0) { + const written = socket.write(socket.data); + if (written === 0) { + break; + } + socket.data = socket.data.slice(written); + } + socket.flush(); + }, + error(err) { + reject(err); + }, + }, + }); + + const [raw, tls_socket] = result; + expect(raw).toBeDefined(); + expect(tls_socket).toBeDefined(); + } + const [tlsData, rawData] = await Promise.all([tlsSocketPromise, rawSocketPromise]); + expect(tlsData).toContain("HTTP/1.1 200 OK"); + expect(tlsData).toContain("Content-Length: 11"); + expect(tlsData).toContain("\r\nHello World"); + expect(rawData.byteLength).toBeGreaterThanOrEqual(1980); +}); + +it("should not leak memory", async () => { + // assert we don't leak the sockets + // we expect 1 or 2 because that's the prototype / structure + await expectMaxObjectTypeCount(expect, "Listener", 2); + await expectMaxObjectTypeCount(expect, "TCPSocket", 2); + await expectMaxObjectTypeCount(expect, "TLSSocket", 2); +}); diff --git a/test/js/bun/net/tcp-server.test.ts b/test/js/bun/net/tcp-server.test.ts index fb6d46b8ac..4af228559f 100644 --- a/test/js/bun/net/tcp-server.test.ts +++ b/test/js/bun/net/tcp-server.test.ts @@ -15,7 +15,7 @@ it("remoteAddress works", async () => { }; reject = reject1; }); - let server = Bun.listen({ + using server = Bun.listen({ socket: { open(ws) { try { @@ -25,8 +25,6 @@ it("remoteAddress works", async () => { reject(e); return; - } finally { - setTimeout(() => server.stop(true), 0); } }, close() {}, @@ -63,7 +61,7 @@ it("should not allow invalid tls option", () => { [1, "string", Symbol("symbol")].forEach(value => { expect(() => { // @ts-ignore - const server = Bun.listen({ + using server = Bun.listen({ socket: { open(ws) {}, close() {}, @@ -73,7 +71,6 @@ it("should not allow invalid tls option", () => { hostname: "localhost", tls: value, }); - server.stop(true); }).toThrow("tls option expects an object"); }); }); @@ -82,7 +79,7 @@ it("should allow using false, null or undefined tls option", () => { [false, null, undefined].forEach(value => { expect(() => { // @ts-ignore - const server = Bun.listen({ + using server = Bun.listen({ socket: { open(ws) {}, close() {}, @@ -92,7 +89,6 @@ it("should allow using false, null or undefined tls option", () => { hostname: "localhost", tls: value, }); - server.stop(true); }).not.toThrow("tls option expects an object"); }); }); @@ -167,7 +163,7 @@ it("echo server 1 on 1", async () => { }, } as SocketHandler; - var server: TCPSocketListener | undefined = listen({ + using server: TCPSocketListener | undefined = listen({ socket: handlers, hostname: "localhost", port: 0, @@ -186,8 +182,6 @@ it("echo server 1 on 1", async () => { }, }); await Promise.all([prom, clientProm, serverProm]); - server.stop(true); - server = serverData = clientData = undefined; })(); }); @@ -276,7 +270,7 @@ describe("tcp socket binaryType", () => { binaryType: type, } as SocketHandler; - var server: TCPSocketListener | undefined = listen({ + using server: TCPSocketListener | undefined = listen({ socket: handlers, hostname: "localhost", port: 0, @@ -296,8 +290,6 @@ describe("tcp socket binaryType", () => { }); await Promise.all([prom, clientProm, serverProm]); - server.stop(true); - server = serverData = clientData = undefined; })(); }); }