diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index 5b22659d34..668d628124 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -64,6 +64,31 @@ struct loop_ssl_data { BIO_METHOD *shared_biom; }; + +enum us_ssl_sni_result_type { + // no cert or error + US_SSL_SNI_RESULT_NONE = 0, + // we need to parse a new SSL_CTX + US_SSL_SNI_RESULT_OPTIONS = 1, + // most optimal case + US_SSL_SNI_RESULT_SSL_CONTEXT = 2, +}; +union us_ssl_sni_result { + struct us_bun_socket_context_options_t options; + SSL_CTX* ssl_context; +}; + +// tagged union for sni result +struct us_tagged_ssl_sni_result { + uint8_t tag; + union us_ssl_sni_result val; +}; + +void (*us_sni_result_cb)(struct us_internal_ssl_socket_t*, struct us_tagged_ssl_sni_result result); +void (*us_sni_callback)(struct us_internal_ssl_socket_t*, + const char *hostname, us_tagged_ssl_sni_result result_cb, void* ctx) + + struct us_internal_ssl_socket_context_t { struct us_socket_context_t sc; @@ -98,6 +123,10 @@ struct us_internal_ssl_socket_context_t { us_internal_on_handshake_t on_handshake; void *handshake_data; + + // dynamic sni callback + us_sni_callback on_sni_callback; + void *on_sni_callback_ctx; }; // same here, should or shouldn't it @@ -114,6 +143,8 @@ struct us_internal_ssl_socket_t { unsigned int ssl_read_wants_write : 1; unsigned int handshake_state : 2; unsigned int fatal_error : 1; + unsigned int sni_callback_running : 1; + unsigned int cert_cb_running : 1; }; int passphrase_cb(char *buf, int size, int rwflag, void *u) { @@ -213,6 +244,11 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, s->ssl_read_wants_write = 0; s->fatal_error = 0; s->handshake_state = HANDSHAKE_PENDING; + s->sni_callback_running = 0; + s->cert_cb_running = 0; + if(context->on_sni_callback) { + SSL_set_cert_cb(s->ssl, us_internal_ssl_cert_cb, s); + } SSL_set_bio(s->ssl, loop_ssl_data->shared_rbio, loop_ssl_data->shared_wbio); @@ -404,7 +440,8 @@ void us_internal_update_handshake(struct us_internal_ssl_socket_t *s) { if (result <= 0) { int err = SSL_get_error(s->ssl, result); // as far as I know these are the only errors we want to handle - if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) { + // SSL_ERROR_WANT_X509_LOOKUP is a special case for SNI with means the promise/callback is still running + if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE && err != SSL_ERROR_WANT_X509_LOOKUP) { // clear per thread error queue if it may contain something if (err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) { ERR_clear_error(); @@ -1341,12 +1378,84 @@ us_internal_ssl_socket_get_sni_userdata(struct us_internal_ssl_socket_t *s) { return SSL_CTX_get_ex_data(SSL_get_SSL_CTX(s->ssl), 0); } + + +void us_internal_ssl_socket_context_sni_result( + struct us_internal_ssl_socket_t *s, + struct us_tagged_ssl_sni_result result) { + + s->cert_cb_running = 0; + + + switch(result.tag) { + case US_SSL_SNI_RESULT_OPTIONS: + enum create_bun_socket_error_t err = CREATE_BUN_SOCKET_ERROR_NONE; + SSL_CTX *ssl_context = create_ssl_context_from_bun_options(result.val.options, &err); + if (ssl_context) { + SSL_set_SSL_CTX(s->ssl, ssl_context); + } else { + // error in this case lets fallback to the default and continue + } + break; + case US_SSL_SNI_RESULT_SSL_CONTEXT: + SSL_CTX *ssl_context = result.val.ssl_context; + if (ssl_context) { + // set ssl context + SSL_set_SSL_CTX(s->ssl, ssl_context); + } else { + // error in this case lets fallback to the default and continue + } + break; + } + // if cert_cb_running is 1 it means we are in the middle of a handshake already so no need to update again + // if cert_cb_running is 0 it means this callback is async and we need to update the handshake + if(s->cert_cb_running == 0) { + // continue handshake + us_internal_update_handshake(s); + } +} +int us_internal_ssl_cert_cb(SSL *ssl, void *arg) { + + struct us_internal_ssl_socket_t *s = (struct us_internal_ssl_socket_t *)arg; + struct us_internal_ssl_socket_context_t *context = + (struct us_internal_ssl_socket_context_t *)us_socket_context(0, &s->s); + + if(!context) return 1; + + if(context->on_sni_callback && s->cert_cb_running == 0) { + s->cert_cb_running = 1; + s->sni_callback_running = 1; + context->on_sni_callback(s, SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name), us_internal_ssl_socket_context_sni_result, context->on_sni_callback_ctx); + s->cert_cb_running = 0; + + // if callback is done, return 1 + if(s->sni_callback_running == 0) { + return 1; + } + + // still waiting for callback + return -1; + } + + // if no callback, use default otherwise still waiting for callback + return s->sni_callback_running == 0 ? 1 : -1; +} +void us_internal_ssl_socket_context_add_sni_callback( + struct us_internal_ssl_socket_context_t *context, + us_sni_callback cb, void* ctx) { + + + context->on_sni_callback = cb; + context->on_sni_callback_ctx = ctx; +} + /* Todo: return error on failure? */ void us_internal_ssl_socket_context_add_server_name( struct us_internal_ssl_socket_context_t *context, const char *hostname_pattern, struct us_socket_context_options_t options, void *user) { + /* Try and construct an SSL_CTX from options */ SSL_CTX *ssl_context = create_ssl_context_from_options(options); diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 5b35669156..6f57d50280 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -3877,8 +3877,9 @@ pub const WindowsNamedPipeListeningContext = if (Environment.isWindows) struct { BoringSSL.load(); const ctx_opts: uws.us_bun_socket_context_options_t = JSC.API.ServerConfig.SSLConfig.asUSockets(ssl_options); + var err: uws.create_bun_socket_error_t = .none; // Create SSL context using uSockets to match behavior of node.js - const ctx = uws.create_ssl_context_from_bun_options(ctx_opts) orelse return error.InvalidOptions; // invalid options + const ctx = uws.create_ssl_context_from_bun_options(ctx_opts, &err) orelse return error.InvalidOptions; // invalid options errdefer BoringSSL.SSL_CTX_free(ctx); this.ctx = ctx; } diff --git a/src/bun.js/api/bun/ssl_wrapper.zig b/src/bun.js/api/bun/ssl_wrapper.zig index c75fba25fa..7a8a74378f 100644 --- a/src/bun.js/api/bun/ssl_wrapper.zig +++ b/src/bun.js/api/bun/ssl_wrapper.zig @@ -97,8 +97,9 @@ pub fn SSLWrapper(comptime T: type) type { BoringSSL.load(); const ctx_opts: uws.us_bun_socket_context_options_t = JSC.API.ServerConfig.SSLConfig.asUSockets(ssl_options); + var err: uws.create_bun_socket_error_t = .none; // Create SSL context using uSockets to match behavior of node.js - const ctx = uws.create_ssl_context_from_bun_options(ctx_opts) orelse return error.InvalidOptions; // invalid options + const ctx = uws.create_ssl_context_from_bun_options(ctx_opts, &err) orelse return error.InvalidOptions; // invalid options errdefer BoringSSL.SSL_CTX_free(ctx); return try This.initWithCTX(ctx, is_client, handlers); } diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 30b90a7404..8f8bddfe79 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -2641,7 +2641,15 @@ pub const us_bun_socket_context_options_t = extern struct { client_renegotiation_limit: u32 = 3, client_renegotiation_window: u32 = 600, }; -pub extern fn create_ssl_context_from_bun_options(options: us_bun_socket_context_options_t) ?*BoringSSL.SSL_CTX; + +pub const create_bun_socket_error_t = enum(c_int) { + none = 0, + load_ca_file, + invalid_ca_file, + invalid_ca, +}; + +pub extern fn create_ssl_context_from_bun_options(options: us_bun_socket_context_options_t, err: ?*create_bun_socket_error_t) ?*BoringSSL.SSL_CTX; pub const create_bun_socket_error_t = enum(i32) { none = 0,