Remove usage of protect() & unprotect() from Listener

This commit is contained in:
Jarred Sumner
2025-04-18 14:40:26 -07:00
parent 7d7512076b
commit f76addac43
6 changed files with 262 additions and 231 deletions

View File

@@ -0,0 +1,200 @@
binary_type: BinaryType = .Buffer,
vm: *JSC.VirtualMachine,
globalObject: *JSC.JSGlobalObject,
active_connections: u32 = 0,
is_server: bool = false,
protection_count: bun.DebugOnly(u32) = if (Environment.isDebug) 0,
pub const js = JSC.Codegen.JSSocketHandlers;
pub const Options = struct {
onData: JSValue = .zero,
onWritable: JSValue = .zero,
onOpen: JSValue = .zero,
onClose: JSValue = .zero,
onTimeout: JSValue = .zero,
onConnectError: JSValue = .zero,
onEnd: JSValue = .zero,
onError: JSValue = .zero,
onHandshake: JSValue = .zero,
promise: JSValue = .zero,
};
fn toJS(vm: *JSC.VirtualMachine, globalObject: *JSC.JSGlobalObject, is_server: bool, binary_type: BinaryType, opts: *const Options) bun.JSError!JSValue {
const handlers = bun.new(SocketHandlers, .{
.vm = vm,
.globalObject = globalObject,
.is_server = is_server,
.binary_type = binary_type,
});
const as_js = js.toJS(handlers, globalObject);
if (opts.onData != .zero) js.onDataSetCached(as_js, opts.onData, globalObject);
if (opts.onWritable != .zero) js.onWritableSetCached(as_js, opts.onWritable, globalObject);
if (opts.onOpen != .zero) js.onOpenSetCached(as_js, opts.onOpen, globalObject);
if (opts.onClose != .zero) js.onCloseSetCached(as_js, opts.onClose, globalObject);
if (opts.onTimeout != .zero) js.onTimeoutSetCached(as_js, opts.onTimeout, globalObject);
if (opts.onConnectError != .zero) js.onConnectErrorSetCached(as_js, opts.onConnectError, globalObject);
if (opts.onEnd != .zero) js.onEndSetCached(as_js, opts.onEnd, globalObject);
if (opts.onError != .zero) js.onErrorSetCached(as_js, opts.onError, globalObject);
if (opts.onHandshake != .zero) js.onHandshakeSetCached(as_js, opts.onHandshake, globalObject);
if (opts.promise != .zero) js.promiseSetCached(as_js, opts.promise, globalObject);
return as_js;
}
pub fn markActive(this: *SocketHandlers) void {
Listener.log("markActive", .{});
this.active_connections += 1;
}
pub const Scope = struct {
handlers: *SocketHandlers,
pub fn exit(this: *Scope) void {
var vm = this.handlers.vm;
defer vm.eventLoop().exit();
this.handlers.markInactive();
}
};
pub fn enter(this: *SocketHandlers) Scope {
this.markActive();
this.vm.eventLoop().enter();
return .{
.handlers = this,
};
}
// corker: Corker = .{},
fn getPromise(this_value: JSValue, globalObject: *JSC.JSGlobalObject) ?JSC.AnyPromise {
if (js.promiseGetCached(this_value)) |promise| {
js.promiseSetCached(this_value, .zero, globalObject);
return promise.asAnyPromise();
}
return null;
}
pub fn resolvePromise(this: *SocketHandlers, this_value: JSValue, value: JSValue) void {
const vm = this.vm;
if (vm.isShuttingDown()) {
return;
}
const promise = getPromise(this_value, this.globalObject) orelse return;
promise.resolve(this.globalObject, value);
}
pub fn rejectPromise(this: *SocketHandlers, this_value: JSValue, value: JSValue) bool {
const vm = this.vm;
if (vm.isShuttingDown()) {
return true;
}
const promise = getPromise(this_value, this.globalObject) orelse return false;
promise.reject(this.globalObject, value);
return true;
}
pub fn markInactive(this: *SocketHandlers) void {
this.active_connections -= 1;
if (this.active_connections == 0) {
if (this.is_server) {
const listen_socket: *Listener = @fieldParentPtr("handlers", this);
// allow it to be GC'd once the last connection is closed and it's not listening anymore
if (listen_socket.listener == .none) {
listen_socket.poll_ref.unref(this.vm);
listen_socket.this_value.deinit();
}
}
}
}
pub fn callErrorHandler(this: *SocketHandlers, this_handler: JSValue, thisValue: JSValue, err: []const JSValue) bool {
const vm = this.vm;
if (vm.isShuttingDown()) {
return false;
}
const globalObject = this.globalObject;
const onError = js.onErrorGetCached(this_handler) orelse return false;
if (onError == .zero) {
if (err.len > 0)
_ = vm.uncaughtException(globalObject, err[0], false);
return false;
}
_ = onError.call(globalObject, thisValue, err) catch |e|
globalObject.reportActiveExceptionAsUnhandled(e);
return true;
}
pub fn create(globalObject: *JSC.JSGlobalObject, opts: JSValue) bun.JSError!JSValue {
var handlers = SocketHandlers{
.vm = globalObject.bunVM(),
.globalObject = globalObject,
};
if (opts.isEmptyOrUndefinedOrNull() or opts.isBoolean() or !opts.isObject()) {
return globalObject.throwInvalidArguments("Expected \"socket\" to be an object", .{});
}
var options = Options{};
const pairs = .{
.{ "onData", "data" },
.{ "onWritable", "drain" },
.{ "onOpen", "open" },
.{ "onClose", "close" },
.{ "onTimeout", "timeout" },
.{ "onConnectError", "connectError" },
.{ "onEnd", "end" },
.{ "onError", "error" },
.{ "onHandshake", "handshake" },
};
inline for (pairs) |pair| {
if (try opts.getTruthyComptime(globalObject, pair.@"1")) |callback_value| {
if (!callback_value.isCell() or !callback_value.isCallable()) {
return globalObject.throwInvalidArguments("Expected \"{s}\" callback to be a function", .{pair[1]});
}
@field(options, pair.@"0") = callback_value;
}
}
if (options.onData == .zero and options.onWritable == .zero) {
return globalObject.throwInvalidArguments("Expected at least \"data\" or \"drain\" callback", .{});
}
if (try opts.getTruthy(globalObject, "binaryType")) |binary_type_value| {
if (!binary_type_value.isString()) {
return globalObject.throwInvalidArguments("Expected \"binaryType\" to be a string", .{});
}
handlers.binary_type = try BinaryType.fromJSValue(globalObject, binary_type_value) orelse {
return globalObject.throwInvalidArguments("Expected 'binaryType' to be 'ArrayBuffer', 'Uint8Array', or 'Buffer'", .{});
};
}
return toJS(globalObject.bunVM(), globalObject, false, handlers.binary_type, &options);
}
pub fn finalize(this: *SocketHandlers) void {
bun.destroy(this);
}
const bun = @import("bun");
const JSC = bun.JSC;
const BinaryType = JSC.BinaryType;
const Environment = bun.Environment;
const Listener = JSC.API.Listener;
const JSValue = JSC.JSValue;
const SocketHandlers = @This();

View File

@@ -112,205 +112,12 @@ const WrappedType = enum {
tls,
tcp,
};
const Handlers = struct {
onOpen: JSC.JSValue = .zero,
onClose: JSC.JSValue = .zero,
onData: JSC.JSValue = .zero,
onWritable: JSC.JSValue = .zero,
onTimeout: JSC.JSValue = .zero,
onConnectError: JSC.JSValue = .zero,
onEnd: JSC.JSValue = .zero,
onError: JSC.JSValue = .zero,
onHandshake: JSC.JSValue = .zero,
binary_type: BinaryType = .Buffer,
vm: *JSC.VirtualMachine,
globalObject: *JSC.JSGlobalObject,
active_connections: u32 = 0,
is_server: bool = false,
promise: JSC.Strong = .empty,
protection_count: bun.DebugOnly(u32) = if (Environment.isDebug) 0,
pub fn markActive(this: *Handlers) void {
Listener.log("markActive", .{});
this.active_connections += 1;
}
pub const Scope = struct {
handlers: *Handlers,
pub fn exit(this: *Scope) void {
var vm = this.handlers.vm;
defer vm.eventLoop().exit();
this.handlers.markInactive();
}
};
pub fn enter(this: *Handlers) Scope {
this.markActive();
this.vm.eventLoop().enter();
return .{
.handlers = this,
};
}
// corker: Corker = .{},
pub fn resolvePromise(this: *Handlers, value: JSValue) void {
const vm = this.vm;
if (vm.isShuttingDown()) {
return;
}
const promise = this.promise.trySwap() orelse return;
const anyPromise = promise.asAnyPromise() orelse return;
anyPromise.resolve(this.globalObject, value);
}
pub fn rejectPromise(this: *Handlers, value: JSValue) bool {
const vm = this.vm;
if (vm.isShuttingDown()) {
return true;
}
const promise = this.promise.trySwap() orelse return false;
const anyPromise = promise.asAnyPromise() orelse return false;
anyPromise.reject(this.globalObject, value);
return true;
}
pub fn markInactive(this: *Handlers) void {
Listener.log("markInactive", .{});
this.active_connections -= 1;
if (this.active_connections == 0) {
if (this.is_server) {
const listen_socket: *Listener = @fieldParentPtr("handlers", this);
// allow it to be GC'd once the last connection is closed and it's not listening anymore
if (listen_socket.listener == .none) {
listen_socket.poll_ref.unref(this.vm);
listen_socket.strong_self.deinit();
}
} else {
this.unprotect();
bun.default_allocator.destroy(this);
}
}
}
pub fn callErrorHandler(this: *Handlers, thisValue: JSValue, err: []const JSValue) bool {
const vm = this.vm;
if (vm.isShuttingDown()) {
return false;
}
const globalObject = this.globalObject;
const onError = this.onError;
if (onError == .zero) {
if (err.len > 0)
_ = vm.uncaughtException(globalObject, err[0], false);
return false;
}
_ = onError.call(globalObject, thisValue, err) catch |e|
globalObject.reportActiveExceptionAsUnhandled(e);
return true;
}
pub fn fromJS(globalObject: *JSC.JSGlobalObject, opts: JSC.JSValue) bun.JSError!Handlers {
var handlers = Handlers{
.vm = globalObject.bunVM(),
.globalObject = globalObject,
};
if (opts.isEmptyOrUndefinedOrNull() or opts.isBoolean() or !opts.isObject()) {
return globalObject.throwInvalidArguments("Expected \"socket\" to be an object", .{});
}
const pairs = .{
.{ "onData", "data" },
.{ "onWritable", "drain" },
.{ "onOpen", "open" },
.{ "onClose", "close" },
.{ "onTimeout", "timeout" },
.{ "onConnectError", "connectError" },
.{ "onEnd", "end" },
.{ "onError", "error" },
.{ "onHandshake", "handshake" },
};
inline for (pairs) |pair| {
if (try opts.getTruthyComptime(globalObject, pair.@"1")) |callback_value| {
if (!callback_value.isCell() or !callback_value.isCallable()) {
return globalObject.throwInvalidArguments("Expected \"{s}\" callback to be a function", .{pair[1]});
}
@field(handlers, pair.@"0") = callback_value;
}
}
if (handlers.onData == .zero and handlers.onWritable == .zero) {
return globalObject.throwInvalidArguments("Expected at least \"data\" or \"drain\" callback", .{});
}
if (try opts.getTruthy(globalObject, "binaryType")) |binary_type_value| {
if (!binary_type_value.isString()) {
return globalObject.throwInvalidArguments("Expected \"binaryType\" to be a string", .{});
}
handlers.binary_type = try BinaryType.fromJSValue(globalObject, binary_type_value) orelse {
return globalObject.throwInvalidArguments("Expected 'binaryType' to be 'ArrayBuffer', 'Uint8Array', or 'Buffer'", .{});
};
}
return handlers;
}
pub fn unprotect(this: *Handlers) void {
if (this.vm.isShuttingDown()) {
return;
}
if (comptime Environment.isDebug) {
bun.assert(this.protection_count > 0);
this.protection_count -= 1;
}
this.onOpen.unprotect();
this.onClose.unprotect();
this.onData.unprotect();
this.onWritable.unprotect();
this.onTimeout.unprotect();
this.onConnectError.unprotect();
this.onEnd.unprotect();
this.onError.unprotect();
this.onHandshake.unprotect();
}
pub fn protect(this: *Handlers) void {
if (comptime Environment.isDebug) {
this.protection_count += 1;
}
this.onOpen.protect();
this.onClose.protect();
this.onData.protect();
this.onWritable.protect();
this.onTimeout.protect();
this.onConnectError.protect();
this.onEnd.protect();
this.onError.protect();
this.onHandshake.protect();
}
};
pub const SocketConfig = struct {
hostname_or_unix: JSC.ZigString.Slice,
port: ?u16 = null,
ssl: ?JSC.API.ServerConfig.SSLConfig = null,
handlers: Handlers,
handlers_jsvalue: JSC.JSValue = .zero,
default_data: JSC.JSValue = .zero,
exclusive: bool = false,
allowHalfOpen: bool = false,
@@ -455,19 +262,18 @@ pub const SocketConfig = struct {
return globalObject.throwInvalidArguments("Expected either \"hostname\" or \"unix\"", .{});
}
var handlers = try Handlers.fromJS(globalObject, try opts.get(globalObject, "socket") orelse JSValue.zero);
const handlers = try SocketHandlers.create(globalObject, try opts.get(globalObject, "socket") orelse JSValue.zero);
defer handlers.ensureStillAlive();
if (opts.fastGet(globalObject, .data)) |default_data_value| {
default_data = default_data_value;
}
handlers.protect();
return SocketConfig{
.hostname_or_unix = hostname_or_unix,
.port = port,
.ssl = ssl,
.handlers = handlers,
.handlers_jsvalue = handlers,
.default_data = default_data,
.exclusive = exclusive,
.allowHalfOpen = allowHalfOpen,
@@ -512,7 +318,6 @@ fn normalizePipeName(pipe_name: []const u8, buffer: []u8) ?[]const u8 {
pub const Listener = struct {
pub const log = Output.scoped(.Listener, false);
handlers: Handlers,
listener: ListenerType = .none,
poll_ref: Async.KeepAlive = Async.KeepAlive.init(),
@@ -522,7 +327,7 @@ pub const Listener = struct {
protos: ?[]const u8 = null,
strong_data: JSC.Strong = .empty,
strong_self: JSC.Strong = .empty,
this_value: JSC.JSRef = .empty,
pub const js = JSC.Codegen.JSListener;
pub const toJS = js.toJS;
@@ -628,7 +433,9 @@ pub const Listener = struct {
var hostname_or_unix = socket_config.hostname_or_unix;
const port = socket_config.port;
var ssl = socket_config.ssl;
var handlers = socket_config.handlers;
const handlers_jsvalue = socket_config.handlers_jsvalue;
defer handlers_jsvalue.ensureStillAlive();
const handlers = handlers_jsvalue.as(SocketHandlers) orelse return globalObject.throwInvalidArguments("Expected \"socket\" object", .{});
var protos: ?[]const u8 = null;
handlers.is_server = true;
@@ -650,7 +457,6 @@ pub const Listener = struct {
}
}
var socket = Listener{
.handlers = handlers,
.connection = connection,
.ssl = ssl_enabled,
.socket_context = null,
@@ -660,8 +466,6 @@ pub const Listener = struct {
vm.eventLoop().ensureWaker();
socket.handlers.protect();
if (socket_config.default_data != .zero) {
socket.strong_data = JSC.Strong.create(socket_config.default_data, globalObject);
}
@@ -679,9 +483,10 @@ pub const Listener = struct {
};
const this_value = this.toJS(globalObject);
this.strong_self.set(globalObject, this_value);
this.this_value.setStrong(globalObject, this_value);
this.poll_ref.ref(handlers.vm);
js.handlersSetCached(this_value, globalObject, handlers_jsvalue);
return this_value;
}
}
@@ -703,7 +508,6 @@ pub const Listener = struct {
) orelse {
var err = globalObject.createErrorInstance("Failed to listen on {s}:{d}", .{ hostname_or_unix.slice(), port orelse 0 });
defer {
socket_config.handlers.unprotect();
hostname_or_unix.deinit();
}
@@ -821,8 +625,6 @@ pub const Listener = struct {
.protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch bun.outOfMemory()) else null,
};
socket.handlers.protect();
if (socket_config.default_data != .zero) {
socket.strong_data = JSC.Strong.create(socket_config.default_data, globalObject);
}
@@ -840,8 +642,9 @@ pub const Listener = struct {
this.socket_context.?.ext(ssl_enabled, *Listener).?.* = this;
const this_value = this.toJS(globalObject);
this.strong_self.set(globalObject, this_value);
this.this_value.setStrong(globalObject, this_value);
this.poll_ref.ref(handlers.vm);
js.handlersSetCached(this_value, globalObject, handlers_jsvalue);
return this_value;
}
@@ -864,7 +667,6 @@ pub const Listener = struct {
var this_socket = Socket.new(.{
.ref_count = .init(),
.handlers = &listener.handlers,
.this_value = .zero,
// here we start with a detached socket and attach it later after accept
.socket = Socket.Socket.detached,
@@ -877,6 +679,8 @@ pub const Listener = struct {
const globalObject = listener.handlers.globalObject;
Socket.js.dataSetCached(this_socket.getThisValue(globalObject), globalObject, default_data);
}
const this_value = this_socket.getThisValue(globalObject);
js.handlersSetCached(listener.this(), globalObject, js.handlersGetCached(this_value).?);
return this_socket;
}
@@ -965,7 +769,7 @@ pub const Listener = struct {
this.socket_context = null;
ctx.deinit(this.ssl);
}
this.strong_self.clearWithoutDeallocation();
this.this_value.downgrade();
this.strong_data.clearWithoutDeallocation();
} else {
if (force_close) {
@@ -998,22 +802,10 @@ pub const Listener = struct {
pub fn deinit(this: *Listener) void {
log("deinit", .{});
this.strong_self.deinit();
this.this_value.deinit();
this.strong_data.deinit();
this.poll_ref.unref(this.handlers.vm);
bun.assert(this.listener == .none);
this.handlers.unprotect();
if (this.handlers.active_connections > 0) {
if (this.socket_context) |ctx| {
ctx.close(this.ssl);
}
// TODO: fix this leak.
} else {
if (this.socket_context) |ctx| {
ctx.deinit(this.ssl);
}
}
this.connection.deinit();
if (this.protos) |protos| {
@@ -1053,14 +845,14 @@ pub const Listener = struct {
const this_value = callframe.this();
if (this.listener == .none) return JSValue.jsUndefined();
this.poll_ref.ref(globalObject.bunVM());
this.strong_self.set(globalObject, this_value);
this.this_value.setStrong(globalObject, this_value);
return JSValue.jsUndefined();
}
pub fn unref(this: *Listener, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue {
this.poll_ref.unref(globalObject.bunVM());
if (this.handlers.active_connections == 0) {
this.strong_self.clearWithoutDeallocation();
this.this_value.downgrade();
}
return JSValue.jsUndefined();
}
@@ -2030,7 +1822,7 @@ fn NewSocket(comptime ssl: bool) type {
}
const l: *Listener = @fieldParentPtr("handlers", this.handlers);
return l.strong_self.get() orelse JSValue.jsUndefined();
return l.this_value.get() orelse JSValue.jsUndefined();
}
pub fn getReadyState(
@@ -4512,3 +4304,5 @@ pub fn jsCreateSocketPair(global: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JS
array.putIndex(global, 1, JSC.jsNumber(fds_[1]));
return array;
}
pub const SocketHandlers = @import("./SocketHandlers.zig");

View File

@@ -221,6 +221,7 @@ function generate(ssl) {
finalize: true,
construct: true,
klass: {},
values: ["handlers"],
});
}
const sslOnly = {
@@ -290,6 +291,7 @@ export default [
finalize: true,
construct: true,
klass: {},
values: ["handlers"],
}),
define({
@@ -448,4 +450,26 @@ export default [
},
},
}),
define({
name: "SocketHandlers",
construct: false,
call: false,
noConstructor: true,
finalize: true,
proto: {},
klass: {},
values: [
"onData",
"onWritable",
"onOpen",
"onClose",
"onTimeout",
"onConnectError",
"onEnd",
"onError",
"onHandshake",
"promise",
],
}),
];

View File

@@ -11,10 +11,6 @@ pub const JSRef = union(enum) {
return .{ .strong = JSC.Strong.create(value, globalThis) };
}
pub fn empty() @This() {
return .{ .weak = .zero };
}
pub fn get(this: *@This()) JSC.JSValue {
return switch (this.*) {
.weak => this.weak,
@@ -75,6 +71,21 @@ pub const JSRef = union(enum) {
.finalized => {},
}
}
pub fn downgrade(this: *@This()) void {
switch (this.*) {
.weak => {},
.strong => {
const value = this.strong.get() orelse {
this.* = .{ .weak = .zero };
return;
};
this.strong.deinit();
this.* = .{ .weak = value };
},
.finalized => {},
}
}
};
const JSC = bun.JSC;

View File

@@ -83,4 +83,5 @@ pub const Classes = struct {
pub const S3Stat = JSC.WebCore.S3Stat;
pub const HTMLBundle = JSC.API.HTMLBundle;
pub const RedisClient = JSC.API.Valkey;
pub const SocketHandlers = JSC.API.SocketHandlers;
};

View File

@@ -59,6 +59,7 @@ pub const API = struct {
pub const NativeBrotli = @import("./bun.js/node/node_zlib_binding.zig").SNativeBrotli;
pub const HTMLBundle = @import("./bun.js/api/server/HTMLBundle.zig");
pub const Valkey = @import("./valkey/js_valkey.zig").JSValkeyClient;
pub const SocketHandlers = @import("./bun.js/api/bun/SocketHandlers.zig");
};
pub const Postgres = @import("./sql/postgres.zig");
pub const DNS = @import("./bun.js/api/bun/dns_resolver.zig");