[Bun.sql] Support TLS (#15217)

Co-authored-by: Ciro Spaciari <ciro.spaciari@gmail.com>
This commit is contained in:
Jarred Sumner
2024-11-18 19:38:23 -08:00
committed by GitHub
parent 8a0666acd1
commit adaee07138
8 changed files with 440 additions and 75 deletions

View File

@@ -12,6 +12,15 @@ pub const short = u16;
pub const PostgresShort = u16;
const Crypto = JSC.API.Bun.Crypto;
const JSValue = JSC.JSValue;
const BoringSSL = @import("../boringssl.zig");
pub const SSLMode = enum(u8) {
disable = 0,
prefer = 1,
require = 2,
verify_ca = 3,
verify_full = 4,
};
pub const Data = union(enum) {
owned: bun.ByteList,
@@ -837,7 +846,16 @@ pub const PostgresRequest = struct {
switch (try reader.int(u8)) {
'D' => try connection.on(.DataRow, Context, reader),
'd' => try connection.on(.CopyData, Context, reader),
'S' => try connection.on(.ParameterStatus, Context, reader),
'S' => {
if (connection.tls_status == .message_sent) {
bun.debugAssert(connection.tls_status.message_sent == 8);
connection.tls_status = .ssl_ok;
connection.setupTLS();
return;
}
try connection.on(.ParameterStatus, Context, reader);
},
'Z' => try connection.on(.ReadyForQuery, Context, reader),
'C' => try connection.on(.CommandComplete, Context, reader),
'2' => try connection.on(.BindComplete, Context, reader),
@@ -851,7 +869,19 @@ pub const PostgresRequest = struct {
's' => try connection.on(.PortalSuspended, Context, reader),
'3' => try connection.on(.CloseComplete, Context, reader),
'G' => try connection.on(.CopyInResponse, Context, reader),
'N' => try connection.on(.NoticeResponse, Context, reader),
'N' => {
if (connection.tls_status == .message_sent) {
connection.tls_status = .ssl_not_available;
debug("Server does not support SSL", .{});
if (connection.ssl_mode == .require) {
connection.fail("Server does not support SSL", error.SSLNotAvailable);
return;
}
continue;
}
try connection.on(.NoticeResponse, Context, reader);
},
'I' => try connection.on(.EmptyQueryResponse, Context, reader),
'H' => try connection.on(.CopyOutResponse, Context, reader),
'c' => try connection.on(.CopyDone, Context, reader),
@@ -904,6 +934,23 @@ pub const PostgresSQLConnection = struct {
authentication_state: AuthenticationState = .{ .pending = {} },
tls_ctx: ?*uws.SocketContext = null,
tls_config: JSC.API.ServerConfig.SSLConfig = .{},
tls_status: TLSStatus = .none,
ssl_mode: SSLMode = .disable,
pub const TLSStatus = union(enum) {
none,
pending,
/// Number of bytes sent of the 8-byte SSL request message.
/// Since we may send a partial message, we need to know how many bytes were sent.
message_sent: u8,
ssl_not_available,
ssl_ok,
};
pub const AuthenticationState = union(enum) {
pending: void,
SASL: SASL,
@@ -1005,12 +1052,41 @@ pub const PostgresSQLConnection = struct {
pub const Status = enum {
disconnected,
connecting,
// Prevent sending the startup message multiple times.
// Particularly relevant for TLS connections.
sent_startup_message,
connected,
failed,
};
pub usingnamespace JSC.Codegen.JSPostgresSQLConnection;
pub fn setupTLS(this: *PostgresSQLConnection) void {
debug("setupTLS", .{});
const new_socket = uws.us_socket_upgrade_to_tls(this.socket.SocketTCP.socket.connected, this.tls_ctx.?, this.tls_config.server_name) orelse {
this.fail("Failed to upgrade to TLS", error.TLSUpgradeFailed);
return;
};
this.socket = .{
.SocketTLS = .{
.socket = .{
.connected = new_socket,
},
},
};
this.start();
}
fn start(this: *PostgresSQLConnection) void {
this.sendStartupMessage();
const event_loop = this.globalObject.bunVM().eventLoop();
event_loop.enter();
defer event_loop.exit();
this.flushData();
}
pub fn hasPendingActivity(this: *PostgresSQLConnection) bool {
@fence(.acquire);
return this.pending_activity_count.load(.acquire) > 0;
@@ -1059,48 +1135,110 @@ pub const PostgresSQLConnection = struct {
}
}
pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: anyerror) void {
pub fn failWithJSValue(this: *PostgresSQLConnection, value: JSValue) void {
defer this.updateHasPendingActivity();
if (this.status == .failed) return;
debug("failed: {s}: {s}", .{ message, @errorName(err) });
this.status = .failed;
if (!this.socket.isClosed()) this.socket.close();
const on_close = this.on_close.swap();
if (on_close == .zero) return;
const instance = this.globalObject.createErrorInstance("{s}", .{message});
instance.put(this.globalObject, JSC.ZigString.static("code"), String.init(@errorName(err)).toJS(this.globalObject));
_ = on_close.call(
this.globalObject,
this.js_value,
&[_]JSValue{
instance,
value,
},
) catch |e| this.globalObject.reportActiveExceptionAsUnhandled(e);
}
pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: anyerror) void {
debug("failed: {s}: {s}", .{ message, @errorName(err) });
const instance = this.globalObject.createErrorInstance("{s}", .{message});
instance.put(this.globalObject, JSC.ZigString.static("code"), String.init(@errorName(err)).toJS(this.globalObject));
this.failWithJSValue(instance);
}
pub fn onClose(this: *PostgresSQLConnection) void {
var vm = this.globalObject.bunVM();
defer vm.drainMicrotasks();
this.fail("Connection closed", error.ConnectionClosed);
}
fn sendStartupMessage(this: *PostgresSQLConnection) void {
if (this.status != .connecting) return;
debug("sendStartupMessage", .{});
this.status = .sent_startup_message;
var msg = protocol.StartupMessage{
.user = Data{ .temporary = this.user },
.database = Data{ .temporary = this.database },
.options = Data{ .temporary = this.options },
};
msg.writeInternal(Writer, this.writer()) catch |err| {
this.socket.close();
this.fail("Failed to write startup message", err);
};
}
fn startTLS(this: *PostgresSQLConnection, socket: uws.AnySocket) void {
debug("startTLS", .{});
const offset = switch (this.tls_status) {
.message_sent => |count| count,
else => 0,
};
const ssl_request = [_]u8{
0x00, 0x00, 0x00, 0x08, // Length
0x04, 0xD2, 0x16, 0x2F, // SSL request code
};
const written = socket.write(ssl_request[offset..], false);
if (written > 0) {
this.tls_status = .{
.message_sent = offset + @as(u8, @intCast(written)),
};
} else {
this.tls_status = .{
.message_sent = offset,
};
}
}
pub fn onOpen(this: *PostgresSQLConnection, socket: uws.AnySocket) void {
this.socket = socket;
this.poll_ref.ref(this.globalObject.bunVM());
this.updateHasPendingActivity();
var msg = protocol.StartupMessage{ .user = Data{ .temporary = this.user }, .database = Data{ .temporary = this.database }, .options = Data{ .temporary = this.options } };
msg.writeInternal(Writer, this.writer()) catch |err| {
socket.close();
this.fail("Failed to write startup message", err);
};
if (this.tls_status == .message_sent or this.tls_status == .pending) {
this.startTLS(socket);
return;
}
const event_loop = this.globalObject.bunVM().eventLoop();
event_loop.enter();
defer event_loop.exit();
this.flushData();
this.start();
}
pub fn onHandshake(this: *PostgresSQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void {
debug("onHandshake: {d} {d}", .{ success, ssl_error.error_no });
if (success != 1) {
this.failWithJSValue(ssl_error.toJS(this.globalObject));
return;
}
if (this.tls_config.reject_unauthorized == 1) {
if (ssl_error.error_no != 0) {
this.failWithJSValue(ssl_error.toJS(this.globalObject));
return;
}
const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle()));
if (BoringSSL.SSL_get_servername(ssl_ptr, 0)) |servername| {
const hostname = servername[0..bun.len(servername)];
if (!BoringSSL.checkServerIdentity(ssl_ptr, hostname)) {
this.failWithJSValue(ssl_error.toJS(this.globalObject));
}
}
}
}
pub fn onTimeout(this: *PostgresSQLConnection) void {
@@ -1109,6 +1247,16 @@ pub const PostgresSQLConnection = struct {
}
pub fn onDrain(this: *PostgresSQLConnection) void {
// Don't send any other messages while we're waiting for TLS.
if (this.tls_status == .message_sent) {
if (this.tls_status.message_sent < 8) {
this.startTLS(this.socket);
}
return;
}
const event_loop = this.globalObject.bunVM().eventLoop();
event_loop.enter();
defer event_loop.exit();
@@ -1226,7 +1374,7 @@ pub const PostgresSQLConnection = struct {
pub fn call(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue {
var vm = globalObject.bunVM();
const arguments = callframe.arguments(9).slice();
const arguments = callframe.arguments(10).slice();
const hostname_str = arguments[0].toBunString(globalObject);
defer hostname_str.deref();
const port = arguments[1].coerce(i32, globalObject);
@@ -1237,13 +1385,67 @@ pub const PostgresSQLConnection = struct {
defer password_str.deref();
const database_str = arguments[4].toBunString(globalObject);
defer database_str.deref();
const tls_object = arguments[5];
const ssl_mode: SSLMode = switch (arguments[5].toInt32()) {
0 => .disable,
1 => .prefer,
2 => .require,
3 => .verify_ca,
4 => .verify_full,
else => .disable,
};
const tls_object = arguments[6];
var tls_config: JSC.API.ServerConfig.SSLConfig = .{};
var tls_ctx: ?*uws.SocketContext = null;
if (ssl_mode != .disable) {
tls_config = if (tls_object.isBoolean() and tls_object.toBoolean())
.{}
else if (tls_object.isObject())
(JSC.API.ServerConfig.SSLConfig.fromJS(vm, globalObject, tls_object) catch return .zero) orelse .{}
else {
globalObject.throwInvalidArguments("tls must be a boolean or an object", .{});
return .zero;
};
if (globalObject.hasException()) {
tls_config.deinit();
return .zero;
}
if (tls_config.reject_unauthorized != 0)
tls_config.request_cert = 1;
// We create it right here so we can throw errors early.
const context_options = tls_config.asUSockets();
var err: uws.create_bun_socket_error_t = .none;
tls_ctx = uws.us_create_bun_socket_context(1, vm.uwsLoop(), @sizeOf(*PostgresSQLConnection), context_options, &err) orelse {
if (err != .none) {
globalObject.throw("failed to create TLS context", .{});
} else {
globalObject.throwValue(err.toJS(globalObject));
}
return .zero;
};
if (err != .none) {
tls_config.deinit();
globalObject.throwValue(err.toJS(globalObject));
if (tls_ctx) |ctx| {
ctx.deinit(true);
}
return .zero;
}
uws.NewSocketHandler(true).configure(tls_ctx.?, true, *PostgresSQLConnection, SocketHandler(true));
}
var username: []const u8 = "";
var password: []const u8 = "";
var database: []const u8 = "";
var options: []const u8 = "";
const options_str = arguments[6].toBunString(globalObject);
const options_str = arguments[7].toBunString(globalObject);
defer options_str.deref();
const options_buf: []u8 = brk: {
@@ -1270,10 +1472,15 @@ pub const PostgresSQLConnection = struct {
break :brk b.allocatedSlice();
};
const on_connect = arguments[7];
const on_close = arguments[8];
const on_connect = arguments[8];
const on_close = arguments[9];
var ptr = bun.default_allocator.create(PostgresSQLConnection) catch |err| {
globalObject.throwError(err, "failed to allocate connection");
tls_config.deinit();
if (tls_ctx) |ctx| {
ctx.deinit(true);
}
return .zero;
};
@@ -1289,6 +1496,10 @@ pub const PostgresSQLConnection = struct {
.socket = undefined,
.requests = PostgresRequest.Queue.init(bun.default_allocator),
.statements = PreparedStatementsMap{},
.tls_config = tls_config,
.tls_ctx = tls_ctx,
.ssl_mode = ssl_mode,
.tls_status = if (ssl_mode != .disable) .pending else .none,
};
ptr.updateHasPendingActivity();
@@ -1300,28 +1511,25 @@ pub const PostgresSQLConnection = struct {
{
const hostname = hostname_str.toUTF8(bun.default_allocator);
defer hostname.deinit();
if (tls_object.isEmptyOrUndefinedOrNull()) {
const ctx = vm.rareData().postgresql_context.tcp orelse brk: {
var err: uws.create_bun_socket_error_t = .none;
const ctx_ = uws.us_create_bun_socket_context(0, vm.uwsLoop(), @sizeOf(*PostgresSQLConnection), uws.us_bun_socket_context_options_t{}, &err).?;
uws.NewSocketHandler(false).configure(ctx_, true, *PostgresSQLConnection, SocketHandler(false));
vm.rareData().postgresql_context.tcp = ctx_;
break :brk ctx_;
};
ptr.socket = .{
// TODO: investigate if allowHalfOpen: true is necessary here or if brings some advantage
.SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| {
globalObject.throwError(err, "failed to connect to postgresql");
ptr.deinit();
return .zero;
},
};
} else {
// TODO:
globalObject.throwTODO("TLS is not supported yet");
ptr.deinit();
return .zero;
}
const ctx = vm.rareData().postgresql_context.tcp orelse brk: {
var err: uws.create_bun_socket_error_t = .none;
const ctx_ = uws.us_create_bun_socket_context(0, vm.uwsLoop(), @sizeOf(*PostgresSQLConnection), uws.us_bun_socket_context_options_t{}, &err).?;
uws.NewSocketHandler(false).configure(ctx_, true, *PostgresSQLConnection, SocketHandler(false));
vm.rareData().postgresql_context.tcp = ctx_;
break :brk ctx_;
};
ptr.socket = .{
.SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| {
globalObject.throwError(err, "failed to connect to postgresql");
tls_config.deinit();
if (tls_ctx) |tls| {
tls.deinit(true);
}
ptr.deinit();
return .zero;
},
};
}
return js_value;
@@ -1341,6 +1549,12 @@ pub const PostgresSQLConnection = struct {
this.onOpen(_socket(socket));
}
fn onHandshake_(this: *PostgresSQLConnection, _: anytype, success: i32, ssl_error: uws.us_bun_verify_error_t) void {
this.onHandshake(success, ssl_error);
}
pub const onHandshake = if (ssl) onHandshake_ else null;
pub fn onClose(this: *PostgresSQLConnection, socket: SocketType, _: i32, _: ?*anyopaque) void {
_ = socket;
this.onClose();
@@ -1421,6 +1635,7 @@ pub const PostgresSQLConnection = struct {
this.on_connect.deinit();
this.backend_parameters.deinit();
bun.default_allocator.free(this.options_buf);
this.tls_config.deinit();
bun.default_allocator.destroy(this);
}
@@ -2147,6 +2362,18 @@ pub const PostgresSQLConnection = struct {
this.fail("Unknown authentication method", error.UNKNOWN_AUTHENTICATION_METHOD);
},
.ClearTextPassword => {
debug("ClearTextPassword", .{});
var response = protocol.PasswordMessage{
.password = .{
.temporary = this.password,
},
};
try response.writeInternal(PostgresSQLConnection.Writer, this.writer());
this.flushData();
},
else => {
debug("TODO auth: {s}", .{@tagName(std.meta.activeTag(auth))});
},