From a73bbcd150e72dd96c59623e0a5f93bb5d33371d Mon Sep 17 00:00:00 2001 From: Kai Tamkun Date: Mon, 21 Jul 2025 13:56:21 -0700 Subject: [PATCH] Make TCPSocket/TLSSocket.handlers optional (please don't free undefined) --- src/bun.js/api/bun/socket.zig | 111 ++++++++++-------- .../api/bun/socket/tls_socket_functions.zig | 10 +- src/bun.js/node/node_net_binding.zig | 4 +- 3 files changed, 69 insertions(+), 56 deletions(-) diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index edd8c0e9d8..b35cea1f00 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -76,8 +76,7 @@ pub fn NewSocket(comptime ssl: bool) type { flags: Flags = .{}, ref_count: RefCount, wrapped: WrappedType = .none, - // TODO: make this optional - handlers: *Handlers, + handlers: ?*Handlers, this_value: JSC.JSValue = .zero, poll_ref: Async.KeepAlive = Async.KeepAlive.init(), ref_pollref_on_connect: bool = true, @@ -224,7 +223,7 @@ pub fn NewSocket(comptime ssl: bool) type { pub fn handleError(this: *This, err_value: JSC.JSValue) void { log("handleError", .{}); - const handlers = this.handlers; + const handlers = this.getHandlers(); var vm = handlers.vm; if (vm.isShuttingDown()) { return; @@ -242,7 +241,7 @@ pub fn NewSocket(comptime ssl: bool) type { JSC.markBinding(@src()); if (this.socket.isDetached()) return; if (this.native_callback.onWritable()) return; - const handlers = this.handlers; + const handlers = this.getHandlers(); const callback = handlers.onWritable; if (callback == .zero) return; @@ -272,8 +271,8 @@ pub fn NewSocket(comptime ssl: bool) type { pub fn onTimeout(this: *This, _: Socket) void { JSC.markBinding(@src()); if (this.socket.isDetached()) return; - log("onTimeout {s}", .{if (this.handlers.is_server) "S" else "C"}); - const handlers = this.handlers; + const handlers = this.getHandlers(); + log("onTimeout {s}", .{if (handlers.is_server) "S" else "C"}); const callback = handlers.onTimeout; if (callback == .zero or this.flags.finalizing) return; if (handlers.vm.isShuttingDown()) { @@ -292,8 +291,13 @@ pub fn NewSocket(comptime ssl: bool) type { }; } + pub fn getHandlers(this: *const This) *Handlers { + return this.handlers orelse @panic("No handlers set on Socket"); + } + pub fn handleConnectError(this: *This, errno: c_int) void { - log("onConnectError {s} ({d}, {d})", .{ if (this.handlers.is_server) "S" else "C", errno, this.ref_count.active_counts }); + const handlers = this.getHandlers(); + log("onConnectError {s} ({d}, {d})", .{ if (handlers.is_server) "S" else "C", errno, this.ref_count.active_counts }); // Ensure the socket is still alive for any defer's we have this.ref(); defer this.deref(); @@ -304,7 +308,6 @@ pub fn NewSocket(comptime ssl: bool) type { defer this.markInactive(); defer if (needs_deref) this.deref(); - const handlers = this.handlers; const vm = handlers.vm; this.poll_ref.unrefOnNextTick(vm); if (vm.isShuttingDown()) { @@ -373,7 +376,7 @@ pub fn NewSocket(comptime ssl: bool) type { pub fn markActive(this: *This) void { if (!this.flags.is_active) { - this.handlers.markActive(); + this.getHandlers().markActive(); this.flags.is_active = true; this.has_pending_activity.store(true, .release); } @@ -401,15 +404,20 @@ pub fn NewSocket(comptime ssl: bool) type { } this.flags.is_active = false; - const vm = this.handlers.vm; - this.handlers.markInactive(); + const handlers = this.getHandlers(); + const vm = handlers.vm; + handlers.markInactive(); this.poll_ref.unref(vm); this.has_pending_activity.store(false, .release); } } + pub fn isServer(this: *const This) bool { + return this.getHandlers().is_server; + } + pub fn onOpen(this: *This, socket: Socket) void { - log("onOpen {s} {*} {} {}", .{ if (this.handlers.is_server) "S" else "C", this, this.socket.isDetached(), this.ref_count.active_counts }); + log("onOpen {s} {*} {} {}", .{ if (this.isServer()) "S" else "C", this, this.socket.isDetached(), this.ref_count.active_counts }); // Ensure the socket remains alive until this is finished this.ref(); defer this.deref(); @@ -441,7 +449,7 @@ pub fn NewSocket(comptime ssl: bool) type { } } if (this.protos) |protos| { - if (this.handlers.is_server) { + if (this.isServer()) { BoringSSL.SSL_CTX_set_alpn_select_cb(BoringSSL.SSL_get_SSL_CTX(ssl_ptr), selectALPNCallback, bun.cast(*anyopaque, this)); } else { _ = BoringSSL.SSL_set_alpn_protos(ssl_ptr, protos.ptr, @as(c_uint, @intCast(protos.len))); @@ -457,7 +465,7 @@ pub fn NewSocket(comptime ssl: bool) type { } } - const handlers = this.handlers; + const handlers = this.getHandlers(); const callback = handlers.onOpen; const handshake_callback = handlers.onHandshake; @@ -509,13 +517,12 @@ pub fn NewSocket(comptime ssl: bool) type { pub fn onEnd(this: *This, _: Socket) void { JSC.markBinding(@src()); if (this.socket.isDetached()) return; - log("onEnd {s}", .{if (this.handlers.is_server) "S" else "C"}); + const handlers = this.getHandlers(); + log("onEnd {s}", .{if (handlers.is_server) "S" else "C"}); // Ensure the socket remains alive until this is finished this.ref(); defer this.deref(); - const handlers = this.handlers; - const callback = handlers.onEnd; if (callback == .zero or handlers.vm.isShuttingDown()) { this.poll_ref.unref(handlers.vm); @@ -541,13 +548,13 @@ pub fn NewSocket(comptime ssl: bool) type { JSC.markBinding(@src()); this.flags.handshake_complete = true; if (this.socket.isDetached()) return; - log("onHandshake {s} ({d})", .{ if (this.handlers.is_server) "S" else "C", success }); + const handlers = this.getHandlers(); + log("onHandshake {s} ({d})", .{ if (handlers.is_server) "S" else "C", success }); const authorized = if (success == 1) true else false; this.flags.authorized = authorized; - const handlers = this.handlers; var callback = handlers.onHandshake; var is_open = false; @@ -583,8 +590,8 @@ pub fn NewSocket(comptime ssl: bool) type { // clean onOpen callback so only called in the first handshake and not in every renegotiation // on servers this would require a different approach but it's not needed because our servers will not call handshake multiple times // servers don't support renegotiation - this.handlers.onOpen.unprotect(); - this.handlers.onOpen = .zero; + this.handlers.?.onOpen.unprotect(); + this.handlers.?.onOpen = .zero; } } else { // call handhsake callback with authorized and authorization error if has one @@ -607,7 +614,8 @@ pub fn NewSocket(comptime ssl: bool) type { pub fn onClose(this: *This, _: Socket, err: c_int, _: ?*anyopaque) void { JSC.markBinding(@src()); - log("onClose {s}", .{if (this.handlers.is_server) "S" else "C"}); + const handlers = this.getHandlers(); + log("onClose {s}", .{if (handlers.is_server) "S" else "C"}); this.detachNativeCallback(); this.socket.detach(); defer this.deref(); @@ -617,7 +625,6 @@ pub fn NewSocket(comptime ssl: bool) type { return; } - const handlers = this.handlers; const vm = handlers.vm; this.poll_ref.unref(vm); @@ -654,10 +661,10 @@ pub fn NewSocket(comptime ssl: bool) type { pub fn onData(this: *This, _: Socket, data: []const u8) void { JSC.markBinding(@src()); if (this.socket.isDetached()) return; - log("onData {s} ({d})", .{ if (this.handlers.is_server) "S" else "C", data.len }); + const handlers = this.getHandlers(); + log("onData {s} ({d})", .{ if (handlers.is_server) "S" else "C", data.len }); if (this.native_callback.onData(data)) return; - const handlers = this.handlers; const callback = handlers.onData; if (callback == .zero or this.flags.finalizing) return; if (handlers.vm.isShuttingDown()) { @@ -696,11 +703,13 @@ pub fn NewSocket(comptime ssl: bool) type { } pub fn getListener(this: *This, _: *JSC.JSGlobalObject) JSValue { - if (!this.handlers.is_server or this.socket.isDetached()) { + const handlers = this.getHandlers(); + + if (!handlers.is_server or this.socket.isDetached()) { return .js_undefined; } - const l: *Listener = @fieldParentPtr("handlers", this.handlers); + const l: *Listener = @fieldParentPtr("handlers", handlers); return l.strong_self.get() orelse .js_undefined; } @@ -1357,13 +1366,14 @@ pub fn NewSocket(comptime ssl: bool) type { return globalObject.throw("Expected \"socket\" option", .{}); }; - const handlers = try Handlers.fromJS(globalObject, socket_obj, this.handlers.is_server); + var prev_handlers = this.getHandlers(); + + const handlers = try Handlers.fromJS(globalObject, socket_obj, prev_handlers.is_server); - var prev_handlers = this.handlers; prev_handlers.unprotect(); - this.handlers.* = handlers; // TODO: this is a memory leak - this.handlers.withAsyncContextIfNeeded(globalObject); - this.handlers.protect(); + this.handlers.?.* = handlers; // TODO: this is a memory leak + this.handlers.?.withAsyncContextIfNeeded(globalObject); + this.handlers.?.protect(); return .js_undefined; } @@ -1405,7 +1415,7 @@ pub fn NewSocket(comptime ssl: bool) type { return .zero; } - var handlers = try Handlers.fromJS(globalObject, socket_obj, this.handlers.is_server); + var handlers = try Handlers.fromJS(globalObject, socket_obj, this.isServer()); if (globalObject.hasException()) { return .zero; @@ -1535,20 +1545,23 @@ pub fn NewSocket(comptime ssl: bool) type { const vm = handlers.vm; var raw_handlers_ptr = bun.default_allocator.create(Handlers) catch bun.outOfMemory(); - raw_handlers_ptr.* = .{ - .vm = vm, - .globalObject = globalObject, - .onOpen = this.handlers.onOpen, - .onClose = this.handlers.onClose, - .onData = this.handlers.onData, - .onWritable = this.handlers.onWritable, - .onTimeout = this.handlers.onTimeout, - .onConnectError = this.handlers.onConnectError, - .onEnd = this.handlers.onEnd, - .onError = this.handlers.onError, - .onHandshake = this.handlers.onHandshake, - .binary_type = this.handlers.binary_type, - .is_server = this.handlers.is_server, + raw_handlers_ptr.* = blk: { + const this_handlers = this.getHandlers(); + break :blk .{ + .vm = vm, + .globalObject = globalObject, + .onOpen = this_handlers.onOpen, + .onClose = this_handlers.onClose, + .onData = this_handlers.onData, + .onWritable = this_handlers.onWritable, + .onTimeout = this_handlers.onTimeout, + .onConnectError = this_handlers.onConnectError, + .onEnd = this_handlers.onEnd, + .onError = this_handlers.onError, + .onHandshake = this_handlers.onHandshake, + .binary_type = this_handlers.binary_type, + .is_server = this_handlers.is_server, + }; }; raw_handlers_ptr.protect(); @@ -1578,7 +1591,7 @@ pub fn NewSocket(comptime ssl: bool) type { tls.markActive(); // we're unrefing the original instance and refing the TLS instance - tls.poll_ref.ref(this.handlers.vm); + tls.poll_ref.ref(this.getHandlers().vm); // mark both instances on socket data if (new_socket.ext(WrappedSocket)) |ctx| { @@ -1590,7 +1603,7 @@ pub fn NewSocket(comptime ssl: bool) type { this.flags.is_active = false; // will free handlers when hits 0 active connections // the connection can be upgraded inside a handler call so we need to guarantee that it will be still alive - this.handlers.markInactive(); + this.getHandlers().markInactive(); this.has_pending_activity.store(false, .release); } diff --git a/src/bun.js/api/bun/socket/tls_socket_functions.zig b/src/bun.js/api/bun/socket/tls_socket_functions.zig index bc1443f212..2cd4d94b1c 100644 --- a/src/bun.js/api/bun/socket/tls_socket_functions.zig +++ b/src/bun.js/api/bun/socket/tls_socket_functions.zig @@ -11,7 +11,7 @@ pub fn getServername(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.Cal } pub fn setServername(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { - if (this.handlers.is_server) { + if (this.isServer()) { return globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{}); } @@ -120,7 +120,7 @@ pub fn getPeerCertificate(this: *This, globalObject: *JSC.JSGlobalObject, callfr const ssl_ptr = this.socket.ssl() orelse return .js_undefined; if (abbreviated) { - if (this.handlers.is_server) { + if (this.isServer()) { const cert = BoringSSL.SSL_get_peer_certificate(ssl_ptr); if (cert) |x509| { return X509.toJS(x509, globalObject); @@ -132,7 +132,7 @@ pub fn getPeerCertificate(this: *This, globalObject: *JSC.JSGlobalObject, callfr return X509.toJS(cert, globalObject); } var cert: ?*BoringSSL.X509 = null; - if (this.handlers.is_server) { + if (this.isServer()) { cert = BoringSSL.SSL_get_peer_certificate(ssl_ptr); } @@ -382,7 +382,7 @@ pub fn exportKeyingMaterial(this: *This, globalObject: *JSC.JSGlobalObject, call pub fn getEphemeralKeyInfo(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { // only available for clients - if (this.handlers.is_server) { + if (this.isServer()) { return JSValue.jsNull(); } var result = JSValue.createEmptyObject(globalObject, 3); @@ -555,7 +555,7 @@ pub fn setVerifyMode(this: *This, globalObject: *JSC.JSGlobalObject, callframe: const request_cert = request_cert_js.toBoolean(); const reject_unauthorized = request_cert_js.toBoolean(); var verify_mode: c_int = BoringSSL.SSL_VERIFY_NONE; - if (this.handlers.is_server) { + if (this.isServer()) { if (request_cert) { verify_mode = BoringSSL.SSL_VERIFY_PEER; if (reject_unauthorized) diff --git a/src/bun.js/node/node_net_binding.zig b/src/bun.js/node/node_net_binding.zig index 7978bd137c..b857d0b03a 100644 --- a/src/bun.js/node/node_net_binding.zig +++ b/src/bun.js/node/node_net_binding.zig @@ -84,7 +84,7 @@ pub fn newDetachedSocket(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFr .socket_context = null, .ref_count = .init(), .protos = null, - .handlers = undefined, + .handlers = null, }); return socket.getThisValue(globalThis); } else { @@ -93,7 +93,7 @@ pub fn newDetachedSocket(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFr .socket_context = null, .ref_count = .init(), .protos = null, - .handlers = undefined, + .handlers = null, }); return socket.getThisValue(globalThis); }