Initial work on cert callbacks

This commit is contained in:
Kai Tamkun
2025-06-10 15:39:10 -07:00
parent 1ec0718e05
commit af5c0bdea6

View File

@@ -1380,6 +1380,7 @@ fn NewSocket(comptime ssl: bool) type {
buffered_data_for_node_net: bun.ByteList = .{},
bytes_written: u64 = 0,
sni_callback: JSC.Strong.Optional = .empty,
cert_callback: JSC.Strong.Optional = .empty,
// TODO: switch to something that uses `visitAggregate` and have the
// `Listener` keep a list of all the sockets JSValue in there
@@ -1433,7 +1434,8 @@ fn NewSocket(comptime ssl: bool) type {
owned_protos: bool = true,
is_paused: bool = false,
allow_half_open: bool = false,
_: u7 = 0,
cert_cb_running: bool = false,
_: u6 = 0,
};
pub fn hasPendingActivity(this: *This) callconv(.C) bool {
@@ -1790,6 +1792,7 @@ fn NewSocket(comptime ssl: bool) type {
}
const ctx = BoringSSL.SSL_get_SSL_CTX(ssl_ptr);
_ = BoringSSL.SSL_set_app_data(ssl_ptr, this);
if (this.protos) |protos| {
if (this.handlers.is_server) {
@@ -1801,23 +1804,24 @@ fn NewSocket(comptime ssl: bool) type {
if (this.handlers.is_server) {
_ = BoringSSL.SSL_CTX_set_tlsext_servername_callback(ctx, struct {
fn cb(cb_ssl: ?*BoringSSL.SSL, _: [*c]c_int, arg: ?*anyopaque) callconv(.C) c_int {
const sn_type: c_int = BoringSSL.SSL_get_servername_type(cb_ssl);
if (sn_type == -1) {
return BoringSSL.SSL_TLSEXT_ERR_OK;
fn cb(cb_ssl: ?*BoringSSL.SSL, _: [*c]c_int, _: ?*anyopaque) callconv(.C) c_int {
const servername: [*c]const u8 = BoringSSL.SSL_get_servername(cb_ssl, BoringSSL.TLSEXT_NAMETYPE_host_name);
if (servername == null) {
return BoringSSL.SSL_TLSEXT_ERR_NOACK;
}
const sn: [*c]const u8 = BoringSSL.SSL_get_servername(cb_ssl, sn_type);
if (sn == null) {
return BoringSSL.SSL_TLSEXT_ERR_OK;
}
const cb_this: *This = @alignCast(@ptrCast(arg));
cb_this.onSNI(cb_this.socket, sn[0..std.mem.len(sn)]);
return BoringSSL.SSL_TLSEXT_ERR_OK;
const cb_this: *This = @alignCast(@ptrCast(BoringSSL.SSL_get_app_data(cb_ssl)));
return cb_this.onSNI(servername[0..std.mem.len(servername)]);
}
}.cb);
_ = BoringSSL.SSL_CTX_set_tlsext_servername_arg(ctx, bun.cast(*anyopaque, this));
_ = BoringSSL.SSL_set_cert_cb(ssl_ptr, struct {
fn cb(cb_ssl: ?*BoringSSL.SSL, _: ?*anyopaque) callconv(.C) c_int {
const servername: [*c]const u8 = BoringSSL.SSL_get_servername(cb_ssl, BoringSSL.TLSEXT_NAMETYPE_host_name) orelse "";
const cb_this: *This = @alignCast(@ptrCast(BoringSSL.SSL_get_app_data(cb_ssl)));
return cb_this.onCert(servername[0..std.mem.len(servername)]);
}
}.cb, bun.cast(*anyopaque, this));
}
}
}
@@ -2020,10 +2024,14 @@ fn NewSocket(comptime ssl: bool) type {
};
}
pub fn onSNI(this: *This, _: Socket, servername: []const u8) void {
pub fn onSNI(this: *This, servername: []const u8) c_int {
if (comptime ssl == false) {
return BoringSSL.SSL_TLSEXT_ERR_NOACK;
}
JSC.markBinding(@src());
log("onSNI {s} ({s})", .{ if (this.handlers.is_server) "S" else "C", servername });
if (this.socket.isDetached()) return;
if (this.socket.isDetached()) return BoringSSL.SSL_TLSEXT_ERR_NOACK;
if (this.sni_callback.get()) |callback| {
const globalObject = this.handlers.globalObject;
@@ -2036,6 +2044,59 @@ fn NewSocket(comptime ssl: bool) type {
_ = this.handlers.callErrorHandler(this_value, &.{ this_value, globalObject.takeError(err) });
};
}
return BoringSSL.SSL_TLSEXT_ERR_OK;
}
fn isWaitingCertCb(this: *This) bool {
return this.cert_callback.get() != null;
}
pub fn onCert(this: *This, servername: []const u8) c_int {
if (comptime ssl == false) {
return 1;
}
if (!this.handlers.is_server or !this.isWaitingCertCb()) {
return 1;
}
if (this.flags.cert_cb_running) {
return -1;
}
JSC.markBinding(@src());
log("onCert {s} ({s})", .{ if (this.handlers.is_server) "S" else "C", servername });
if (this.socket.isDetached()) return -1;
this.flags.cert_cb_running = true;
// Presence already verified by isWaitingCertCb
const callback = this.cert_callback.get().?;
const globalObject = this.handlers.globalObject;
const this_value = this.getThisValue(globalObject);
_ = callback.call(globalObject, this_value, &[_]JSValue{
this_value,
ZigString.init(servername).toJS(globalObject),
}) catch |err| {
_ = this.handlers.callErrorHandler(this_value, &.{ this_value, globalObject.takeError(err) });
};
return if (this.flags.cert_cb_running) -1 else 1;
}
pub fn certCallbackDone(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue {
bun.assert(this.isWaitingCertCb() and this.flags.cert_cb_running);
const this_value: JSC.JSValue = this.getThisValue(globalObject);
_ = this_value;
_ = callframe;
// TODO(@heimskr)
return error.JSError;
}
pub fn onData(this: *This, _: Socket, data: []const u8) void {
@@ -3229,6 +3290,10 @@ fn NewSocket(comptime ssl: bool) type {
this.sni_callback.set(globalObject, value);
}
pub fn setCertCallback(this: *This, globalObject: *JSC.JSGlobalObject, value: JSC.JSValue) void {
this.cert_callback.set(globalObject, value);
}
pub fn setMaxSendFragment(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue {
JSC.markBinding(@src());
if (comptime ssl == false) {