Make TCPSocket/TLSSocket.handlers optional (please don't free undefined)

This commit is contained in:
Kai Tamkun
2025-07-21 13:56:21 -07:00
parent 71d9084252
commit a73bbcd150
3 changed files with 69 additions and 56 deletions

View File

@@ -76,8 +76,7 @@ pub fn NewSocket(comptime ssl: bool) type {
flags: Flags = .{},
ref_count: RefCount,
wrapped: WrappedType = .none,
// TODO: make this optional
handlers: *Handlers,
handlers: ?*Handlers,
this_value: JSC.JSValue = .zero,
poll_ref: Async.KeepAlive = Async.KeepAlive.init(),
ref_pollref_on_connect: bool = true,
@@ -224,7 +223,7 @@ pub fn NewSocket(comptime ssl: bool) type {
pub fn handleError(this: *This, err_value: JSC.JSValue) void {
log("handleError", .{});
const handlers = this.handlers;
const handlers = this.getHandlers();
var vm = handlers.vm;
if (vm.isShuttingDown()) {
return;
@@ -242,7 +241,7 @@ pub fn NewSocket(comptime ssl: bool) type {
JSC.markBinding(@src());
if (this.socket.isDetached()) return;
if (this.native_callback.onWritable()) return;
const handlers = this.handlers;
const handlers = this.getHandlers();
const callback = handlers.onWritable;
if (callback == .zero) return;
@@ -272,8 +271,8 @@ pub fn NewSocket(comptime ssl: bool) type {
pub fn onTimeout(this: *This, _: Socket) void {
JSC.markBinding(@src());
if (this.socket.isDetached()) return;
log("onTimeout {s}", .{if (this.handlers.is_server) "S" else "C"});
const handlers = this.handlers;
const handlers = this.getHandlers();
log("onTimeout {s}", .{if (handlers.is_server) "S" else "C"});
const callback = handlers.onTimeout;
if (callback == .zero or this.flags.finalizing) return;
if (handlers.vm.isShuttingDown()) {
@@ -292,8 +291,13 @@ pub fn NewSocket(comptime ssl: bool) type {
};
}
pub fn getHandlers(this: *const This) *Handlers {
return this.handlers orelse @panic("No handlers set on Socket");
}
pub fn handleConnectError(this: *This, errno: c_int) void {
log("onConnectError {s} ({d}, {d})", .{ if (this.handlers.is_server) "S" else "C", errno, this.ref_count.active_counts });
const handlers = this.getHandlers();
log("onConnectError {s} ({d}, {d})", .{ if (handlers.is_server) "S" else "C", errno, this.ref_count.active_counts });
// Ensure the socket is still alive for any defer's we have
this.ref();
defer this.deref();
@@ -304,7 +308,6 @@ pub fn NewSocket(comptime ssl: bool) type {
defer this.markInactive();
defer if (needs_deref) this.deref();
const handlers = this.handlers;
const vm = handlers.vm;
this.poll_ref.unrefOnNextTick(vm);
if (vm.isShuttingDown()) {
@@ -373,7 +376,7 @@ pub fn NewSocket(comptime ssl: bool) type {
pub fn markActive(this: *This) void {
if (!this.flags.is_active) {
this.handlers.markActive();
this.getHandlers().markActive();
this.flags.is_active = true;
this.has_pending_activity.store(true, .release);
}
@@ -401,15 +404,20 @@ pub fn NewSocket(comptime ssl: bool) type {
}
this.flags.is_active = false;
const vm = this.handlers.vm;
this.handlers.markInactive();
const handlers = this.getHandlers();
const vm = handlers.vm;
handlers.markInactive();
this.poll_ref.unref(vm);
this.has_pending_activity.store(false, .release);
}
}
pub fn isServer(this: *const This) bool {
return this.getHandlers().is_server;
}
pub fn onOpen(this: *This, socket: Socket) void {
log("onOpen {s} {*} {} {}", .{ if (this.handlers.is_server) "S" else "C", this, this.socket.isDetached(), this.ref_count.active_counts });
log("onOpen {s} {*} {} {}", .{ if (this.isServer()) "S" else "C", this, this.socket.isDetached(), this.ref_count.active_counts });
// Ensure the socket remains alive until this is finished
this.ref();
defer this.deref();
@@ -441,7 +449,7 @@ pub fn NewSocket(comptime ssl: bool) type {
}
}
if (this.protos) |protos| {
if (this.handlers.is_server) {
if (this.isServer()) {
BoringSSL.SSL_CTX_set_alpn_select_cb(BoringSSL.SSL_get_SSL_CTX(ssl_ptr), selectALPNCallback, bun.cast(*anyopaque, this));
} else {
_ = BoringSSL.SSL_set_alpn_protos(ssl_ptr, protos.ptr, @as(c_uint, @intCast(protos.len)));
@@ -457,7 +465,7 @@ pub fn NewSocket(comptime ssl: bool) type {
}
}
const handlers = this.handlers;
const handlers = this.getHandlers();
const callback = handlers.onOpen;
const handshake_callback = handlers.onHandshake;
@@ -509,13 +517,12 @@ pub fn NewSocket(comptime ssl: bool) type {
pub fn onEnd(this: *This, _: Socket) void {
JSC.markBinding(@src());
if (this.socket.isDetached()) return;
log("onEnd {s}", .{if (this.handlers.is_server) "S" else "C"});
const handlers = this.getHandlers();
log("onEnd {s}", .{if (handlers.is_server) "S" else "C"});
// Ensure the socket remains alive until this is finished
this.ref();
defer this.deref();
const handlers = this.handlers;
const callback = handlers.onEnd;
if (callback == .zero or handlers.vm.isShuttingDown()) {
this.poll_ref.unref(handlers.vm);
@@ -541,13 +548,13 @@ pub fn NewSocket(comptime ssl: bool) type {
JSC.markBinding(@src());
this.flags.handshake_complete = true;
if (this.socket.isDetached()) return;
log("onHandshake {s} ({d})", .{ if (this.handlers.is_server) "S" else "C", success });
const handlers = this.getHandlers();
log("onHandshake {s} ({d})", .{ if (handlers.is_server) "S" else "C", success });
const authorized = if (success == 1) true else false;
this.flags.authorized = authorized;
const handlers = this.handlers;
var callback = handlers.onHandshake;
var is_open = false;
@@ -583,8 +590,8 @@ pub fn NewSocket(comptime ssl: bool) type {
// clean onOpen callback so only called in the first handshake and not in every renegotiation
// on servers this would require a different approach but it's not needed because our servers will not call handshake multiple times
// servers don't support renegotiation
this.handlers.onOpen.unprotect();
this.handlers.onOpen = .zero;
this.handlers.?.onOpen.unprotect();
this.handlers.?.onOpen = .zero;
}
} else {
// call handhsake callback with authorized and authorization error if has one
@@ -607,7 +614,8 @@ pub fn NewSocket(comptime ssl: bool) type {
pub fn onClose(this: *This, _: Socket, err: c_int, _: ?*anyopaque) void {
JSC.markBinding(@src());
log("onClose {s}", .{if (this.handlers.is_server) "S" else "C"});
const handlers = this.getHandlers();
log("onClose {s}", .{if (handlers.is_server) "S" else "C"});
this.detachNativeCallback();
this.socket.detach();
defer this.deref();
@@ -617,7 +625,6 @@ pub fn NewSocket(comptime ssl: bool) type {
return;
}
const handlers = this.handlers;
const vm = handlers.vm;
this.poll_ref.unref(vm);
@@ -654,10 +661,10 @@ pub fn NewSocket(comptime ssl: bool) type {
pub fn onData(this: *This, _: Socket, data: []const u8) void {
JSC.markBinding(@src());
if (this.socket.isDetached()) return;
log("onData {s} ({d})", .{ if (this.handlers.is_server) "S" else "C", data.len });
const handlers = this.getHandlers();
log("onData {s} ({d})", .{ if (handlers.is_server) "S" else "C", data.len });
if (this.native_callback.onData(data)) return;
const handlers = this.handlers;
const callback = handlers.onData;
if (callback == .zero or this.flags.finalizing) return;
if (handlers.vm.isShuttingDown()) {
@@ -696,11 +703,13 @@ pub fn NewSocket(comptime ssl: bool) type {
}
pub fn getListener(this: *This, _: *JSC.JSGlobalObject) JSValue {
if (!this.handlers.is_server or this.socket.isDetached()) {
const handlers = this.getHandlers();
if (!handlers.is_server or this.socket.isDetached()) {
return .js_undefined;
}
const l: *Listener = @fieldParentPtr("handlers", this.handlers);
const l: *Listener = @fieldParentPtr("handlers", handlers);
return l.strong_self.get() orelse .js_undefined;
}
@@ -1357,13 +1366,14 @@ pub fn NewSocket(comptime ssl: bool) type {
return globalObject.throw("Expected \"socket\" option", .{});
};
const handlers = try Handlers.fromJS(globalObject, socket_obj, this.handlers.is_server);
var prev_handlers = this.getHandlers();
const handlers = try Handlers.fromJS(globalObject, socket_obj, prev_handlers.is_server);
var prev_handlers = this.handlers;
prev_handlers.unprotect();
this.handlers.* = handlers; // TODO: this is a memory leak
this.handlers.withAsyncContextIfNeeded(globalObject);
this.handlers.protect();
this.handlers.?.* = handlers; // TODO: this is a memory leak
this.handlers.?.withAsyncContextIfNeeded(globalObject);
this.handlers.?.protect();
return .js_undefined;
}
@@ -1405,7 +1415,7 @@ pub fn NewSocket(comptime ssl: bool) type {
return .zero;
}
var handlers = try Handlers.fromJS(globalObject, socket_obj, this.handlers.is_server);
var handlers = try Handlers.fromJS(globalObject, socket_obj, this.isServer());
if (globalObject.hasException()) {
return .zero;
@@ -1535,20 +1545,23 @@ pub fn NewSocket(comptime ssl: bool) type {
const vm = handlers.vm;
var raw_handlers_ptr = bun.default_allocator.create(Handlers) catch bun.outOfMemory();
raw_handlers_ptr.* = .{
.vm = vm,
.globalObject = globalObject,
.onOpen = this.handlers.onOpen,
.onClose = this.handlers.onClose,
.onData = this.handlers.onData,
.onWritable = this.handlers.onWritable,
.onTimeout = this.handlers.onTimeout,
.onConnectError = this.handlers.onConnectError,
.onEnd = this.handlers.onEnd,
.onError = this.handlers.onError,
.onHandshake = this.handlers.onHandshake,
.binary_type = this.handlers.binary_type,
.is_server = this.handlers.is_server,
raw_handlers_ptr.* = blk: {
const this_handlers = this.getHandlers();
break :blk .{
.vm = vm,
.globalObject = globalObject,
.onOpen = this_handlers.onOpen,
.onClose = this_handlers.onClose,
.onData = this_handlers.onData,
.onWritable = this_handlers.onWritable,
.onTimeout = this_handlers.onTimeout,
.onConnectError = this_handlers.onConnectError,
.onEnd = this_handlers.onEnd,
.onError = this_handlers.onError,
.onHandshake = this_handlers.onHandshake,
.binary_type = this_handlers.binary_type,
.is_server = this_handlers.is_server,
};
};
raw_handlers_ptr.protect();
@@ -1578,7 +1591,7 @@ pub fn NewSocket(comptime ssl: bool) type {
tls.markActive();
// we're unrefing the original instance and refing the TLS instance
tls.poll_ref.ref(this.handlers.vm);
tls.poll_ref.ref(this.getHandlers().vm);
// mark both instances on socket data
if (new_socket.ext(WrappedSocket)) |ctx| {
@@ -1590,7 +1603,7 @@ pub fn NewSocket(comptime ssl: bool) type {
this.flags.is_active = false;
// will free handlers when hits 0 active connections
// the connection can be upgraded inside a handler call so we need to guarantee that it will be still alive
this.handlers.markInactive();
this.getHandlers().markInactive();
this.has_pending_activity.store(false, .release);
}

View File

@@ -11,7 +11,7 @@ pub fn getServername(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.Cal
}
pub fn setServername(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue {
if (this.handlers.is_server) {
if (this.isServer()) {
return globalObject.throw("Cannot issue SNI from a TLS server-side socket", .{});
}
@@ -120,7 +120,7 @@ pub fn getPeerCertificate(this: *This, globalObject: *JSC.JSGlobalObject, callfr
const ssl_ptr = this.socket.ssl() orelse return .js_undefined;
if (abbreviated) {
if (this.handlers.is_server) {
if (this.isServer()) {
const cert = BoringSSL.SSL_get_peer_certificate(ssl_ptr);
if (cert) |x509| {
return X509.toJS(x509, globalObject);
@@ -132,7 +132,7 @@ pub fn getPeerCertificate(this: *This, globalObject: *JSC.JSGlobalObject, callfr
return X509.toJS(cert, globalObject);
}
var cert: ?*BoringSSL.X509 = null;
if (this.handlers.is_server) {
if (this.isServer()) {
cert = BoringSSL.SSL_get_peer_certificate(ssl_ptr);
}
@@ -382,7 +382,7 @@ pub fn exportKeyingMaterial(this: *This, globalObject: *JSC.JSGlobalObject, call
pub fn getEphemeralKeyInfo(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue {
// only available for clients
if (this.handlers.is_server) {
if (this.isServer()) {
return JSValue.jsNull();
}
var result = JSValue.createEmptyObject(globalObject, 3);
@@ -555,7 +555,7 @@ pub fn setVerifyMode(this: *This, globalObject: *JSC.JSGlobalObject, callframe:
const request_cert = request_cert_js.toBoolean();
const reject_unauthorized = request_cert_js.toBoolean();
var verify_mode: c_int = BoringSSL.SSL_VERIFY_NONE;
if (this.handlers.is_server) {
if (this.isServer()) {
if (request_cert) {
verify_mode = BoringSSL.SSL_VERIFY_PEER;
if (reject_unauthorized)

View File

@@ -84,7 +84,7 @@ pub fn newDetachedSocket(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFr
.socket_context = null,
.ref_count = .init(),
.protos = null,
.handlers = undefined,
.handlers = null,
});
return socket.getThisValue(globalThis);
} else {
@@ -93,7 +93,7 @@ pub fn newDetachedSocket(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFr
.socket_context = null,
.ref_count = .init(),
.protos = null,
.handlers = undefined,
.handlers = null,
});
return socket.getThisValue(globalThis);
}