From f76addac437ecf6ec8436cd4a4d054c5a03c837b Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Fri, 18 Apr 2025 14:40:26 -0700 Subject: [PATCH] Remove usage of protect() & unprotect() from Listener --- src/bun.js/api/bun/SocketHandlers.zig | 200 ++++++++++++++ src/bun.js/api/bun/socket.zig | 248 ++---------------- src/bun.js/api/sockets.classes.ts | 24 ++ src/bun.js/bindings/JSRef.zig | 19 +- .../bindings/generated_classes_list.zig | 1 + src/jsc.zig | 1 + 6 files changed, 262 insertions(+), 231 deletions(-) create mode 100644 src/bun.js/api/bun/SocketHandlers.zig diff --git a/src/bun.js/api/bun/SocketHandlers.zig b/src/bun.js/api/bun/SocketHandlers.zig new file mode 100644 index 0000000000..5472bbc044 --- /dev/null +++ b/src/bun.js/api/bun/SocketHandlers.zig @@ -0,0 +1,200 @@ +binary_type: BinaryType = .Buffer, +vm: *JSC.VirtualMachine, +globalObject: *JSC.JSGlobalObject, +active_connections: u32 = 0, +is_server: bool = false, +protection_count: bun.DebugOnly(u32) = if (Environment.isDebug) 0, + +pub const js = JSC.Codegen.JSSocketHandlers; + +pub const Options = struct { + onData: JSValue = .zero, + onWritable: JSValue = .zero, + onOpen: JSValue = .zero, + onClose: JSValue = .zero, + onTimeout: JSValue = .zero, + onConnectError: JSValue = .zero, + onEnd: JSValue = .zero, + onError: JSValue = .zero, + onHandshake: JSValue = .zero, + promise: JSValue = .zero, +}; + +fn toJS(vm: *JSC.VirtualMachine, globalObject: *JSC.JSGlobalObject, is_server: bool, binary_type: BinaryType, opts: *const Options) bun.JSError!JSValue { + const handlers = bun.new(SocketHandlers, .{ + .vm = vm, + .globalObject = globalObject, + .is_server = is_server, + .binary_type = binary_type, + }); + + const as_js = js.toJS(handlers, globalObject); + if (opts.onData != .zero) js.onDataSetCached(as_js, opts.onData, globalObject); + if (opts.onWritable != .zero) js.onWritableSetCached(as_js, opts.onWritable, globalObject); + if (opts.onOpen != .zero) js.onOpenSetCached(as_js, opts.onOpen, globalObject); + if (opts.onClose != .zero) js.onCloseSetCached(as_js, opts.onClose, globalObject); + if (opts.onTimeout != .zero) js.onTimeoutSetCached(as_js, opts.onTimeout, globalObject); + if (opts.onConnectError != .zero) js.onConnectErrorSetCached(as_js, opts.onConnectError, globalObject); + if (opts.onEnd != .zero) js.onEndSetCached(as_js, opts.onEnd, globalObject); + if (opts.onError != .zero) js.onErrorSetCached(as_js, opts.onError, globalObject); + if (opts.onHandshake != .zero) js.onHandshakeSetCached(as_js, opts.onHandshake, globalObject); + if (opts.promise != .zero) js.promiseSetCached(as_js, opts.promise, globalObject); + + return as_js; +} + +pub fn markActive(this: *SocketHandlers) void { + Listener.log("markActive", .{}); + + this.active_connections += 1; +} + +pub const Scope = struct { + handlers: *SocketHandlers, + + pub fn exit(this: *Scope) void { + var vm = this.handlers.vm; + defer vm.eventLoop().exit(); + this.handlers.markInactive(); + } +}; + +pub fn enter(this: *SocketHandlers) Scope { + this.markActive(); + this.vm.eventLoop().enter(); + return .{ + .handlers = this, + }; +} + +// corker: Corker = .{}, + +fn getPromise(this_value: JSValue, globalObject: *JSC.JSGlobalObject) ?JSC.AnyPromise { + if (js.promiseGetCached(this_value)) |promise| { + js.promiseSetCached(this_value, .zero, globalObject); + return promise.asAnyPromise(); + } + + return null; +} + +pub fn resolvePromise(this: *SocketHandlers, this_value: JSValue, value: JSValue) void { + const vm = this.vm; + if (vm.isShuttingDown()) { + return; + } + + const promise = getPromise(this_value, this.globalObject) orelse return; + promise.resolve(this.globalObject, value); +} + +pub fn rejectPromise(this: *SocketHandlers, this_value: JSValue, value: JSValue) bool { + const vm = this.vm; + if (vm.isShuttingDown()) { + return true; + } + + const promise = getPromise(this_value, this.globalObject) orelse return false; + promise.reject(this.globalObject, value); + return true; +} + +pub fn markInactive(this: *SocketHandlers) void { + this.active_connections -= 1; + if (this.active_connections == 0) { + if (this.is_server) { + const listen_socket: *Listener = @fieldParentPtr("handlers", this); + // allow it to be GC'd once the last connection is closed and it's not listening anymore + if (listen_socket.listener == .none) { + listen_socket.poll_ref.unref(this.vm); + listen_socket.this_value.deinit(); + } + } + } +} + +pub fn callErrorHandler(this: *SocketHandlers, this_handler: JSValue, thisValue: JSValue, err: []const JSValue) bool { + const vm = this.vm; + if (vm.isShuttingDown()) { + return false; + } + + const globalObject = this.globalObject; + const onError = js.onErrorGetCached(this_handler) orelse return false; + + if (onError == .zero) { + if (err.len > 0) + _ = vm.uncaughtException(globalObject, err[0], false); + + return false; + } + + _ = onError.call(globalObject, thisValue, err) catch |e| + globalObject.reportActiveExceptionAsUnhandled(e); + + return true; +} + +pub fn create(globalObject: *JSC.JSGlobalObject, opts: JSValue) bun.JSError!JSValue { + var handlers = SocketHandlers{ + .vm = globalObject.bunVM(), + .globalObject = globalObject, + }; + + if (opts.isEmptyOrUndefinedOrNull() or opts.isBoolean() or !opts.isObject()) { + return globalObject.throwInvalidArguments("Expected \"socket\" to be an object", .{}); + } + + var options = Options{}; + + const pairs = .{ + .{ "onData", "data" }, + .{ "onWritable", "drain" }, + .{ "onOpen", "open" }, + .{ "onClose", "close" }, + .{ "onTimeout", "timeout" }, + .{ "onConnectError", "connectError" }, + .{ "onEnd", "end" }, + .{ "onError", "error" }, + .{ "onHandshake", "handshake" }, + }; + inline for (pairs) |pair| { + if (try opts.getTruthyComptime(globalObject, pair.@"1")) |callback_value| { + if (!callback_value.isCell() or !callback_value.isCallable()) { + return globalObject.throwInvalidArguments("Expected \"{s}\" callback to be a function", .{pair[1]}); + } + + @field(options, pair.@"0") = callback_value; + } + } + + if (options.onData == .zero and options.onWritable == .zero) { + return globalObject.throwInvalidArguments("Expected at least \"data\" or \"drain\" callback", .{}); + } + + if (try opts.getTruthy(globalObject, "binaryType")) |binary_type_value| { + if (!binary_type_value.isString()) { + return globalObject.throwInvalidArguments("Expected \"binaryType\" to be a string", .{}); + } + + handlers.binary_type = try BinaryType.fromJSValue(globalObject, binary_type_value) orelse { + return globalObject.throwInvalidArguments("Expected 'binaryType' to be 'ArrayBuffer', 'Uint8Array', or 'Buffer'", .{}); + }; + } + + return toJS(globalObject.bunVM(), globalObject, false, handlers.binary_type, &options); +} + +pub fn finalize(this: *SocketHandlers) void { + bun.destroy(this); +} + +const bun = @import("bun"); +const JSC = bun.JSC; +const BinaryType = JSC.BinaryType; + +const Environment = bun.Environment; +const Listener = JSC.API.Listener; +const JSValue = JSC.JSValue; + +const SocketHandlers = @This(); diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 82b5602534..58b3d98626 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -112,205 +112,12 @@ const WrappedType = enum { tls, tcp, }; -const Handlers = struct { - onOpen: JSC.JSValue = .zero, - onClose: JSC.JSValue = .zero, - onData: JSC.JSValue = .zero, - onWritable: JSC.JSValue = .zero, - onTimeout: JSC.JSValue = .zero, - onConnectError: JSC.JSValue = .zero, - onEnd: JSC.JSValue = .zero, - onError: JSC.JSValue = .zero, - onHandshake: JSC.JSValue = .zero, - - binary_type: BinaryType = .Buffer, - - vm: *JSC.VirtualMachine, - globalObject: *JSC.JSGlobalObject, - active_connections: u32 = 0, - is_server: bool = false, - promise: JSC.Strong = .empty, - - protection_count: bun.DebugOnly(u32) = if (Environment.isDebug) 0, - - pub fn markActive(this: *Handlers) void { - Listener.log("markActive", .{}); - - this.active_connections += 1; - } - - pub const Scope = struct { - handlers: *Handlers, - - pub fn exit(this: *Scope) void { - var vm = this.handlers.vm; - defer vm.eventLoop().exit(); - this.handlers.markInactive(); - } - }; - - pub fn enter(this: *Handlers) Scope { - this.markActive(); - this.vm.eventLoop().enter(); - return .{ - .handlers = this, - }; - } - - // corker: Corker = .{}, - - pub fn resolvePromise(this: *Handlers, value: JSValue) void { - const vm = this.vm; - if (vm.isShuttingDown()) { - return; - } - - const promise = this.promise.trySwap() orelse return; - const anyPromise = promise.asAnyPromise() orelse return; - anyPromise.resolve(this.globalObject, value); - } - - pub fn rejectPromise(this: *Handlers, value: JSValue) bool { - const vm = this.vm; - if (vm.isShuttingDown()) { - return true; - } - - const promise = this.promise.trySwap() orelse return false; - const anyPromise = promise.asAnyPromise() orelse return false; - anyPromise.reject(this.globalObject, value); - return true; - } - - pub fn markInactive(this: *Handlers) void { - Listener.log("markInactive", .{}); - this.active_connections -= 1; - if (this.active_connections == 0) { - if (this.is_server) { - const listen_socket: *Listener = @fieldParentPtr("handlers", this); - // allow it to be GC'd once the last connection is closed and it's not listening anymore - if (listen_socket.listener == .none) { - listen_socket.poll_ref.unref(this.vm); - listen_socket.strong_self.deinit(); - } - } else { - this.unprotect(); - bun.default_allocator.destroy(this); - } - } - } - - pub fn callErrorHandler(this: *Handlers, thisValue: JSValue, err: []const JSValue) bool { - const vm = this.vm; - if (vm.isShuttingDown()) { - return false; - } - - const globalObject = this.globalObject; - const onError = this.onError; - - if (onError == .zero) { - if (err.len > 0) - _ = vm.uncaughtException(globalObject, err[0], false); - - return false; - } - - _ = onError.call(globalObject, thisValue, err) catch |e| - globalObject.reportActiveExceptionAsUnhandled(e); - - return true; - } - - pub fn fromJS(globalObject: *JSC.JSGlobalObject, opts: JSC.JSValue) bun.JSError!Handlers { - var handlers = Handlers{ - .vm = globalObject.bunVM(), - .globalObject = globalObject, - }; - - if (opts.isEmptyOrUndefinedOrNull() or opts.isBoolean() or !opts.isObject()) { - return globalObject.throwInvalidArguments("Expected \"socket\" to be an object", .{}); - } - - const pairs = .{ - .{ "onData", "data" }, - .{ "onWritable", "drain" }, - .{ "onOpen", "open" }, - .{ "onClose", "close" }, - .{ "onTimeout", "timeout" }, - .{ "onConnectError", "connectError" }, - .{ "onEnd", "end" }, - .{ "onError", "error" }, - .{ "onHandshake", "handshake" }, - }; - inline for (pairs) |pair| { - if (try opts.getTruthyComptime(globalObject, pair.@"1")) |callback_value| { - if (!callback_value.isCell() or !callback_value.isCallable()) { - return globalObject.throwInvalidArguments("Expected \"{s}\" callback to be a function", .{pair[1]}); - } - - @field(handlers, pair.@"0") = callback_value; - } - } - - if (handlers.onData == .zero and handlers.onWritable == .zero) { - return globalObject.throwInvalidArguments("Expected at least \"data\" or \"drain\" callback", .{}); - } - - if (try opts.getTruthy(globalObject, "binaryType")) |binary_type_value| { - if (!binary_type_value.isString()) { - return globalObject.throwInvalidArguments("Expected \"binaryType\" to be a string", .{}); - } - - handlers.binary_type = try BinaryType.fromJSValue(globalObject, binary_type_value) orelse { - return globalObject.throwInvalidArguments("Expected 'binaryType' to be 'ArrayBuffer', 'Uint8Array', or 'Buffer'", .{}); - }; - } - - return handlers; - } - - pub fn unprotect(this: *Handlers) void { - if (this.vm.isShuttingDown()) { - return; - } - - if (comptime Environment.isDebug) { - bun.assert(this.protection_count > 0); - this.protection_count -= 1; - } - this.onOpen.unprotect(); - this.onClose.unprotect(); - this.onData.unprotect(); - this.onWritable.unprotect(); - this.onTimeout.unprotect(); - this.onConnectError.unprotect(); - this.onEnd.unprotect(); - this.onError.unprotect(); - this.onHandshake.unprotect(); - } - - pub fn protect(this: *Handlers) void { - if (comptime Environment.isDebug) { - this.protection_count += 1; - } - this.onOpen.protect(); - this.onClose.protect(); - this.onData.protect(); - this.onWritable.protect(); - this.onTimeout.protect(); - this.onConnectError.protect(); - this.onEnd.protect(); - this.onError.protect(); - this.onHandshake.protect(); - } -}; pub const SocketConfig = struct { hostname_or_unix: JSC.ZigString.Slice, port: ?u16 = null, ssl: ?JSC.API.ServerConfig.SSLConfig = null, - handlers: Handlers, + handlers_jsvalue: JSC.JSValue = .zero, default_data: JSC.JSValue = .zero, exclusive: bool = false, allowHalfOpen: bool = false, @@ -455,19 +262,18 @@ pub const SocketConfig = struct { return globalObject.throwInvalidArguments("Expected either \"hostname\" or \"unix\"", .{}); } - var handlers = try Handlers.fromJS(globalObject, try opts.get(globalObject, "socket") orelse JSValue.zero); + const handlers = try SocketHandlers.create(globalObject, try opts.get(globalObject, "socket") orelse JSValue.zero); + defer handlers.ensureStillAlive(); if (opts.fastGet(globalObject, .data)) |default_data_value| { default_data = default_data_value; } - handlers.protect(); - return SocketConfig{ .hostname_or_unix = hostname_or_unix, .port = port, .ssl = ssl, - .handlers = handlers, + .handlers_jsvalue = handlers, .default_data = default_data, .exclusive = exclusive, .allowHalfOpen = allowHalfOpen, @@ -512,7 +318,6 @@ fn normalizePipeName(pipe_name: []const u8, buffer: []u8) ?[]const u8 { pub const Listener = struct { pub const log = Output.scoped(.Listener, false); - handlers: Handlers, listener: ListenerType = .none, poll_ref: Async.KeepAlive = Async.KeepAlive.init(), @@ -522,7 +327,7 @@ pub const Listener = struct { protos: ?[]const u8 = null, strong_data: JSC.Strong = .empty, - strong_self: JSC.Strong = .empty, + this_value: JSC.JSRef = .empty, pub const js = JSC.Codegen.JSListener; pub const toJS = js.toJS; @@ -628,7 +433,9 @@ pub const Listener = struct { var hostname_or_unix = socket_config.hostname_or_unix; const port = socket_config.port; var ssl = socket_config.ssl; - var handlers = socket_config.handlers; + const handlers_jsvalue = socket_config.handlers_jsvalue; + defer handlers_jsvalue.ensureStillAlive(); + const handlers = handlers_jsvalue.as(SocketHandlers) orelse return globalObject.throwInvalidArguments("Expected \"socket\" object", .{}); var protos: ?[]const u8 = null; handlers.is_server = true; @@ -650,7 +457,6 @@ pub const Listener = struct { } } var socket = Listener{ - .handlers = handlers, .connection = connection, .ssl = ssl_enabled, .socket_context = null, @@ -660,8 +466,6 @@ pub const Listener = struct { vm.eventLoop().ensureWaker(); - socket.handlers.protect(); - if (socket_config.default_data != .zero) { socket.strong_data = JSC.Strong.create(socket_config.default_data, globalObject); } @@ -679,9 +483,10 @@ pub const Listener = struct { }; const this_value = this.toJS(globalObject); - this.strong_self.set(globalObject, this_value); + this.this_value.setStrong(globalObject, this_value); this.poll_ref.ref(handlers.vm); + js.handlersSetCached(this_value, globalObject, handlers_jsvalue); return this_value; } } @@ -703,7 +508,6 @@ pub const Listener = struct { ) orelse { var err = globalObject.createErrorInstance("Failed to listen on {s}:{d}", .{ hostname_or_unix.slice(), port orelse 0 }); defer { - socket_config.handlers.unprotect(); hostname_or_unix.deinit(); } @@ -821,8 +625,6 @@ pub const Listener = struct { .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch bun.outOfMemory()) else null, }; - socket.handlers.protect(); - if (socket_config.default_data != .zero) { socket.strong_data = JSC.Strong.create(socket_config.default_data, globalObject); } @@ -840,8 +642,9 @@ pub const Listener = struct { this.socket_context.?.ext(ssl_enabled, *Listener).?.* = this; const this_value = this.toJS(globalObject); - this.strong_self.set(globalObject, this_value); + this.this_value.setStrong(globalObject, this_value); this.poll_ref.ref(handlers.vm); + js.handlersSetCached(this_value, globalObject, handlers_jsvalue); return this_value; } @@ -864,7 +667,6 @@ pub const Listener = struct { var this_socket = Socket.new(.{ .ref_count = .init(), - .handlers = &listener.handlers, .this_value = .zero, // here we start with a detached socket and attach it later after accept .socket = Socket.Socket.detached, @@ -877,6 +679,8 @@ pub const Listener = struct { const globalObject = listener.handlers.globalObject; Socket.js.dataSetCached(this_socket.getThisValue(globalObject), globalObject, default_data); } + const this_value = this_socket.getThisValue(globalObject); + js.handlersSetCached(listener.this(), globalObject, js.handlersGetCached(this_value).?); return this_socket; } @@ -965,7 +769,7 @@ pub const Listener = struct { this.socket_context = null; ctx.deinit(this.ssl); } - this.strong_self.clearWithoutDeallocation(); + this.this_value.downgrade(); this.strong_data.clearWithoutDeallocation(); } else { if (force_close) { @@ -998,22 +802,10 @@ pub const Listener = struct { pub fn deinit(this: *Listener) void { log("deinit", .{}); - this.strong_self.deinit(); + this.this_value.deinit(); this.strong_data.deinit(); this.poll_ref.unref(this.handlers.vm); bun.assert(this.listener == .none); - this.handlers.unprotect(); - - if (this.handlers.active_connections > 0) { - if (this.socket_context) |ctx| { - ctx.close(this.ssl); - } - // TODO: fix this leak. - } else { - if (this.socket_context) |ctx| { - ctx.deinit(this.ssl); - } - } this.connection.deinit(); if (this.protos) |protos| { @@ -1053,14 +845,14 @@ pub const Listener = struct { const this_value = callframe.this(); if (this.listener == .none) return JSValue.jsUndefined(); this.poll_ref.ref(globalObject.bunVM()); - this.strong_self.set(globalObject, this_value); + this.this_value.setStrong(globalObject, this_value); return JSValue.jsUndefined(); } pub fn unref(this: *Listener, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { this.poll_ref.unref(globalObject.bunVM()); if (this.handlers.active_connections == 0) { - this.strong_self.clearWithoutDeallocation(); + this.this_value.downgrade(); } return JSValue.jsUndefined(); } @@ -2030,7 +1822,7 @@ fn NewSocket(comptime ssl: bool) type { } const l: *Listener = @fieldParentPtr("handlers", this.handlers); - return l.strong_self.get() orelse JSValue.jsUndefined(); + return l.this_value.get() orelse JSValue.jsUndefined(); } pub fn getReadyState( @@ -4512,3 +4304,5 @@ pub fn jsCreateSocketPair(global: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JS array.putIndex(global, 1, JSC.jsNumber(fds_[1])); return array; } + +pub const SocketHandlers = @import("./SocketHandlers.zig"); diff --git a/src/bun.js/api/sockets.classes.ts b/src/bun.js/api/sockets.classes.ts index b85d79176e..4646ebca0e 100644 --- a/src/bun.js/api/sockets.classes.ts +++ b/src/bun.js/api/sockets.classes.ts @@ -221,6 +221,7 @@ function generate(ssl) { finalize: true, construct: true, klass: {}, + values: ["handlers"], }); } const sslOnly = { @@ -290,6 +291,7 @@ export default [ finalize: true, construct: true, klass: {}, + values: ["handlers"], }), define({ @@ -448,4 +450,26 @@ export default [ }, }, }), + define({ + name: "SocketHandlers", + construct: false, + call: false, + noConstructor: true, + finalize: true, + proto: {}, + klass: {}, + + values: [ + "onData", + "onWritable", + "onOpen", + "onClose", + "onTimeout", + "onConnectError", + "onEnd", + "onError", + "onHandshake", + "promise", + ], + }), ]; diff --git a/src/bun.js/bindings/JSRef.zig b/src/bun.js/bindings/JSRef.zig index 0e49d74fdd..54291cf4f4 100644 --- a/src/bun.js/bindings/JSRef.zig +++ b/src/bun.js/bindings/JSRef.zig @@ -11,10 +11,6 @@ pub const JSRef = union(enum) { return .{ .strong = JSC.Strong.create(value, globalThis) }; } - pub fn empty() @This() { - return .{ .weak = .zero }; - } - pub fn get(this: *@This()) JSC.JSValue { return switch (this.*) { .weak => this.weak, @@ -75,6 +71,21 @@ pub const JSRef = union(enum) { .finalized => {}, } } + + pub fn downgrade(this: *@This()) void { + switch (this.*) { + .weak => {}, + .strong => { + const value = this.strong.get() orelse { + this.* = .{ .weak = .zero }; + return; + }; + this.strong.deinit(); + this.* = .{ .weak = value }; + }, + .finalized => {}, + } + } }; const JSC = bun.JSC; diff --git a/src/bun.js/bindings/generated_classes_list.zig b/src/bun.js/bindings/generated_classes_list.zig index 1d0cc94c9c..9a66f391cf 100644 --- a/src/bun.js/bindings/generated_classes_list.zig +++ b/src/bun.js/bindings/generated_classes_list.zig @@ -83,4 +83,5 @@ pub const Classes = struct { pub const S3Stat = JSC.WebCore.S3Stat; pub const HTMLBundle = JSC.API.HTMLBundle; pub const RedisClient = JSC.API.Valkey; + pub const SocketHandlers = JSC.API.SocketHandlers; }; diff --git a/src/jsc.zig b/src/jsc.zig index 0ce27f5860..faec6f7e90 100644 --- a/src/jsc.zig +++ b/src/jsc.zig @@ -59,6 +59,7 @@ pub const API = struct { pub const NativeBrotli = @import("./bun.js/node/node_zlib_binding.zig").SNativeBrotli; pub const HTMLBundle = @import("./bun.js/api/server/HTMLBundle.zig"); pub const Valkey = @import("./valkey/js_valkey.zig").JSValkeyClient; + pub const SocketHandlers = @import("./bun.js/api/bun/SocketHandlers.zig"); }; pub const Postgres = @import("./sql/postgres.zig"); pub const DNS = @import("./bun.js/api/bun/dns_resolver.zig");