From adaee0713850243c95be6dca205a624b176a4c14 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Mon, 18 Nov 2024 19:38:23 -0800 Subject: [PATCH] [Bun.sql] Support TLS (#15217) Co-authored-by: Ciro Spaciari --- packages/bun-usockets/src/crypto/openssl.c | 37 ++- packages/bun-usockets/src/libusockets.h | 2 + src/bun.js/api/bun/socket.zig | 18 +- src/bun.js/bindings/bindings.zig | 2 +- src/deps/uws.zig | 26 ++ src/js/bun/sql.ts | 94 ++++++- src/sql/postgres.zig | 313 ++++++++++++++++++--- test/js/sql/tls-sql.test.ts | 23 ++ 8 files changed, 440 insertions(+), 75 deletions(-) create mode 100644 test/js/sql/tls-sql.test.ts diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index 5880fa35cc..4c4c2a76d5 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -201,7 +201,7 @@ struct loop_ssl_data * us_internal_set_loop_ssl_data(struct us_internal_ssl_sock struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, int is_client, char *ip, - int ip_length) { + int ip_length, const char* sni) { struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *)us_socket_context(0, &s->s); @@ -231,6 +231,10 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, if (is_client) { SSL_set_renegotiate_mode(s->ssl, ssl_renegotiate_explicit); SSL_set_connect_state(s->ssl); + + if (sni) { + SSL_set_tlsext_host_name(s->ssl, sni); + } } else { SSL_set_accept_state(s->ssl); // we do not allow renegotiation on the server side (should be the default for BoringSSL, but we set to make openssl compatible) @@ -1603,6 +1607,10 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_connect_unix( socket_ext_size); } +static void ssl_on_open_without_sni(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length) { + ssl_on_open(s, is_client, ip, ip_length, NULL); +} + void us_internal_ssl_socket_context_on_open( struct us_internal_ssl_socket_context_t *context, struct us_internal_ssl_socket_t *(*on_open)( @@ -1611,7 +1619,7 @@ void us_internal_ssl_socket_context_on_open( us_socket_context_on_open( 0, &context->sc, (struct us_socket_t * (*)(struct us_socket_t *, int, char *, int)) - ssl_on_open); + ssl_on_open_without_sni); context->on_open = on_open; } @@ -2005,7 +2013,30 @@ us_internal_ssl_socket_open(struct us_internal_ssl_socket_t *s, int is_client, return s; // start SSL open - return ssl_on_open(s, is_client, ip, ip_length); + return ssl_on_open(s, is_client, ip, ip_length, NULL); +} + +struct us_socket_t *us_socket_upgrade_to_tls(us_socket_r s, us_socket_context_r new_context, const char *sni) { + // Resize to tls + ext size + void** prev_ext_ptr = (void**)us_socket_ext(0, s); + void* prev_ext = *prev_ext_ptr; + struct us_internal_ssl_socket_t *socket = + (struct us_internal_ssl_socket_t *)us_socket_context_adopt_socket( + 0, new_context, s, + (sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t)) + sizeof(void*)); + socket->ssl = NULL; + socket->ssl_write_wants_read = 0; + socket->ssl_read_wants_write = 0; + socket->fatal_error = 0; + socket->handshake_state = HANDSHAKE_PENDING; + + void** new_ext_ptr = (void**)us_socket_ext(1, (struct us_socket_t *)socket); + *new_ext_ptr = prev_ext; + + ssl_on_open(socket, 1, NULL, 0, sni); + + + return (struct us_socket_t *)socket; } struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls( diff --git a/packages/bun-usockets/src/libusockets.h b/packages/bun-usockets/src/libusockets.h index 94dd58e22b..55657f8845 100644 --- a/packages/bun-usockets/src/libusockets.h +++ b/packages/bun-usockets/src/libusockets.h @@ -339,6 +339,8 @@ struct us_loop_t *us_socket_context_loop(int ssl, us_socket_context_r context) n * Used mainly for "socket upgrades" such as when transitioning from HTTP to WebSocket. */ struct us_socket_t *us_socket_context_adopt_socket(int ssl, us_socket_context_r context, us_socket_r s, int ext_size); +struct us_socket_t *us_socket_upgrade_to_tls(us_socket_r s, us_socket_context_r new_context, const char *sni); + /* Create a child socket context which acts much like its own socket context with its own callbacks yet still relies on the * parent socket context for some shared resources. Child socket contexts should be used together with socket adoptions and nothing else. */ struct us_socket_context_t *us_create_child_socket_context(int ssl, us_socket_context_r context, int context_ext_size); diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 050b6ded6e..ac738e40ed 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -1866,21 +1866,11 @@ fn NewSocket(comptime ssl: bool) type { } } else { // call handhsake callback with authorized and authorization error if has one - var authorization_error: JSValue = undefined; - if (ssl_error.error_no == 0) { - authorization_error = JSValue.jsNull(); - } else { - const code = if (ssl_error.code == null) "" else ssl_error.code[0..bun.len(ssl_error.code)]; + const authorization_error: JSValue = if (ssl_error.error_no == 0) + JSValue.jsNull() + else + ssl_error.toJS(globalObject); - const reason = if (ssl_error.reason == null) "" else ssl_error.reason[0..bun.len(ssl_error.reason)]; - - const fallback = JSC.SystemError{ - .code = bun.String.createUTF8(code), - .message = bun.String.createUTF8(reason), - }; - - authorization_error = fallback.toErrorInstance(globalObject); - } result = callback.call(globalObject, this_value, &[_]JSValue{ this_value, JSValue.jsBoolean(authorized), diff --git a/src/bun.js/bindings/bindings.zig b/src/bun.js/bindings/bindings.zig index afe79db515..187e163b20 100644 --- a/src/bun.js/bindings/bindings.zig +++ b/src/bun.js/bindings/bindings.zig @@ -6803,7 +6803,7 @@ pub const CallFrame = opaque { const ptr = self.argumentsPtr(); return switch (@as(u4, @min(len, max))) { 0 => .{ .ptr = undefined, .len = 0 }, - inline 1...9 => |count| Arguments(max).init(comptime @min(count, max), ptr), + inline 1...10 => |count| Arguments(max).init(comptime @min(count, max), ptr), else => unreachable, }; } diff --git a/src/deps/uws.zig b/src/deps/uws.zig index b50e1e76db..65a58c30f5 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -2640,6 +2640,18 @@ pub const create_bun_socket_error_t = enum(i32) { load_ca_file, invalid_ca_file, invalid_ca, + + pub fn toJS(this: create_bun_socket_error_t, globalObject: *JSC.JSGlobalObject) JSC.JSValue { + return switch (this) { + .none => brk: { + bun.debugAssert(false); + break :brk .null; + }, + .load_ca_file => globalObject.ERR_BORINGSSL("Failed to load CA file", .{}).toJS(), + .invalid_ca_file => globalObject.ERR_BORINGSSL("Invalid CA file", .{}).toJS(), + .invalid_ca => globalObject.ERR_BORINGSSL("Invalid CA", .{}).toJS(), + }; + } }; pub const us_bun_verify_error_t = extern struct { @@ -2647,6 +2659,18 @@ pub const us_bun_verify_error_t = extern struct { error_no: i64 = 0, code: [*c]const u8 = null, reason: [*c]const u8 = null, + + pub fn toJS(this: *const us_bun_verify_error_t, globalObject: *JSC.JSGlobalObject) JSC.JSValue { + const code = if (this.code == null) "" else this.code[0..bun.len(this.code)]; + const reason = if (this.reason == null) "" else this.reason[0..bun.len(this.reason)]; + + const fallback = JSC.SystemError{ + .code = bun.String.createUTF8(code), + .message = bun.String.createUTF8(reason), + }; + + return fallback.toErrorInstance(globalObject); + } }; pub extern fn us_ssl_socket_verify_error_from_ssl(ssl: *BoringSSL.SSL) us_bun_verify_error_t; @@ -4522,3 +4546,5 @@ pub fn onThreadExit() void { } extern fn uws_app_clear_routes(ssl_flag: c_int, app: *uws_app_t) void; + +pub extern fn us_socket_upgrade_to_tls(s: *Socket, new_context: *SocketContext, sni: ?[*:0]const u8) ?*Socket; diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index adbe607882..f4f29f431f 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -7,6 +7,13 @@ const enum QueryStatus { const cmds = ["", "INSERT", "DELETE", "UPDATE", "MERGE", "SELECT", "MOVE", "FETCH", "COPY"]; const PublicArray = globalThis.Array; +const enum SSLMode { + disable = 0, + prefer = 1, + require = 2, + verify_ca = 3, + verify_full = 4, +} class SQLResultArray extends PublicArray { static [Symbol.toStringTag] = "SQLResults"; @@ -34,6 +41,33 @@ const { init, } = $zig("postgres.zig", "createBinding"); +function normalizeSSLMode(value: string): SSLMode { + if (!value) { + return SSLMode.disable; + } + + value = (value + "").toLowerCase(); + switch (value) { + case "disable": + return SSLMode.disable; + case "prefer": + return SSLMode.prefer; + case "require": + return SSLMode.require; + case "verify-ca": + case "verify_ca": + return SSLMode.verify_ca; + case "verify-full": + case "verify_full": + return SSLMode.verify_full; + default: { + break; + } + } + + throw $ERR_INVALID_ARG_VALUE(`Invalid SSL mode: ${value}`); +} + class Query extends PublicPromise { [_resolve]; [_reject]; @@ -162,26 +196,27 @@ init( try { query.resolve(result); - } catch (e) { - console.log(e); - } + } catch (e) {} }, function (query, reject) { try { query.reject(reject); - } catch (e) { - console.log(e); - } + } catch (e) {} }, ); -function createConnection({ hostname, port, username, password, tls, query, database }, onConnected, onClose) { +function createConnection({ hostname, port, username, password, tls, query, database, sslMode }, onConnected, onClose) { return _createConnection( hostname, Number(port), username || "", password || "", database || "", + // > The default value for sslmode is prefer. As is shown in the table, this + // makes no sense from a security point of view, and it only promises + // performance overhead if possible. It is only provided as the default for + // backward compatibility, and is not recommended in secure deployments. + sslMode || SSLMode.disable, tls || null, query || "", onConnected, @@ -279,9 +314,17 @@ class SQLArrayParameter { function loadOptions(o) { var hostname, port, username, password, database, tls, url, query, adapter; const env = Bun.env; + var sslMode: SSLMode = SSLMode.disable; if (o === undefined || (typeof o === "string" && o.length === 0)) { - const urlString = env.POSTGRES_URL || env.DATABASE_URL || env.PGURL || env.PG_URL; + let urlString = env.POSTGRES_URL || env.DATABASE_URL || env.PGURL || env.PG_URL; + if (!urlString) { + urlString = env.TLS_POSTGRES_DATABASE_URL || env.TLS_DATABASE_URL; + if (urlString) { + sslMode = SSLMode.require; + } + } + if (urlString) { url = new URL(urlString); o = {}; @@ -297,6 +340,11 @@ function loadOptions(o) { url = _url; } } + + if (o?.tls) { + sslMode = SSLMode.require; + tls = o.tls; + } } else if (typeof o === "string") { url = new URL(o); } @@ -306,18 +354,19 @@ function loadOptions(o) { if (adapter[adapter.length - 1] === ":") { adapter = adapter.slice(0, -1); } + const queryObject = url.searchParams.toJSON(); query = ""; for (const key in queryObject) { - query += `${encodeURIComponent(key)}=${encodeURIComponent(queryObject[key])} `; + if (key.toLowerCase() === "sslmode") { + sslMode = normalizeSSLMode(queryObject[key]); + } else { + query += `${encodeURIComponent(key)}=${encodeURIComponent(queryObject[key])} `; + } } query = query.trim(); } - if (!o) { - o = {}; - } - hostname ||= o.hostname || o.host || env.PGHOST || "localhost"; port ||= Number(o.port || env.PGPORT || 5432); username ||= o.username || o.user || env.PGUSERNAME || env.PGUSER || env.USER || env.USERNAME || "postgres"; @@ -326,6 +375,19 @@ function loadOptions(o) { tls ||= o.tls || o.ssl; adapter ||= o.adapter || "postgres"; + if (sslMode !== SSLMode.disable && !tls?.serverName) { + if (hostname) { + tls = { + serverName: hostname, + }; + } else { + tls = true; + } + } + + if (!!tls) { + sslMode = SSLMode.prefer; + } port = Number(port); if (!Number.isSafeInteger(port) || port < 1 || port > 65535) { @@ -336,7 +398,7 @@ function loadOptions(o) { throw new Error(`Unsupported adapter: ${adapter}. Only \"postgres\" is supported for now`); } - return { hostname, port, username, password, database, tls, query }; + return { hostname, port, username, password, database, tls, query, sslMode }; } function SQL(o) { @@ -507,6 +569,10 @@ function SQL(o) { var lazyDefaultSQL; var defaultSQLObject = function sql(strings, ...values) { + if (new.target) { + return SQL(strings); + } + if (!lazyDefaultSQL) { lazyDefaultSQL = SQL(undefined); Object.assign(defaultSQLObject, lazyDefaultSQL); diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index 496fef2b18..aefdd4c5de 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -12,6 +12,15 @@ pub const short = u16; pub const PostgresShort = u16; const Crypto = JSC.API.Bun.Crypto; const JSValue = JSC.JSValue; +const BoringSSL = @import("../boringssl.zig"); + +pub const SSLMode = enum(u8) { + disable = 0, + prefer = 1, + require = 2, + verify_ca = 3, + verify_full = 4, +}; pub const Data = union(enum) { owned: bun.ByteList, @@ -837,7 +846,16 @@ pub const PostgresRequest = struct { switch (try reader.int(u8)) { 'D' => try connection.on(.DataRow, Context, reader), 'd' => try connection.on(.CopyData, Context, reader), - 'S' => try connection.on(.ParameterStatus, Context, reader), + 'S' => { + if (connection.tls_status == .message_sent) { + bun.debugAssert(connection.tls_status.message_sent == 8); + connection.tls_status = .ssl_ok; + connection.setupTLS(); + return; + } + + try connection.on(.ParameterStatus, Context, reader); + }, 'Z' => try connection.on(.ReadyForQuery, Context, reader), 'C' => try connection.on(.CommandComplete, Context, reader), '2' => try connection.on(.BindComplete, Context, reader), @@ -851,7 +869,19 @@ pub const PostgresRequest = struct { 's' => try connection.on(.PortalSuspended, Context, reader), '3' => try connection.on(.CloseComplete, Context, reader), 'G' => try connection.on(.CopyInResponse, Context, reader), - 'N' => try connection.on(.NoticeResponse, Context, reader), + 'N' => { + if (connection.tls_status == .message_sent) { + connection.tls_status = .ssl_not_available; + debug("Server does not support SSL", .{}); + if (connection.ssl_mode == .require) { + connection.fail("Server does not support SSL", error.SSLNotAvailable); + return; + } + continue; + } + + try connection.on(.NoticeResponse, Context, reader); + }, 'I' => try connection.on(.EmptyQueryResponse, Context, reader), 'H' => try connection.on(.CopyOutResponse, Context, reader), 'c' => try connection.on(.CopyDone, Context, reader), @@ -904,6 +934,23 @@ pub const PostgresSQLConnection = struct { authentication_state: AuthenticationState = .{ .pending = {} }, + tls_ctx: ?*uws.SocketContext = null, + tls_config: JSC.API.ServerConfig.SSLConfig = .{}, + tls_status: TLSStatus = .none, + ssl_mode: SSLMode = .disable, + + pub const TLSStatus = union(enum) { + none, + pending, + + /// Number of bytes sent of the 8-byte SSL request message. + /// Since we may send a partial message, we need to know how many bytes were sent. + message_sent: u8, + + ssl_not_available, + ssl_ok, + }; + pub const AuthenticationState = union(enum) { pending: void, SASL: SASL, @@ -1005,12 +1052,41 @@ pub const PostgresSQLConnection = struct { pub const Status = enum { disconnected, connecting, + // Prevent sending the startup message multiple times. + // Particularly relevant for TLS connections. + sent_startup_message, connected, failed, }; pub usingnamespace JSC.Codegen.JSPostgresSQLConnection; + pub fn setupTLS(this: *PostgresSQLConnection) void { + debug("setupTLS", .{}); + const new_socket = uws.us_socket_upgrade_to_tls(this.socket.SocketTCP.socket.connected, this.tls_ctx.?, this.tls_config.server_name) orelse { + this.fail("Failed to upgrade to TLS", error.TLSUpgradeFailed); + return; + }; + this.socket = .{ + .SocketTLS = .{ + .socket = .{ + .connected = new_socket, + }, + }, + }; + + this.start(); + } + + fn start(this: *PostgresSQLConnection) void { + this.sendStartupMessage(); + + const event_loop = this.globalObject.bunVM().eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + this.flushData(); + } + pub fn hasPendingActivity(this: *PostgresSQLConnection) bool { @fence(.acquire); return this.pending_activity_count.load(.acquire) > 0; @@ -1059,48 +1135,110 @@ pub const PostgresSQLConnection = struct { } } - pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: anyerror) void { + pub fn failWithJSValue(this: *PostgresSQLConnection, value: JSValue) void { defer this.updateHasPendingActivity(); if (this.status == .failed) return; - debug("failed: {s}: {s}", .{ message, @errorName(err) }); this.status = .failed; if (!this.socket.isClosed()) this.socket.close(); const on_close = this.on_close.swap(); if (on_close == .zero) return; - const instance = this.globalObject.createErrorInstance("{s}", .{message}); - instance.put(this.globalObject, JSC.ZigString.static("code"), String.init(@errorName(err)).toJS(this.globalObject)); + _ = on_close.call( this.globalObject, this.js_value, &[_]JSValue{ - instance, + value, }, ) catch |e| this.globalObject.reportActiveExceptionAsUnhandled(e); } + pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: anyerror) void { + debug("failed: {s}: {s}", .{ message, @errorName(err) }); + const instance = this.globalObject.createErrorInstance("{s}", .{message}); + instance.put(this.globalObject, JSC.ZigString.static("code"), String.init(@errorName(err)).toJS(this.globalObject)); + this.failWithJSValue(instance); + } + pub fn onClose(this: *PostgresSQLConnection) void { var vm = this.globalObject.bunVM(); defer vm.drainMicrotasks(); this.fail("Connection closed", error.ConnectionClosed); } + fn sendStartupMessage(this: *PostgresSQLConnection) void { + if (this.status != .connecting) return; + debug("sendStartupMessage", .{}); + this.status = .sent_startup_message; + var msg = protocol.StartupMessage{ + .user = Data{ .temporary = this.user }, + .database = Data{ .temporary = this.database }, + .options = Data{ .temporary = this.options }, + }; + msg.writeInternal(Writer, this.writer()) catch |err| { + this.socket.close(); + this.fail("Failed to write startup message", err); + }; + } + + fn startTLS(this: *PostgresSQLConnection, socket: uws.AnySocket) void { + debug("startTLS", .{}); + const offset = switch (this.tls_status) { + .message_sent => |count| count, + else => 0, + }; + const ssl_request = [_]u8{ + 0x00, 0x00, 0x00, 0x08, // Length + 0x04, 0xD2, 0x16, 0x2F, // SSL request code + }; + + const written = socket.write(ssl_request[offset..], false); + if (written > 0) { + this.tls_status = .{ + .message_sent = offset + @as(u8, @intCast(written)), + }; + } else { + this.tls_status = .{ + .message_sent = offset, + }; + } + } + pub fn onOpen(this: *PostgresSQLConnection, socket: uws.AnySocket) void { this.socket = socket; this.poll_ref.ref(this.globalObject.bunVM()); this.updateHasPendingActivity(); - var msg = protocol.StartupMessage{ .user = Data{ .temporary = this.user }, .database = Data{ .temporary = this.database }, .options = Data{ .temporary = this.options } }; - msg.writeInternal(Writer, this.writer()) catch |err| { - socket.close(); - this.fail("Failed to write startup message", err); - }; + if (this.tls_status == .message_sent or this.tls_status == .pending) { + this.startTLS(socket); + return; + } - const event_loop = this.globalObject.bunVM().eventLoop(); - event_loop.enter(); - defer event_loop.exit(); - this.flushData(); + this.start(); + } + + pub fn onHandshake(this: *PostgresSQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + debug("onHandshake: {d} {d}", .{ success, ssl_error.error_no }); + + if (success != 1) { + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + return; + } + + if (this.tls_config.reject_unauthorized == 1) { + if (ssl_error.error_no != 0) { + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + return; + } + const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + if (BoringSSL.SSL_get_servername(ssl_ptr, 0)) |servername| { + const hostname = servername[0..bun.len(servername)]; + if (!BoringSSL.checkServerIdentity(ssl_ptr, hostname)) { + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + } + } + } } pub fn onTimeout(this: *PostgresSQLConnection) void { @@ -1109,6 +1247,16 @@ pub const PostgresSQLConnection = struct { } pub fn onDrain(this: *PostgresSQLConnection) void { + + // Don't send any other messages while we're waiting for TLS. + if (this.tls_status == .message_sent) { + if (this.tls_status.message_sent < 8) { + this.startTLS(this.socket); + } + + return; + } + const event_loop = this.globalObject.bunVM().eventLoop(); event_loop.enter(); defer event_loop.exit(); @@ -1226,7 +1374,7 @@ pub const PostgresSQLConnection = struct { pub fn call(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { var vm = globalObject.bunVM(); - const arguments = callframe.arguments(9).slice(); + const arguments = callframe.arguments(10).slice(); const hostname_str = arguments[0].toBunString(globalObject); defer hostname_str.deref(); const port = arguments[1].coerce(i32, globalObject); @@ -1237,13 +1385,67 @@ pub const PostgresSQLConnection = struct { defer password_str.deref(); const database_str = arguments[4].toBunString(globalObject); defer database_str.deref(); - const tls_object = arguments[5]; + const ssl_mode: SSLMode = switch (arguments[5].toInt32()) { + 0 => .disable, + 1 => .prefer, + 2 => .require, + 3 => .verify_ca, + 4 => .verify_full, + else => .disable, + }; + + const tls_object = arguments[6]; + + var tls_config: JSC.API.ServerConfig.SSLConfig = .{}; + var tls_ctx: ?*uws.SocketContext = null; + if (ssl_mode != .disable) { + tls_config = if (tls_object.isBoolean() and tls_object.toBoolean()) + .{} + else if (tls_object.isObject()) + (JSC.API.ServerConfig.SSLConfig.fromJS(vm, globalObject, tls_object) catch return .zero) orelse .{} + else { + globalObject.throwInvalidArguments("tls must be a boolean or an object", .{}); + return .zero; + }; + + if (globalObject.hasException()) { + tls_config.deinit(); + return .zero; + } + + if (tls_config.reject_unauthorized != 0) + tls_config.request_cert = 1; + + // We create it right here so we can throw errors early. + const context_options = tls_config.asUSockets(); + var err: uws.create_bun_socket_error_t = .none; + tls_ctx = uws.us_create_bun_socket_context(1, vm.uwsLoop(), @sizeOf(*PostgresSQLConnection), context_options, &err) orelse { + if (err != .none) { + globalObject.throw("failed to create TLS context", .{}); + } else { + globalObject.throwValue(err.toJS(globalObject)); + } + return .zero; + }; + + if (err != .none) { + tls_config.deinit(); + globalObject.throwValue(err.toJS(globalObject)); + if (tls_ctx) |ctx| { + ctx.deinit(true); + } + return .zero; + } + + uws.NewSocketHandler(true).configure(tls_ctx.?, true, *PostgresSQLConnection, SocketHandler(true)); + } + var username: []const u8 = ""; var password: []const u8 = ""; var database: []const u8 = ""; var options: []const u8 = ""; - const options_str = arguments[6].toBunString(globalObject); + const options_str = arguments[7].toBunString(globalObject); defer options_str.deref(); const options_buf: []u8 = brk: { @@ -1270,10 +1472,15 @@ pub const PostgresSQLConnection = struct { break :brk b.allocatedSlice(); }; - const on_connect = arguments[7]; - const on_close = arguments[8]; + const on_connect = arguments[8]; + const on_close = arguments[9]; + var ptr = bun.default_allocator.create(PostgresSQLConnection) catch |err| { globalObject.throwError(err, "failed to allocate connection"); + tls_config.deinit(); + if (tls_ctx) |ctx| { + ctx.deinit(true); + } return .zero; }; @@ -1289,6 +1496,10 @@ pub const PostgresSQLConnection = struct { .socket = undefined, .requests = PostgresRequest.Queue.init(bun.default_allocator), .statements = PreparedStatementsMap{}, + .tls_config = tls_config, + .tls_ctx = tls_ctx, + .ssl_mode = ssl_mode, + .tls_status = if (ssl_mode != .disable) .pending else .none, }; ptr.updateHasPendingActivity(); @@ -1300,28 +1511,25 @@ pub const PostgresSQLConnection = struct { { const hostname = hostname_str.toUTF8(bun.default_allocator); defer hostname.deinit(); - if (tls_object.isEmptyOrUndefinedOrNull()) { - const ctx = vm.rareData().postgresql_context.tcp orelse brk: { - var err: uws.create_bun_socket_error_t = .none; - const ctx_ = uws.us_create_bun_socket_context(0, vm.uwsLoop(), @sizeOf(*PostgresSQLConnection), uws.us_bun_socket_context_options_t{}, &err).?; - uws.NewSocketHandler(false).configure(ctx_, true, *PostgresSQLConnection, SocketHandler(false)); - vm.rareData().postgresql_context.tcp = ctx_; - break :brk ctx_; - }; - ptr.socket = .{ - // TODO: investigate if allowHalfOpen: true is necessary here or if brings some advantage - .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| { - globalObject.throwError(err, "failed to connect to postgresql"); - ptr.deinit(); - return .zero; - }, - }; - } else { - // TODO: - globalObject.throwTODO("TLS is not supported yet"); - ptr.deinit(); - return .zero; - } + + const ctx = vm.rareData().postgresql_context.tcp orelse brk: { + var err: uws.create_bun_socket_error_t = .none; + const ctx_ = uws.us_create_bun_socket_context(0, vm.uwsLoop(), @sizeOf(*PostgresSQLConnection), uws.us_bun_socket_context_options_t{}, &err).?; + uws.NewSocketHandler(false).configure(ctx_, true, *PostgresSQLConnection, SocketHandler(false)); + vm.rareData().postgresql_context.tcp = ctx_; + break :brk ctx_; + }; + ptr.socket = .{ + .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| { + globalObject.throwError(err, "failed to connect to postgresql"); + tls_config.deinit(); + if (tls_ctx) |tls| { + tls.deinit(true); + } + ptr.deinit(); + return .zero; + }, + }; } return js_value; @@ -1341,6 +1549,12 @@ pub const PostgresSQLConnection = struct { this.onOpen(_socket(socket)); } + fn onHandshake_(this: *PostgresSQLConnection, _: anytype, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + this.onHandshake(success, ssl_error); + } + + pub const onHandshake = if (ssl) onHandshake_ else null; + pub fn onClose(this: *PostgresSQLConnection, socket: SocketType, _: i32, _: ?*anyopaque) void { _ = socket; this.onClose(); @@ -1421,6 +1635,7 @@ pub const PostgresSQLConnection = struct { this.on_connect.deinit(); this.backend_parameters.deinit(); bun.default_allocator.free(this.options_buf); + this.tls_config.deinit(); bun.default_allocator.destroy(this); } @@ -2147,6 +2362,18 @@ pub const PostgresSQLConnection = struct { this.fail("Unknown authentication method", error.UNKNOWN_AUTHENTICATION_METHOD); }, + .ClearTextPassword => { + debug("ClearTextPassword", .{}); + var response = protocol.PasswordMessage{ + .password = .{ + .temporary = this.password, + }, + }; + + try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); + this.flushData(); + }, + else => { debug("TODO auth: {s}", .{@tagName(std.meta.activeTag(auth))}); }, diff --git a/test/js/sql/tls-sql.test.ts b/test/js/sql/tls-sql.test.ts new file mode 100644 index 0000000000..2bc99bd3ad --- /dev/null +++ b/test/js/sql/tls-sql.test.ts @@ -0,0 +1,23 @@ +import { test, expect } from "bun:test"; +import { getSecret } from "harness"; +import { sql as SQL } from "bun"; + +const TLS_POSTGRES_DATABASE_URL = getSecret("TLS_POSTGRES_DATABASE_URL"); + +test("tls (explicit)", async () => { + const sql = new SQL({ + url: TLS_POSTGRES_DATABASE_URL!, + tls: true, + adapter: "postgresql", + }); + + const [{ one, two }] = await sql`SELECT 1 as one, '2' as two`; + expect(one).toBe(1); + expect(two).toBe("2"); +}); + +test("tls (implicit)", async () => { + const [{ one, two }] = await SQL`SELECT 1 as one, '2' as two`; + expect(one).toBe(1); + expect(two).toBe("2"); +});