diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 21d596fc05..3d980d1022 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -1380,6 +1380,7 @@ fn NewSocket(comptime ssl: bool) type { buffered_data_for_node_net: bun.ByteList = .{}, bytes_written: u64 = 0, sni_callback: JSC.Strong.Optional = .empty, + cert_callback: JSC.Strong.Optional = .empty, // TODO: switch to something that uses `visitAggregate` and have the // `Listener` keep a list of all the sockets JSValue in there @@ -1433,7 +1434,8 @@ fn NewSocket(comptime ssl: bool) type { owned_protos: bool = true, is_paused: bool = false, allow_half_open: bool = false, - _: u7 = 0, + cert_cb_running: bool = false, + _: u6 = 0, }; pub fn hasPendingActivity(this: *This) callconv(.C) bool { @@ -1790,6 +1792,7 @@ fn NewSocket(comptime ssl: bool) type { } const ctx = BoringSSL.SSL_get_SSL_CTX(ssl_ptr); + _ = BoringSSL.SSL_set_app_data(ssl_ptr, this); if (this.protos) |protos| { if (this.handlers.is_server) { @@ -1801,23 +1804,24 @@ fn NewSocket(comptime ssl: bool) type { if (this.handlers.is_server) { _ = BoringSSL.SSL_CTX_set_tlsext_servername_callback(ctx, struct { - fn cb(cb_ssl: ?*BoringSSL.SSL, _: [*c]c_int, arg: ?*anyopaque) callconv(.C) c_int { - const sn_type: c_int = BoringSSL.SSL_get_servername_type(cb_ssl); - if (sn_type == -1) { - return BoringSSL.SSL_TLSEXT_ERR_OK; + fn cb(cb_ssl: ?*BoringSSL.SSL, _: [*c]c_int, _: ?*anyopaque) callconv(.C) c_int { + const servername: [*c]const u8 = BoringSSL.SSL_get_servername(cb_ssl, BoringSSL.TLSEXT_NAMETYPE_host_name); + if (servername == null) { + return BoringSSL.SSL_TLSEXT_ERR_NOACK; } - const sn: [*c]const u8 = BoringSSL.SSL_get_servername(cb_ssl, sn_type); - if (sn == null) { - return BoringSSL.SSL_TLSEXT_ERR_OK; - } - - const cb_this: *This = @alignCast(@ptrCast(arg)); - cb_this.onSNI(cb_this.socket, sn[0..std.mem.len(sn)]); - return BoringSSL.SSL_TLSEXT_ERR_OK; + const cb_this: *This = @alignCast(@ptrCast(BoringSSL.SSL_get_app_data(cb_ssl))); + return cb_this.onSNI(servername[0..std.mem.len(servername)]); } }.cb); - _ = BoringSSL.SSL_CTX_set_tlsext_servername_arg(ctx, bun.cast(*anyopaque, this)); + + _ = BoringSSL.SSL_set_cert_cb(ssl_ptr, struct { + fn cb(cb_ssl: ?*BoringSSL.SSL, _: ?*anyopaque) callconv(.C) c_int { + const servername: [*c]const u8 = BoringSSL.SSL_get_servername(cb_ssl, BoringSSL.TLSEXT_NAMETYPE_host_name) orelse ""; + const cb_this: *This = @alignCast(@ptrCast(BoringSSL.SSL_get_app_data(cb_ssl))); + return cb_this.onCert(servername[0..std.mem.len(servername)]); + } + }.cb, bun.cast(*anyopaque, this)); } } } @@ -2020,10 +2024,14 @@ fn NewSocket(comptime ssl: bool) type { }; } - pub fn onSNI(this: *This, _: Socket, servername: []const u8) void { + pub fn onSNI(this: *This, servername: []const u8) c_int { + if (comptime ssl == false) { + return BoringSSL.SSL_TLSEXT_ERR_NOACK; + } + JSC.markBinding(@src()); log("onSNI {s} ({s})", .{ if (this.handlers.is_server) "S" else "C", servername }); - if (this.socket.isDetached()) return; + if (this.socket.isDetached()) return BoringSSL.SSL_TLSEXT_ERR_NOACK; if (this.sni_callback.get()) |callback| { const globalObject = this.handlers.globalObject; @@ -2036,6 +2044,59 @@ fn NewSocket(comptime ssl: bool) type { _ = this.handlers.callErrorHandler(this_value, &.{ this_value, globalObject.takeError(err) }); }; } + + return BoringSSL.SSL_TLSEXT_ERR_OK; + } + + fn isWaitingCertCb(this: *This) bool { + return this.cert_callback.get() != null; + } + + pub fn onCert(this: *This, servername: []const u8) c_int { + if (comptime ssl == false) { + return 1; + } + + if (!this.handlers.is_server or !this.isWaitingCertCb()) { + return 1; + } + + if (this.flags.cert_cb_running) { + return -1; + } + + JSC.markBinding(@src()); + log("onCert {s} ({s})", .{ if (this.handlers.is_server) "S" else "C", servername }); + if (this.socket.isDetached()) return -1; + + this.flags.cert_cb_running = true; + + // Presence already verified by isWaitingCertCb + const callback = this.cert_callback.get().?; + + const globalObject = this.handlers.globalObject; + const this_value = this.getThisValue(globalObject); + + _ = callback.call(globalObject, this_value, &[_]JSValue{ + this_value, + ZigString.init(servername).toJS(globalObject), + }) catch |err| { + _ = this.handlers.callErrorHandler(this_value, &.{ this_value, globalObject.takeError(err) }); + }; + + return if (this.flags.cert_cb_running) -1 else 1; + } + + pub fn certCallbackDone(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + bun.assert(this.isWaitingCertCb() and this.flags.cert_cb_running); + + const this_value: JSC.JSValue = this.getThisValue(globalObject); + + _ = this_value; + _ = callframe; + + // TODO(@heimskr) + return error.JSError; } pub fn onData(this: *This, _: Socket, data: []const u8) void { @@ -3229,6 +3290,10 @@ fn NewSocket(comptime ssl: bool) type { this.sni_callback.set(globalObject, value); } + pub fn setCertCallback(this: *This, globalObject: *JSC.JSGlobalObject, value: JSC.JSValue) void { + this.cert_callback.set(globalObject, value); + } + pub fn setMaxSendFragment(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { JSC.markBinding(@src()); if (comptime ssl == false) {