diff --git a/cmake/sources/ZigSources.txt b/cmake/sources/ZigSources.txt index aa592e5957..08d07b7249 100644 --- a/cmake/sources/ZigSources.txt +++ b/cmake/sources/ZigSources.txt @@ -732,10 +732,77 @@ src/sourcemap/CodeCoverage.zig src/sourcemap/LineOffsetTable.zig src/sourcemap/sourcemap.zig src/sourcemap/VLQ.zig -src/sql/DataCell.zig src/sql/postgres.zig -src/sql/postgres/postgres_protocol.zig -src/sql/postgres/postgres_types.zig +src/sql/postgres/AnyPostgresError.zig +src/sql/postgres/AuthenticationState.zig +src/sql/postgres/CommandTag.zig +src/sql/postgres/ConnectionFlags.zig +src/sql/postgres/Data.zig +src/sql/postgres/DataCell.zig +src/sql/postgres/DebugSocketMonitorReader.zig +src/sql/postgres/DebugSocketMonitorWriter.zig +src/sql/postgres/ObjectIterator.zig +src/sql/postgres/PostgresCachedStructure.zig +src/sql/postgres/PostgresProtocol.zig +src/sql/postgres/PostgresRequest.zig +src/sql/postgres/PostgresSQLConnection.zig +src/sql/postgres/PostgresSQLContext.zig +src/sql/postgres/PostgresSQLQuery.zig +src/sql/postgres/PostgresSQLQueryResultMode.zig +src/sql/postgres/PostgresSQLStatement.zig +src/sql/postgres/PostgresTypes.zig +src/sql/postgres/protocol/ArrayList.zig +src/sql/postgres/protocol/Authentication.zig +src/sql/postgres/protocol/BackendKeyData.zig +src/sql/postgres/protocol/Close.zig +src/sql/postgres/protocol/ColumnIdentifier.zig +src/sql/postgres/protocol/CommandComplete.zig +src/sql/postgres/protocol/CopyData.zig +src/sql/postgres/protocol/CopyFail.zig +src/sql/postgres/protocol/CopyInResponse.zig +src/sql/postgres/protocol/CopyOutResponse.zig +src/sql/postgres/protocol/DataRow.zig +src/sql/postgres/protocol/DecoderWrap.zig +src/sql/postgres/protocol/Describe.zig +src/sql/postgres/protocol/ErrorResponse.zig +src/sql/postgres/protocol/Execute.zig +src/sql/postgres/protocol/FieldDescription.zig +src/sql/postgres/protocol/FieldMessage.zig +src/sql/postgres/protocol/FieldType.zig +src/sql/postgres/protocol/NegotiateProtocolVersion.zig +src/sql/postgres/protocol/NewReader.zig +src/sql/postgres/protocol/NewWriter.zig +src/sql/postgres/protocol/NoticeResponse.zig +src/sql/postgres/protocol/NotificationResponse.zig +src/sql/postgres/protocol/ParameterDescription.zig +src/sql/postgres/protocol/ParameterStatus.zig +src/sql/postgres/protocol/Parse.zig +src/sql/postgres/protocol/PasswordMessage.zig +src/sql/postgres/protocol/PortalOrPreparedStatement.zig +src/sql/postgres/protocol/ReadyForQuery.zig +src/sql/postgres/protocol/RowDescription.zig +src/sql/postgres/protocol/SASLInitialResponse.zig +src/sql/postgres/protocol/SASLResponse.zig +src/sql/postgres/protocol/StackReader.zig +src/sql/postgres/protocol/StartupMessage.zig +src/sql/postgres/protocol/TransactionStatusIndicator.zig +src/sql/postgres/protocol/WriteWrap.zig +src/sql/postgres/protocol/zHelpers.zig +src/sql/postgres/QueryBindingIterator.zig +src/sql/postgres/SASL.zig +src/sql/postgres/Signature.zig +src/sql/postgres/SocketMonitor.zig +src/sql/postgres/SSLMode.zig +src/sql/postgres/Status.zig +src/sql/postgres/TLSStatus.zig +src/sql/postgres/types/bool.zig +src/sql/postgres/types/bytea.zig +src/sql/postgres/types/date.zig +src/sql/postgres/types/int_types.zig +src/sql/postgres/types/json.zig +src/sql/postgres/types/numeric.zig +src/sql/postgres/types/PostgresString.zig +src/sql/postgres/types/Tag.zig src/StandaloneModuleGraph.zig src/StaticHashMap.zig src/string_immutable.zig diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index 18877fdfc3..728eeba420 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -1,3273 +1,3 @@ -const bun = @import("bun"); -const JSC = bun.JSC; -const String = bun.String; -const uws = bun.uws; -const std = @import("std"); -pub const debug = bun.Output.scoped(.Postgres, false); -pub const int4 = u32; -pub const PostgresInt32 = int4; -pub const int8 = i64; -pub const PostgresInt64 = int8; -pub const short = u16; -pub const PostgresShort = u16; -const Crypto = JSC.API.Bun.Crypto; -const JSValue = JSC.JSValue; -const BoringSSL = bun.BoringSSL; -pub const AnyPostgresError = error{ - ConnectionClosed, - ExpectedRequest, - ExpectedStatement, - InvalidBackendKeyData, - InvalidBinaryData, - InvalidByteSequence, - InvalidByteSequenceForEncoding, - InvalidCharacter, - InvalidMessage, - InvalidMessageLength, - InvalidQueryBinding, - InvalidServerKey, - InvalidServerSignature, - JSError, - MultidimensionalArrayNotSupportedYet, - NullsInArrayNotSupportedYet, - OutOfMemory, - Overflow, - PBKDFD2, - SASL_SIGNATURE_MISMATCH, - SASL_SIGNATURE_INVALID_BASE64, - ShortRead, - TLSNotAvailable, - TLSUpgradeFailed, - UnexpectedMessage, - UNKNOWN_AUTHENTICATION_METHOD, - UNSUPPORTED_AUTHENTICATION_METHOD, - UnsupportedByteaFormat, - UnsupportedIntegerSize, - UnsupportedArrayFormat, - UnsupportedNumericFormat, - UnknownFormatCode, -}; - -pub fn postgresErrorToJS(globalObject: *JSC.JSGlobalObject, message: ?[]const u8, err: AnyPostgresError) JSValue { - const error_code: JSC.Error = switch (err) { - error.ConnectionClosed => .POSTGRES_CONNECTION_CLOSED, - error.ExpectedRequest => .POSTGRES_EXPECTED_REQUEST, - error.ExpectedStatement => .POSTGRES_EXPECTED_STATEMENT, - error.InvalidBackendKeyData => .POSTGRES_INVALID_BACKEND_KEY_DATA, - error.InvalidBinaryData => .POSTGRES_INVALID_BINARY_DATA, - error.InvalidByteSequence => .POSTGRES_INVALID_BYTE_SEQUENCE, - error.InvalidByteSequenceForEncoding => .POSTGRES_INVALID_BYTE_SEQUENCE_FOR_ENCODING, - error.InvalidCharacter => .POSTGRES_INVALID_CHARACTER, - error.InvalidMessage => .POSTGRES_INVALID_MESSAGE, - error.InvalidMessageLength => .POSTGRES_INVALID_MESSAGE_LENGTH, - error.InvalidQueryBinding => .POSTGRES_INVALID_QUERY_BINDING, - error.InvalidServerKey => .POSTGRES_INVALID_SERVER_KEY, - error.InvalidServerSignature => .POSTGRES_INVALID_SERVER_SIGNATURE, - error.MultidimensionalArrayNotSupportedYet => .POSTGRES_MULTIDIMENSIONAL_ARRAY_NOT_SUPPORTED_YET, - error.NullsInArrayNotSupportedYet => .POSTGRES_NULLS_IN_ARRAY_NOT_SUPPORTED_YET, - error.Overflow => .POSTGRES_OVERFLOW, - error.PBKDFD2 => .POSTGRES_AUTHENTICATION_FAILED_PBKDF2, - error.SASL_SIGNATURE_MISMATCH => .POSTGRES_SASL_SIGNATURE_MISMATCH, - error.SASL_SIGNATURE_INVALID_BASE64 => .POSTGRES_SASL_SIGNATURE_INVALID_BASE64, - error.TLSNotAvailable => .POSTGRES_TLS_NOT_AVAILABLE, - error.TLSUpgradeFailed => .POSTGRES_TLS_UPGRADE_FAILED, - error.UnexpectedMessage => .POSTGRES_UNEXPECTED_MESSAGE, - error.UNKNOWN_AUTHENTICATION_METHOD => .POSTGRES_UNKNOWN_AUTHENTICATION_METHOD, - error.UNSUPPORTED_AUTHENTICATION_METHOD => .POSTGRES_UNSUPPORTED_AUTHENTICATION_METHOD, - error.UnsupportedByteaFormat => .POSTGRES_UNSUPPORTED_BYTEA_FORMAT, - error.UnsupportedArrayFormat => .POSTGRES_UNSUPPORTED_ARRAY_FORMAT, - error.UnsupportedIntegerSize => .POSTGRES_UNSUPPORTED_INTEGER_SIZE, - error.UnsupportedNumericFormat => .POSTGRES_UNSUPPORTED_NUMERIC_FORMAT, - error.UnknownFormatCode => .POSTGRES_UNKNOWN_FORMAT_CODE, - error.JSError => { - return globalObject.takeException(error.JSError); - }, - error.OutOfMemory => { - // TODO: add binding for creating an out of memory error? - return globalObject.takeException(globalObject.throwOutOfMemory()); - }, - error.ShortRead => { - bun.unreachablePanic("Assertion failed: ShortRead should be handled by the caller in postgres", .{}); - }, - }; - if (message) |msg| { - return error_code.fmt(globalObject, "{s}", .{msg}); - } - return error_code.fmt(globalObject, "Failed to bind query: {s}", .{@errorName(err)}); -} - -pub const SSLMode = enum(u8) { - disable = 0, - prefer = 1, - require = 2, - verify_ca = 3, - verify_full = 4, -}; - -pub const Data = union(enum) { - owned: bun.ByteList, - temporary: []const u8, - empty: void, - - pub const Empty: Data = .{ .empty = {} }; - - pub fn toOwned(this: @This()) !bun.ByteList { - return switch (this) { - .owned => this.owned, - .temporary => bun.ByteList.init(try bun.default_allocator.dupe(u8, this.temporary)), - .empty => bun.ByteList.init(&.{}), - }; - } - - pub fn deinit(this: *@This()) void { - switch (this.*) { - .owned => this.owned.deinitWithAllocator(bun.default_allocator), - .temporary => {}, - .empty => {}, - } - } - - /// Zero bytes before deinit - /// Generally, for security reasons. - pub fn zdeinit(this: *@This()) void { - switch (this.*) { - .owned => { - - // Zero bytes before deinit - @memset(this.owned.slice(), 0); - - this.owned.deinitWithAllocator(bun.default_allocator); - }, - .temporary => {}, - .empty => {}, - } - } - - pub fn slice(this: @This()) []const u8 { - return switch (this) { - .owned => this.owned.slice(), - .temporary => this.temporary, - .empty => "", - }; - } - - pub fn substring(this: @This(), start_index: usize, end_index: usize) Data { - return switch (this) { - .owned => .{ .temporary = this.owned.slice()[start_index..end_index] }, - .temporary => .{ .temporary = this.temporary[start_index..end_index] }, - .empty => .{ .empty = {} }, - }; - } - - pub fn sliceZ(this: @This()) [:0]const u8 { - return switch (this) { - .owned => this.owned.slice()[0..this.owned.len :0], - .temporary => this.temporary[0..this.temporary.len :0], - .empty => "", - }; - } -}; -pub const protocol = @import("./postgres/postgres_protocol.zig"); -pub const types = @import("./postgres/postgres_types.zig"); - -const Socket = uws.AnySocket; -const PreparedStatementsMap = std.HashMapUnmanaged(u64, *PostgresSQLStatement, bun.IdentityContext(u64), 80); - -const SocketMonitor = struct { - const DebugSocketMonitorWriter = struct { - var file: std.fs.File = undefined; - var enabled = false; - var check = std.once(load); - pub fn write(data: []const u8) void { - file.writeAll(data) catch {}; - } - - fn load() void { - if (bun.getenvZAnyCase("BUN_POSTGRES_SOCKET_MONITOR")) |monitor| { - enabled = true; - file = std.fs.cwd().createFile(monitor, .{ .truncate = true }) catch { - enabled = false; - return; - }; - debug("writing to {s}", .{monitor}); - } - } - }; - - const DebugSocketMonitorReader = struct { - var file: std.fs.File = undefined; - var enabled = false; - var check = std.once(load); - - fn load() void { - if (bun.getenvZAnyCase("BUN_POSTGRES_SOCKET_MONITOR_READER")) |monitor| { - enabled = true; - file = std.fs.cwd().createFile(monitor, .{ .truncate = true }) catch { - enabled = false; - return; - }; - debug("duplicating reads to {s}", .{monitor}); - } - } - - pub fn write(data: []const u8) void { - file.writeAll(data) catch {}; - } - }; - - pub fn write(data: []const u8) void { - if (comptime bun.Environment.isDebug) { - DebugSocketMonitorWriter.check.call(); - if (DebugSocketMonitorWriter.enabled) { - DebugSocketMonitorWriter.write(data); - } - } - } - - pub fn read(data: []const u8) void { - if (comptime bun.Environment.isDebug) { - DebugSocketMonitorReader.check.call(); - if (DebugSocketMonitorReader.enabled) { - DebugSocketMonitorReader.write(data); - } - } - } -}; - -pub const PostgresSQLContext = struct { - tcp: ?*uws.SocketContext = null, - - onQueryResolveFn: JSC.Strong.Optional = .empty, - onQueryRejectFn: JSC.Strong.Optional = .empty, - - pub fn init(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { - var ctx = &globalObject.bunVM().rareData().postgresql_context; - ctx.onQueryResolveFn.set(globalObject, callframe.argument(0)); - ctx.onQueryRejectFn.set(globalObject, callframe.argument(1)); - - return .js_undefined; - } - - comptime { - const js_init = JSC.toJSHostFn(init); - @export(&js_init, .{ .name = "PostgresSQLContext__init" }); - } -}; -pub const PostgresSQLQueryResultMode = enum(u2) { - objects = 0, - values = 1, - raw = 2, -}; - -const JSRef = JSC.JSRef; - -pub const PostgresSQLQuery = struct { - statement: ?*PostgresSQLStatement = null, - query: bun.String = bun.String.empty, - cursor_name: bun.String = bun.String.empty, - - thisValue: JSRef = JSRef.empty(), - - status: Status = Status.pending, - - ref_count: std.atomic.Value(u32) = std.atomic.Value(u32).init(1), - - flags: packed struct(u8) { - is_done: bool = false, - binary: bool = false, - bigint: bool = false, - simple: bool = false, - result_mode: PostgresSQLQueryResultMode = .objects, - _padding: u2 = 0, - } = .{}, - - pub const js = JSC.Codegen.JSPostgresSQLQuery; - pub const toJS = js.toJS; - pub const fromJS = js.fromJS; - pub const fromJSDirect = js.fromJSDirect; - - pub fn getTarget(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, clean_target: bool) JSC.JSValue { - const thisValue = this.thisValue.get(); - if (thisValue == .zero) { - return .zero; - } - const target = js.targetGetCached(thisValue) orelse return .zero; - if (clean_target) { - js.targetSetCached(thisValue, globalObject, .zero); - } - return target; - } - - pub const Status = enum(u8) { - /// The query was just enqueued, statement status can be checked for more details - pending, - /// The query is being bound to the statement - binding, - /// The query is running - running, - /// The query is waiting for a partial response - partial_response, - /// The query was successful - success, - /// The query failed - fail, - - pub fn isRunning(this: Status) bool { - return @intFromEnum(this) > @intFromEnum(Status.pending) and @intFromEnum(this) < @intFromEnum(Status.success); - } - }; - - pub fn hasPendingActivity(this: *@This()) bool { - return this.ref_count.load(.monotonic) > 1; - } - - pub fn deinit(this: *@This()) void { - this.thisValue.deinit(); - if (this.statement) |statement| { - statement.deref(); - } - this.query.deref(); - this.cursor_name.deref(); - bun.default_allocator.destroy(this); - } - - pub fn finalize(this: *@This()) void { - debug("PostgresSQLQuery finalize", .{}); - if (this.thisValue == .weak) { - // clean up if is a weak reference, if is a strong reference we need to wait until the query is done - // if we are a strong reference, here is probably a bug because GC'd should not happen - this.thisValue.weak = .zero; - } - this.deref(); - } - - pub fn deref(this: *@This()) void { - const ref_count = this.ref_count.fetchSub(1, .monotonic); - - if (ref_count == 1) { - this.deinit(); - } - } - - pub fn ref(this: *@This()) void { - bun.assert(this.ref_count.fetchAdd(1, .monotonic) > 0); - } - - pub fn onWriteFail( - this: *@This(), - err: AnyPostgresError, - globalObject: *JSC.JSGlobalObject, - queries_array: JSValue, - ) void { - this.status = .fail; - const thisValue = this.thisValue.get(); - defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject, true); - if (thisValue == .zero or targetValue == .zero) { - return; - } - - const vm = JSC.VirtualMachine.get(); - const function = vm.rareData().postgresql_context.onQueryRejectFn.get().?; - const event_loop = vm.eventLoop(); - event_loop.runCallback(function, globalObject, thisValue, &.{ - targetValue, - postgresErrorToJS(globalObject, null, err), - queries_array, - }); - } - pub fn onJSError(this: *@This(), err: JSC.JSValue, globalObject: *JSC.JSGlobalObject) void { - this.status = .fail; - this.ref(); - defer this.deref(); - - const thisValue = this.thisValue.get(); - defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject, true); - if (thisValue == .zero or targetValue == .zero) { - return; - } - - var vm = JSC.VirtualMachine.get(); - const function = vm.rareData().postgresql_context.onQueryRejectFn.get().?; - const event_loop = vm.eventLoop(); - event_loop.runCallback(function, globalObject, thisValue, &.{ - targetValue, - err, - }); - } - pub fn onError(this: *@This(), err: PostgresSQLStatement.Error, globalObject: *JSC.JSGlobalObject) void { - this.onJSError(err.toJS(globalObject), globalObject); - } - - const CommandTag = union(enum) { - // For an INSERT command, the tag is INSERT oid rows, where rows is the - // number of rows inserted. oid used to be the object ID of the inserted - // row if rows was 1 and the target table had OIDs, but OIDs system - // columns are not supported anymore; therefore oid is always 0. - INSERT: u64, - // For a DELETE command, the tag is DELETE rows where rows is the number - // of rows deleted. - DELETE: u64, - // For an UPDATE command, the tag is UPDATE rows where rows is the - // number of rows updated. - UPDATE: u64, - // For a MERGE command, the tag is MERGE rows where rows is the number - // of rows inserted, updated, or deleted. - MERGE: u64, - // For a SELECT or CREATE TABLE AS command, the tag is SELECT rows where - // rows is the number of rows retrieved. - SELECT: u64, - // For a MOVE command, the tag is MOVE rows where rows is the number of - // rows the cursor's position has been changed by. - MOVE: u64, - // For a FETCH command, the tag is FETCH rows where rows is the number - // of rows that have been retrieved from the cursor. - FETCH: u64, - // For a COPY command, the tag is COPY rows where rows is the number of - // rows copied. (Note: the row count appears only in PostgreSQL 8.2 and - // later.) - COPY: u64, - - other: []const u8, - - pub fn toJSTag(this: CommandTag, globalObject: *JSC.JSGlobalObject) JSValue { - return switch (this) { - .INSERT => JSValue.jsNumber(1), - .DELETE => JSValue.jsNumber(2), - .UPDATE => JSValue.jsNumber(3), - .MERGE => JSValue.jsNumber(4), - .SELECT => JSValue.jsNumber(5), - .MOVE => JSValue.jsNumber(6), - .FETCH => JSValue.jsNumber(7), - .COPY => JSValue.jsNumber(8), - .other => |tag| JSC.ZigString.init(tag).toJS(globalObject), - }; - } - - pub fn toJSNumber(this: CommandTag) JSValue { - return switch (this) { - .other => JSValue.jsNumber(0), - inline else => |val| JSValue.jsNumber(val), - }; - } - - const KnownCommand = enum { - INSERT, - DELETE, - UPDATE, - MERGE, - SELECT, - MOVE, - FETCH, - COPY, - - pub const Map = bun.ComptimeEnumMap(KnownCommand); - }; - - pub fn init(tag: []const u8) CommandTag { - const first_space_index = bun.strings.indexOfChar(tag, ' ') orelse return .{ .other = tag }; - const cmd = KnownCommand.Map.get(tag[0..first_space_index]) orelse return .{ - .other = tag, - }; - - const number = brk: { - switch (cmd) { - .INSERT => { - var remaining = tag[@min(first_space_index + 1, tag.len)..]; - const second_space = bun.strings.indexOfChar(remaining, ' ') orelse return .{ .other = tag }; - remaining = remaining[@min(second_space + 1, remaining.len)..]; - break :brk std.fmt.parseInt(u64, remaining, 0) catch |err| { - debug("CommandTag failed to parse number: {s}", .{@errorName(err)}); - return .{ .other = tag }; - }; - }, - else => { - const after_tag = tag[@min(first_space_index + 1, tag.len)..]; - break :brk std.fmt.parseInt(u64, after_tag, 0) catch |err| { - debug("CommandTag failed to parse number: {s}", .{@errorName(err)}); - return .{ .other = tag }; - }; - }, - } - }; - - switch (cmd) { - inline else => |t| return @unionInit(CommandTag, @tagName(t), number), - } - } - }; - - pub fn allowGC(thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) void { - if (thisValue == .zero) { - return; - } - - defer thisValue.ensureStillAlive(); - js.bindingSetCached(thisValue, globalObject, .zero); - js.pendingValueSetCached(thisValue, globalObject, .zero); - js.targetSetCached(thisValue, globalObject, .zero); - } - - fn consumePendingValue(thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) ?JSValue { - const pending_value = js.pendingValueGetCached(thisValue) orelse return null; - js.pendingValueSetCached(thisValue, globalObject, .zero); - return pending_value; - } - - pub fn onResult(this: *@This(), command_tag_str: []const u8, globalObject: *JSC.JSGlobalObject, connection: JSC.JSValue, is_last: bool) void { - this.ref(); - defer this.deref(); - - const thisValue = this.thisValue.get(); - const targetValue = this.getTarget(globalObject, is_last); - if (is_last) { - this.status = .success; - } else { - this.status = .partial_response; - } - defer if (is_last) { - allowGC(thisValue, globalObject); - this.thisValue.deinit(); - }; - if (thisValue == .zero or targetValue == .zero) { - return; - } - - const vm = JSC.VirtualMachine.get(); - const function = vm.rareData().postgresql_context.onQueryResolveFn.get().?; - const event_loop = vm.eventLoop(); - const tag = CommandTag.init(command_tag_str); - - event_loop.runCallback(function, globalObject, thisValue, &.{ - targetValue, - consumePendingValue(thisValue, globalObject) orelse .js_undefined, - tag.toJSTag(globalObject), - tag.toJSNumber(), - if (connection == .zero) .js_undefined else PostgresSQLConnection.js.queriesGetCached(connection) orelse .js_undefined, - JSValue.jsBoolean(is_last), - }); - } - - pub fn constructor(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!*PostgresSQLQuery { - _ = callframe; - return globalThis.throw("PostgresSQLQuery cannot be constructed directly", .{}); - } - - pub fn estimatedSize(this: *PostgresSQLQuery) usize { - _ = this; - return @sizeOf(PostgresSQLQuery); - } - - pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { - const arguments = callframe.arguments_old(6).slice(); - var args = JSC.CallFrame.ArgumentsSlice.init(globalThis.bunVM(), arguments); - defer args.deinit(); - const query = args.nextEat() orelse { - return globalThis.throw("query must be a string", .{}); - }; - const values = args.nextEat() orelse { - return globalThis.throw("values must be an array", .{}); - }; - - if (!query.isString()) { - return globalThis.throw("query must be a string", .{}); - } - - if (values.jsType() != .Array) { - return globalThis.throw("values must be an array", .{}); - } - - const pending_value: JSValue = args.nextEat() orelse .js_undefined; - const columns: JSValue = args.nextEat() orelse .js_undefined; - const js_bigint: JSValue = args.nextEat() orelse .false; - const js_simple: JSValue = args.nextEat() orelse .false; - - const bigint = js_bigint.isBoolean() and js_bigint.asBoolean(); - const simple = js_simple.isBoolean() and js_simple.asBoolean(); - if (simple) { - if (try values.getLength(globalThis) > 0) { - return globalThis.throwInvalidArguments("simple query cannot have parameters", .{}); - } - if (try query.getLength(globalThis) >= std.math.maxInt(i32)) { - return globalThis.throwInvalidArguments("query is too long", .{}); - } - } - if (!pending_value.jsType().isArrayLike()) { - return globalThis.throwInvalidArgumentType("query", "pendingValue", "Array"); - } - - var ptr = try bun.default_allocator.create(PostgresSQLQuery); - - const this_value = ptr.toJS(globalThis); - this_value.ensureStillAlive(); - - ptr.* = .{ - .query = try query.toBunString(globalThis), - .thisValue = JSRef.initWeak(this_value), - .flags = .{ - .bigint = bigint, - .simple = simple, - }, - }; - - js.bindingSetCached(this_value, globalThis, values); - js.pendingValueSetCached(this_value, globalThis, pending_value); - if (!columns.isUndefined()) { - js.columnsSetCached(this_value, globalThis, columns); - } - - return this_value; - } - - pub fn push(this: *PostgresSQLQuery, globalThis: *JSC.JSGlobalObject, value: JSValue) void { - var pending_value = this.pending_value.get() orelse return; - pending_value.push(globalThis, value); - } - - pub fn doDone(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { - _ = globalObject; - this.flags.is_done = true; - return .js_undefined; - } - pub fn setPendingValue(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { - const result = callframe.argument(0); - js.pendingValueSetCached(this.thisValue.get(), globalObject, result); - return .js_undefined; - } - pub fn setMode(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { - const js_mode = callframe.argument(0); - if (js_mode.isEmptyOrUndefinedOrNull() or !js_mode.isNumber()) { - return globalObject.throwInvalidArgumentType("setMode", "mode", "Number"); - } - - const mode = try js_mode.coerce(i32, globalObject); - this.flags.result_mode = std.meta.intToEnum(PostgresSQLQueryResultMode, mode) catch { - return globalObject.throwInvalidArgumentTypeValue("mode", "Number", js_mode); - }; - return .js_undefined; - } - - pub fn doRun(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { - var arguments_ = callframe.arguments_old(2); - const arguments = arguments_.slice(); - const connection: *PostgresSQLConnection = arguments[0].as(PostgresSQLConnection) orelse { - return globalObject.throw("connection must be a PostgresSQLConnection", .{}); - }; - - connection.poll_ref.ref(globalObject.bunVM()); - var query = arguments[1]; - - if (!query.isObject()) { - return globalObject.throwInvalidArgumentType("run", "query", "Query"); - } - - const this_value = callframe.this(); - const binding_value = js.bindingGetCached(this_value) orelse .zero; - var query_str = this.query.toUTF8(bun.default_allocator); - defer query_str.deinit(); - var writer = connection.writer(); - - if (this.flags.simple) { - debug("executeQuery", .{}); - - const can_execute = !connection.hasQueryRunning(); - if (can_execute) { - PostgresRequest.executeQuery(query_str.slice(), PostgresSQLConnection.Writer, writer) catch |err| { - if (!globalObject.hasException()) - return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to execute query", err)); - return error.JSError; - }; - connection.flags.is_ready_for_query = false; - this.status = .running; - } else { - this.status = .pending; - } - const stmt = bun.default_allocator.create(PostgresSQLStatement) catch { - return globalObject.throwOutOfMemory(); - }; - // Query is simple and it's the only owner of the statement - stmt.* = .{ - .signature = Signature.empty(), - .ref_count = 1, - .status = .parsing, - }; - this.statement = stmt; - // We need a strong reference to the query so that it doesn't get GC'd - connection.requests.writeItem(this) catch return globalObject.throwOutOfMemory(); - this.ref(); - this.thisValue.upgrade(globalObject); - - js.targetSetCached(this_value, globalObject, query); - if (this.status == .running) { - connection.flushDataAndResetTimeout(); - } else { - connection.resetConnectionTimeout(); - } - return .js_undefined; - } - - const columns_value: JSValue = js.columnsGetCached(this_value) orelse .js_undefined; - - var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value, connection.prepared_statement_id, connection.flags.use_unnamed_prepared_statements) catch |err| { - if (!globalObject.hasException()) - return globalObject.throwError(err, "failed to generate signature"); - return error.JSError; - }; - - const has_params = signature.fields.len > 0; - var did_write = false; - enqueue: { - var connection_entry_value: ?**PostgresSQLStatement = null; - if (!connection.flags.use_unnamed_prepared_statements) { - const entry = connection.statements.getOrPut(bun.default_allocator, bun.hash(signature.name)) catch |err| { - signature.deinit(); - return globalObject.throwError(err, "failed to allocate statement"); - }; - connection_entry_value = entry.value_ptr; - if (entry.found_existing) { - this.statement = connection_entry_value.?.*; - this.statement.?.ref(); - signature.deinit(); - - switch (this.statement.?.status) { - .failed => { - // If the statement failed, we need to throw the error - return globalObject.throwValue(this.statement.?.error_response.?.toJS(globalObject)); - }, - .prepared => { - if (!connection.hasQueryRunning()) { - this.flags.binary = this.statement.?.fields.len > 0; - debug("bindAndExecute", .{}); - - // bindAndExecute will bind + execute, it will change to running after binding is complete - PostgresRequest.bindAndExecute(globalObject, this.statement.?, binding_value, columns_value, PostgresSQLConnection.Writer, writer) catch |err| { - if (!globalObject.hasException()) - return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to bind and execute query", err)); - return error.JSError; - }; - connection.flags.is_ready_for_query = false; - this.status = .binding; - - did_write = true; - } - }, - .parsing, .pending => {}, - } - - break :enqueue; - } - } - const can_execute = !connection.hasQueryRunning(); - - if (can_execute) { - // If it does not have params, we can write and execute immediately in one go - if (!has_params) { - debug("prepareAndQueryWithSignature", .{}); - // prepareAndQueryWithSignature will write + bind + execute, it will change to running after binding is complete - PostgresRequest.prepareAndQueryWithSignature(globalObject, query_str.slice(), binding_value, PostgresSQLConnection.Writer, writer, &signature) catch |err| { - signature.deinit(); - if (!globalObject.hasException()) - return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to prepare and query", err)); - return error.JSError; - }; - connection.flags.is_ready_for_query = false; - this.status = .binding; - did_write = true; - } else { - debug("writeQuery", .{}); - - PostgresRequest.writeQuery(query_str.slice(), signature.prepared_statement_name, signature.fields, PostgresSQLConnection.Writer, writer) catch |err| { - signature.deinit(); - if (!globalObject.hasException()) - return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to write query", err)); - return error.JSError; - }; - writer.write(&protocol.Sync) catch |err| { - signature.deinit(); - if (!globalObject.hasException()) - return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to flush", err)); - return error.JSError; - }; - connection.flags.is_ready_for_query = false; - did_write = true; - } - } - { - const stmt = bun.default_allocator.create(PostgresSQLStatement) catch { - return globalObject.throwOutOfMemory(); - }; - // we only have connection_entry_value if we are using named prepared statements - if (connection_entry_value) |entry_value| { - connection.prepared_statement_id += 1; - stmt.* = .{ .signature = signature, .ref_count = 2, .status = if (can_execute) .parsing else .pending }; - this.statement = stmt; - - entry_value.* = stmt; - } else { - stmt.* = .{ .signature = signature, .ref_count = 1, .status = if (can_execute) .parsing else .pending }; - this.statement = stmt; - } - } - } - // We need a strong reference to the query so that it doesn't get GC'd - connection.requests.writeItem(this) catch return globalObject.throwOutOfMemory(); - this.ref(); - this.thisValue.upgrade(globalObject); - - js.targetSetCached(this_value, globalObject, query); - if (did_write) { - connection.flushDataAndResetTimeout(); - } else { - connection.resetConnectionTimeout(); - } - return .js_undefined; - } - - pub fn doCancel(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { - _ = callframe; - _ = globalObject; - _ = this; - - return .js_undefined; - } - - comptime { - const jscall = JSC.toJSHostFn(call); - @export(&jscall, .{ .name = "PostgresSQLQuery__createInstance" }); - } -}; - -pub const PostgresRequest = struct { - pub fn writeBind( - name: []const u8, - cursor_name: bun.String, - globalObject: *JSC.JSGlobalObject, - values_array: JSValue, - columns_value: JSValue, - parameter_fields: []const int4, - result_fields: []const protocol.FieldDescription, - comptime Context: type, - writer: protocol.NewWriter(Context), - ) !void { - try writer.write("B"); - const length = try writer.length(); - - try writer.String(cursor_name); - try writer.string(name); - - const len: u32 = @truncate(parameter_fields.len); - - // The number of parameter format codes that follow (denoted C - // below). This can be zero to indicate that there are no - // parameters or that the parameters all use the default format - // (text); or one, in which case the specified format code is - // applied to all parameters; or it can equal the actual number - // of parameters. - try writer.short(len); - - var iter = try QueryBindingIterator.init(values_array, columns_value, globalObject); - for (0..len) |i| { - const parameter_field = parameter_fields[i]; - const is_custom_type = std.math.maxInt(short) < parameter_field; - const tag: types.Tag = if (is_custom_type) .text else @enumFromInt(@as(short, @intCast(parameter_field))); - - const force_text = is_custom_type or (tag.isBinaryFormatSupported() and brk: { - iter.to(@truncate(i)); - if (try iter.next()) |value| { - break :brk value.isString(); - } - if (iter.anyFailed()) { - return error.InvalidQueryBinding; - } - break :brk false; - }); - - if (force_text) { - // If they pass a value as a string, let's avoid attempting to - // convert it to the binary representation. This minimizes the room - // for mistakes on our end, such as stripping the timezone - // differently than what Postgres does when given a timestamp with - // timezone. - try writer.short(0); - continue; - } - - try writer.short( - tag.formatCode(), - ); - } - - // The number of parameter values that follow (possibly zero). This - // must match the number of parameters needed by the query. - try writer.short(len); - - debug("Bind: {} ({d} args)", .{ bun.fmt.quote(name), len }); - iter.to(0); - var i: usize = 0; - while (try iter.next()) |value| : (i += 1) { - const tag: types.Tag = brk: { - if (i >= len) { - // parameter in array but not in parameter_fields - // this is probably a bug a bug in bun lets return .text here so the server will send a error 08P01 - // with will describe better the error saying exactly how many parameters are missing and are expected - // Example: - // SQL error: PostgresError: bind message supplies 0 parameters, but prepared statement "PSELECT * FROM test_table WHERE id=$1 .in$0" requires 1 - // errno: "08P01", - // code: "ERR_POSTGRES_SERVER_ERROR" - break :brk .text; - } - const parameter_field = parameter_fields[i]; - const is_custom_type = std.math.maxInt(short) < parameter_field; - break :brk if (is_custom_type) .text else @enumFromInt(@as(short, @intCast(parameter_field))); - }; - if (value.isEmptyOrUndefinedOrNull()) { - debug(" -> NULL", .{}); - // As a special case, -1 indicates a - // NULL parameter value. No value bytes follow in the NULL case. - try writer.int4(@bitCast(@as(i32, -1))); - continue; - } - if (comptime bun.Environment.enable_logs) { - debug(" -> {s}", .{tag.tagName() orelse "(unknown)"}); - } - - switch ( - // If they pass a value as a string, let's avoid attempting to - // convert it to the binary representation. This minimizes the room - // for mistakes on our end, such as stripping the timezone - // differently than what Postgres does when given a timestamp with - // timezone. - if (tag.isBinaryFormatSupported() and value.isString()) .text else tag) { - .jsonb, .json => { - var str = bun.String.empty; - defer str.deref(); - try value.jsonStringify(globalObject, 0, &str); - const slice = str.toUTF8WithoutRef(bun.default_allocator); - defer slice.deinit(); - const l = try writer.length(); - try writer.write(slice.slice()); - try l.writeExcludingSelf(); - }, - .bool => { - const l = try writer.length(); - try writer.write(&[1]u8{@intFromBool(value.toBoolean())}); - try l.writeExcludingSelf(); - }, - .timestamp, .timestamptz => { - const l = try writer.length(); - try writer.int8(types.date.fromJS(globalObject, value)); - try l.writeExcludingSelf(); - }, - .bytea => { - var bytes: []const u8 = ""; - if (value.asArrayBuffer(globalObject)) |buf| { - bytes = buf.byteSlice(); - } - const l = try writer.length(); - debug(" {d} bytes", .{bytes.len}); - - try writer.write(bytes); - try l.writeExcludingSelf(); - }, - .int4 => { - const l = try writer.length(); - try writer.int4(@bitCast(try value.coerceToInt32(globalObject))); - try l.writeExcludingSelf(); - }, - .int4_array => { - const l = try writer.length(); - try writer.int4(@bitCast(try value.coerceToInt32(globalObject))); - try l.writeExcludingSelf(); - }, - .float8 => { - const l = try writer.length(); - try writer.f64(@bitCast(try value.toNumber(globalObject))); - try l.writeExcludingSelf(); - }, - - else => { - const str = try String.fromJS(value, globalObject); - if (str.tag == .Dead) return error.OutOfMemory; - defer str.deref(); - const slice = str.toUTF8WithoutRef(bun.default_allocator); - defer slice.deinit(); - const l = try writer.length(); - try writer.write(slice.slice()); - try l.writeExcludingSelf(); - }, - } - } - - var any_non_text_fields: bool = false; - for (result_fields) |field| { - if (field.typeTag().isBinaryFormatSupported()) { - any_non_text_fields = true; - break; - } - } - - if (any_non_text_fields) { - try writer.short(result_fields.len); - for (result_fields) |field| { - try writer.short( - field.typeTag().formatCode(), - ); - } - } else { - try writer.short(0); - } - - try length.write(); - } - - pub fn writeQuery( - query: []const u8, - name: []const u8, - params: []const int4, - comptime Context: type, - writer: protocol.NewWriter(Context), - ) AnyPostgresError!void { - { - var q = protocol.Parse{ - .name = name, - .params = params, - .query = query, - }; - try q.writeInternal(Context, writer); - debug("Parse: {}", .{bun.fmt.quote(query)}); - } - - { - var d = protocol.Describe{ - .p = .{ - .prepared_statement = name, - }, - }; - try d.writeInternal(Context, writer); - debug("Describe: {}", .{bun.fmt.quote(name)}); - } - } - - pub fn prepareAndQueryWithSignature( - globalObject: *JSC.JSGlobalObject, - query: []const u8, - array_value: JSValue, - comptime Context: type, - writer: protocol.NewWriter(Context), - signature: *Signature, - ) AnyPostgresError!void { - try writeQuery(query, signature.prepared_statement_name, signature.fields, Context, writer); - try writeBind(signature.prepared_statement_name, bun.String.empty, globalObject, array_value, .zero, &.{}, &.{}, Context, writer); - var exec = protocol.Execute{ - .p = .{ - .prepared_statement = signature.prepared_statement_name, - }, - }; - try exec.writeInternal(Context, writer); - - try writer.write(&protocol.Flush); - try writer.write(&protocol.Sync); - } - - pub fn bindAndExecute( - globalObject: *JSC.JSGlobalObject, - statement: *PostgresSQLStatement, - array_value: JSValue, - columns_value: JSValue, - comptime Context: type, - writer: protocol.NewWriter(Context), - ) !void { - try writeBind(statement.signature.prepared_statement_name, bun.String.empty, globalObject, array_value, columns_value, statement.parameters, statement.fields, Context, writer); - var exec = protocol.Execute{ - .p = .{ - .prepared_statement = statement.signature.prepared_statement_name, - }, - }; - try exec.writeInternal(Context, writer); - - try writer.write(&protocol.Flush); - try writer.write(&protocol.Sync); - } - - pub fn executeQuery( - query: []const u8, - comptime Context: type, - writer: protocol.NewWriter(Context), - ) !void { - try protocol.writeQuery(query, Context, writer); - try writer.write(&protocol.Flush); - try writer.write(&protocol.Sync); - } - - pub fn onData( - connection: *PostgresSQLConnection, - comptime Context: type, - reader: protocol.NewReader(Context), - ) !void { - while (true) { - reader.markMessageStart(); - const c = try reader.int(u8); - debug("read: {c}", .{c}); - switch (c) { - 'D' => try connection.on(.DataRow, Context, reader), - 'd' => try connection.on(.CopyData, Context, reader), - 'S' => { - if (connection.tls_status == .message_sent) { - bun.debugAssert(connection.tls_status.message_sent == 8); - connection.tls_status = .ssl_ok; - connection.setupTLS(); - return; - } - - try connection.on(.ParameterStatus, Context, reader); - }, - 'Z' => try connection.on(.ReadyForQuery, Context, reader), - 'C' => try connection.on(.CommandComplete, Context, reader), - '2' => try connection.on(.BindComplete, Context, reader), - '1' => try connection.on(.ParseComplete, Context, reader), - 't' => try connection.on(.ParameterDescription, Context, reader), - 'T' => try connection.on(.RowDescription, Context, reader), - 'R' => try connection.on(.Authentication, Context, reader), - 'n' => try connection.on(.NoData, Context, reader), - 'K' => try connection.on(.BackendKeyData, Context, reader), - 'E' => try connection.on(.ErrorResponse, Context, reader), - 's' => try connection.on(.PortalSuspended, Context, reader), - '3' => try connection.on(.CloseComplete, Context, reader), - 'G' => try connection.on(.CopyInResponse, Context, reader), - 'N' => { - if (connection.tls_status == .message_sent) { - connection.tls_status = .ssl_not_available; - debug("Server does not support SSL", .{}); - if (connection.ssl_mode == .require) { - connection.fail("Server does not support SSL", error.TLSNotAvailable); - return; - } - continue; - } - - try connection.on(.NoticeResponse, Context, reader); - }, - 'I' => try connection.on(.EmptyQueryResponse, Context, reader), - 'H' => try connection.on(.CopyOutResponse, Context, reader), - 'c' => try connection.on(.CopyDone, Context, reader), - 'W' => try connection.on(.CopyBothResponse, Context, reader), - - else => { - debug("Unknown message: {c}", .{c}); - const to_skip = try reader.length() -| 1; - debug("to_skip: {d}", .{to_skip}); - try reader.skip(@intCast(@max(to_skip, 0))); - }, - } - } - } - - pub const Queue = std.fifo.LinearFifo(*PostgresSQLQuery, .Dynamic); -}; - -pub const PostgresSQLConnection = struct { - socket: Socket, - status: Status = Status.connecting, - ref_count: u32 = 1, - - write_buffer: bun.OffsetByteList = .{}, - read_buffer: bun.OffsetByteList = .{}, - last_message_start: u32 = 0, - requests: PostgresRequest.Queue, - - poll_ref: bun.Async.KeepAlive = .{}, - globalObject: *JSC.JSGlobalObject, - - statements: PreparedStatementsMap, - prepared_statement_id: u64 = 0, - pending_activity_count: std.atomic.Value(u32) = std.atomic.Value(u32).init(0), - js_value: JSValue = .js_undefined, - - backend_parameters: bun.StringMap = bun.StringMap.init(bun.default_allocator, true), - backend_key_data: protocol.BackendKeyData = .{}, - - database: []const u8 = "", - user: []const u8 = "", - password: []const u8 = "", - path: []const u8 = "", - options: []const u8 = "", - options_buf: []const u8 = "", - - authentication_state: AuthenticationState = .{ .pending = {} }, - - tls_ctx: ?*uws.SocketContext = null, - tls_config: JSC.API.ServerConfig.SSLConfig = .{}, - tls_status: TLSStatus = .none, - ssl_mode: SSLMode = .disable, - - idle_timeout_interval_ms: u32 = 0, - connection_timeout_ms: u32 = 0, - - flags: ConnectionFlags = .{}, - - /// Before being connected, this is a connection timeout timer. - /// After being connected, this is an idle timeout timer. - timer: bun.api.Timer.EventLoopTimer = .{ - .tag = .PostgresSQLConnectionTimeout, - .next = .{ - .sec = 0, - .nsec = 0, - }, - }, - - /// This timer controls the maximum lifetime of a connection. - /// It starts when the connection successfully starts (i.e. after handshake is complete). - /// It stops when the connection is closed. - max_lifetime_interval_ms: u32 = 0, - max_lifetime_timer: bun.api.Timer.EventLoopTimer = .{ - .tag = .PostgresSQLConnectionMaxLifetime, - .next = .{ - .sec = 0, - .nsec = 0, - }, - }, - - pub const ConnectionFlags = packed struct { - is_ready_for_query: bool = false, - is_processing_data: bool = false, - use_unnamed_prepared_statements: bool = false, - }; - - pub const TLSStatus = union(enum) { - none, - pending, - - /// Number of bytes sent of the 8-byte SSL request message. - /// Since we may send a partial message, we need to know how many bytes were sent. - message_sent: u8, - - ssl_not_available, - ssl_ok, - }; - - pub const AuthenticationState = union(enum) { - pending: void, - none: void, - ok: void, - SASL: SASL, - md5: void, - - pub fn zero(this: *AuthenticationState) void { - switch (this.*) { - .SASL => |*sasl| { - sasl.deinit(); - }, - else => {}, - } - this.* = .{ .none = {} }; - } - }; - - pub const SASL = struct { - const nonce_byte_len = 18; - const nonce_base64_len = bun.base64.encodeLenFromSize(nonce_byte_len); - - const server_signature_byte_len = 32; - const server_signature_base64_len = bun.base64.encodeLenFromSize(server_signature_byte_len); - - const salted_password_byte_len = 32; - - nonce_base64_bytes: [nonce_base64_len]u8 = .{0} ** nonce_base64_len, - nonce_len: u8 = 0, - - server_signature_base64_bytes: [server_signature_base64_len]u8 = .{0} ** server_signature_base64_len, - server_signature_len: u8 = 0, - - salted_password_bytes: [salted_password_byte_len]u8 = .{0} ** salted_password_byte_len, - salted_password_created: bool = false, - - status: SASLStatus = .init, - - pub const SASLStatus = enum { - init, - @"continue", - }; - - fn hmac(password: []const u8, data: []const u8) ?[32]u8 { - var buf = std.mem.zeroes([bun.BoringSSL.c.EVP_MAX_MD_SIZE]u8); - - // TODO: I don't think this is failable. - const result = bun.hmac.generate(password, data, .sha256, &buf) orelse return null; - - assert(result.len == 32); - return buf[0..32].*; - } - - pub fn computeSaltedPassword(this: *SASL, salt_bytes: []const u8, iteration_count: u32, connection: *PostgresSQLConnection) !void { - this.salted_password_created = true; - if (Crypto.EVP.pbkdf2(&this.salted_password_bytes, connection.password, salt_bytes, iteration_count, .sha256) == null) { - return error.PBKDFD2; - } - } - - pub fn saltedPassword(this: *const SASL) []const u8 { - assert(this.salted_password_created); - return this.salted_password_bytes[0..salted_password_byte_len]; - } - - pub fn serverSignature(this: *const SASL) []const u8 { - assert(this.server_signature_len > 0); - return this.server_signature_base64_bytes[0..this.server_signature_len]; - } - - pub fn computeServerSignature(this: *SASL, auth_string: []const u8) !void { - assert(this.server_signature_len == 0); - - const server_key = hmac(this.saltedPassword(), "Server Key") orelse return error.InvalidServerKey; - const server_signature_bytes = hmac(&server_key, auth_string) orelse return error.InvalidServerSignature; - this.server_signature_len = @intCast(bun.base64.encode(&this.server_signature_base64_bytes, &server_signature_bytes)); - } - - pub fn clientKey(this: *const SASL) [32]u8 { - return hmac(this.saltedPassword(), "Client Key").?; - } - - pub fn clientKeySignature(_: *const SASL, client_key: []const u8, auth_string: []const u8) [32]u8 { - var sha_digest = std.mem.zeroes(bun.sha.SHA256.Digest); - bun.sha.SHA256.hash(client_key, &sha_digest, JSC.VirtualMachine.get().rareData().boringEngine()); - return hmac(&sha_digest, auth_string).?; - } - - pub fn nonce(this: *SASL) []const u8 { - if (this.nonce_len == 0) { - var bytes: [nonce_byte_len]u8 = .{0} ** nonce_byte_len; - bun.csprng(&bytes); - this.nonce_len = @intCast(bun.base64.encode(&this.nonce_base64_bytes, &bytes)); - } - return this.nonce_base64_bytes[0..this.nonce_len]; - } - - pub fn deinit(this: *SASL) void { - this.nonce_len = 0; - this.salted_password_created = false; - this.server_signature_len = 0; - this.status = .init; - } - }; - - pub const Status = enum { - disconnected, - connecting, - // Prevent sending the startup message multiple times. - // Particularly relevant for TLS connections. - sent_startup_message, - connected, - failed, - }; - - pub const js = JSC.Codegen.JSPostgresSQLConnection; - pub const toJS = js.toJS; - pub const fromJS = js.fromJS; - pub const fromJSDirect = js.fromJSDirect; - - fn getTimeoutInterval(this: *const PostgresSQLConnection) u32 { - return switch (this.status) { - .connected => this.idle_timeout_interval_ms, - .failed => 0, - else => this.connection_timeout_ms, - }; - } - pub fn disableConnectionTimeout(this: *PostgresSQLConnection) void { - if (this.timer.state == .ACTIVE) { - this.globalObject.bunVM().timer.remove(&this.timer); - } - this.timer.state = .CANCELLED; - } - pub fn resetConnectionTimeout(this: *PostgresSQLConnection) void { - // if we are processing data, don't reset the timeout, wait for the data to be processed - if (this.flags.is_processing_data) return; - const interval = this.getTimeoutInterval(); - if (this.timer.state == .ACTIVE) { - this.globalObject.bunVM().timer.remove(&this.timer); - } - if (interval == 0) { - return; - } - - this.timer.next = bun.timespec.msFromNow(@intCast(interval)); - this.globalObject.bunVM().timer.insert(&this.timer); - } - - pub fn getQueries(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) bun.JSError!JSC.JSValue { - if (js.queriesGetCached(thisValue)) |value| { - return value; - } - - const array = try JSC.JSValue.createEmptyArray(globalObject, 0); - js.queriesSetCached(thisValue, globalObject, array); - - return array; - } - - pub fn getOnConnect(_: *PostgresSQLConnection, thisValue: JSC.JSValue, _: *JSC.JSGlobalObject) JSC.JSValue { - if (js.onconnectGetCached(thisValue)) |value| { - return value; - } - - return .js_undefined; - } - - pub fn setOnConnect(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject, value: JSC.JSValue) void { - js.onconnectSetCached(thisValue, globalObject, value); - } - - pub fn getOnClose(_: *PostgresSQLConnection, thisValue: JSC.JSValue, _: *JSC.JSGlobalObject) JSC.JSValue { - if (js.oncloseGetCached(thisValue)) |value| { - return value; - } - - return .js_undefined; - } - - pub fn setOnClose(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject, value: JSC.JSValue) void { - js.oncloseSetCached(thisValue, globalObject, value); - } - - pub fn setupTLS(this: *PostgresSQLConnection) void { - debug("setupTLS", .{}); - const new_socket = this.socket.SocketTCP.socket.connected.upgrade(this.tls_ctx.?, this.tls_config.server_name) orelse { - this.fail("Failed to upgrade to TLS", error.TLSUpgradeFailed); - return; - }; - this.socket = .{ - .SocketTLS = .{ - .socket = .{ - .connected = new_socket, - }, - }, - }; - - this.start(); - } - fn setupMaxLifetimeTimerIfNecessary(this: *PostgresSQLConnection) void { - if (this.max_lifetime_interval_ms == 0) return; - if (this.max_lifetime_timer.state == .ACTIVE) return; - - this.max_lifetime_timer.next = bun.timespec.msFromNow(@intCast(this.max_lifetime_interval_ms)); - this.globalObject.bunVM().timer.insert(&this.max_lifetime_timer); - } - - pub fn onConnectionTimeout(this: *PostgresSQLConnection) bun.api.Timer.EventLoopTimer.Arm { - debug("onConnectionTimeout", .{}); - - this.timer.state = .FIRED; - if (this.flags.is_processing_data) { - return .disarm; - } - - if (this.getTimeoutInterval() == 0) { - this.resetConnectionTimeout(); - return .disarm; - } - - switch (this.status) { - .connected => { - this.failFmt(.POSTGRES_IDLE_TIMEOUT, "Idle timeout reached after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.idle_timeout_interval_ms) *| std.time.ns_per_ms)}); - }, - else => { - this.failFmt(.POSTGRES_CONNECTION_TIMEOUT, "Connection timeout after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.connection_timeout_ms) *| std.time.ns_per_ms)}); - }, - .sent_startup_message => { - this.failFmt(.POSTGRES_CONNECTION_TIMEOUT, "Connection timed out after {} (sent startup message, but never received response)", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.connection_timeout_ms) *| std.time.ns_per_ms)}); - }, - } - return .disarm; - } - - pub fn onMaxLifetimeTimeout(this: *PostgresSQLConnection) bun.api.Timer.EventLoopTimer.Arm { - debug("onMaxLifetimeTimeout", .{}); - this.max_lifetime_timer.state = .FIRED; - if (this.status == .failed) return .disarm; - this.failFmt(.POSTGRES_LIFETIME_TIMEOUT, "Max lifetime timeout reached after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.max_lifetime_interval_ms) *| std.time.ns_per_ms)}); - return .disarm; - } - - fn start(this: *PostgresSQLConnection) void { - this.setupMaxLifetimeTimerIfNecessary(); - this.resetConnectionTimeout(); - this.sendStartupMessage(); - - const event_loop = this.globalObject.bunVM().eventLoop(); - event_loop.enter(); - defer event_loop.exit(); - this.flushData(); - } - - pub fn hasPendingActivity(this: *PostgresSQLConnection) bool { - return this.pending_activity_count.load(.acquire) > 0; - } - - fn updateHasPendingActivity(this: *PostgresSQLConnection) void { - const a: u32 = if (this.requests.readableLength() > 0) 1 else 0; - const b: u32 = if (this.status != .disconnected) 1 else 0; - this.pending_activity_count.store(a + b, .release); - } - - pub fn setStatus(this: *PostgresSQLConnection, status: Status) void { - if (this.status == status) return; - defer this.updateHasPendingActivity(); - - this.status = status; - this.resetConnectionTimeout(); - - switch (status) { - .connected => { - const on_connect = this.consumeOnConnectCallback(this.globalObject) orelse return; - const js_value = this.js_value; - js_value.ensureStillAlive(); - this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ JSValue.jsNull(), js_value }); - this.poll_ref.unref(this.globalObject.bunVM()); - }, - else => {}, - } - } - - pub fn finalize(this: *PostgresSQLConnection) void { - debug("PostgresSQLConnection finalize", .{}); - this.stopTimers(); - this.js_value = .zero; - this.deref(); - } - - pub fn flushDataAndResetTimeout(this: *PostgresSQLConnection) void { - this.resetConnectionTimeout(); - this.flushData(); - } - - pub fn flushData(this: *PostgresSQLConnection) void { - const chunk = this.write_buffer.remaining(); - if (chunk.len == 0) return; - const wrote = this.socket.write(chunk); - if (wrote > 0) { - SocketMonitor.write(chunk[0..@intCast(wrote)]); - this.write_buffer.consume(@intCast(wrote)); - } - } - - pub fn failWithJSValue(this: *PostgresSQLConnection, value: JSValue) void { - defer this.updateHasPendingActivity(); - this.stopTimers(); - if (this.status == .failed) return; - - this.status = .failed; - - this.ref(); - defer this.deref(); - // we defer the refAndClose so the on_close will be called first before we reject the pending requests - defer this.refAndClose(value); - const on_close = this.consumeOnCloseCallback(this.globalObject) orelse return; - - const loop = this.globalObject.bunVM().eventLoop(); - loop.enter(); - defer loop.exit(); - _ = on_close.call( - this.globalObject, - this.js_value, - &[_]JSValue{ - value, - this.getQueriesArray(), - }, - ) catch |e| this.globalObject.reportActiveExceptionAsUnhandled(e); - } - - pub fn failFmt(this: *PostgresSQLConnection, comptime error_code: JSC.Error, comptime fmt: [:0]const u8, args: anytype) void { - this.failWithJSValue(error_code.fmt(this.globalObject, fmt, args)); - } - - pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: AnyPostgresError) void { - debug("failed: {s}: {s}", .{ message, @errorName(err) }); - - const globalObject = this.globalObject; - - this.failWithJSValue(postgresErrorToJS(globalObject, message, err)); - } - - pub fn onClose(this: *PostgresSQLConnection) void { - var vm = this.globalObject.bunVM(); - const loop = vm.eventLoop(); - loop.enter(); - defer loop.exit(); - this.poll_ref.unref(this.globalObject.bunVM()); - - this.fail("Connection closed", error.ConnectionClosed); - } - - fn sendStartupMessage(this: *PostgresSQLConnection) void { - if (this.status != .connecting) return; - debug("sendStartupMessage", .{}); - this.status = .sent_startup_message; - var msg = protocol.StartupMessage{ - .user = Data{ .temporary = this.user }, - .database = Data{ .temporary = this.database }, - .options = Data{ .temporary = this.options }, - }; - msg.writeInternal(Writer, this.writer()) catch |err| { - this.fail("Failed to write startup message", err); - }; - } - - fn startTLS(this: *PostgresSQLConnection, socket: uws.AnySocket) void { - debug("startTLS", .{}); - const offset = switch (this.tls_status) { - .message_sent => |count| count, - else => 0, - }; - const ssl_request = [_]u8{ - 0x00, 0x00, 0x00, 0x08, // Length - 0x04, 0xD2, 0x16, 0x2F, // SSL request code - }; - - const written = socket.write(ssl_request[offset..]); - if (written > 0) { - this.tls_status = .{ - .message_sent = offset + @as(u8, @intCast(written)), - }; - } else { - this.tls_status = .{ - .message_sent = offset, - }; - } - } - - pub fn onOpen(this: *PostgresSQLConnection, socket: uws.AnySocket) void { - this.socket = socket; - - this.poll_ref.ref(this.globalObject.bunVM()); - this.updateHasPendingActivity(); - - if (this.tls_status == .message_sent or this.tls_status == .pending) { - this.startTLS(socket); - return; - } - - this.start(); - } - - pub fn onHandshake(this: *PostgresSQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void { - debug("onHandshake: {d} {d}", .{ success, ssl_error.error_no }); - const handshake_success = if (success == 1) true else false; - if (handshake_success) { - if (this.tls_config.reject_unauthorized != 0) { - // only reject the connection if reject_unauthorized == true - switch (this.ssl_mode) { - // https://github.com/porsager/postgres/blob/6ec85a432b17661ccacbdf7f765c651e88969d36/src/connection.js#L272-L279 - - .verify_ca, .verify_full => { - if (ssl_error.error_no != 0) { - this.failWithJSValue(ssl_error.toJS(this.globalObject)); - return; - } - - const ssl_ptr: *BoringSSL.c.SSL = @ptrCast(this.socket.getNativeHandle()); - if (BoringSSL.c.SSL_get_servername(ssl_ptr, 0)) |servername| { - const hostname = servername[0..bun.len(servername)]; - if (!BoringSSL.checkServerIdentity(ssl_ptr, hostname)) { - this.failWithJSValue(ssl_error.toJS(this.globalObject)); - } - } - }, - else => { - return; - }, - } - } - } else { - // if we are here is because server rejected us, and the error_no is the cause of this - // no matter if reject_unauthorized is false because we are disconnected by the server - this.failWithJSValue(ssl_error.toJS(this.globalObject)); - } - } - - pub fn onTimeout(this: *PostgresSQLConnection) void { - _ = this; - debug("onTimeout", .{}); - } - - pub fn onDrain(this: *PostgresSQLConnection) void { - - // Don't send any other messages while we're waiting for TLS. - if (this.tls_status == .message_sent) { - if (this.tls_status.message_sent < 8) { - this.startTLS(this.socket); - } - - return; - } - - const event_loop = this.globalObject.bunVM().eventLoop(); - event_loop.enter(); - defer event_loop.exit(); - this.flushData(); - } - - pub fn onData(this: *PostgresSQLConnection, data: []const u8) void { - this.ref(); - this.flags.is_processing_data = true; - const vm = this.globalObject.bunVM(); - - this.disableConnectionTimeout(); - defer { - if (this.status == .connected and !this.hasQueryRunning() and this.write_buffer.remaining().len == 0) { - // Don't keep the process alive when there's nothing to do. - this.poll_ref.unref(vm); - } else if (this.status == .connected) { - // Keep the process alive if there's something to do. - this.poll_ref.ref(vm); - } - this.flags.is_processing_data = false; - - // reset the connection timeout after we're done processing the data - this.resetConnectionTimeout(); - this.deref(); - } - - const event_loop = vm.eventLoop(); - event_loop.enter(); - defer event_loop.exit(); - SocketMonitor.read(data); - // reset the head to the last message so remaining reflects the right amount of bytes - this.read_buffer.head = this.last_message_start; - - if (this.read_buffer.remaining().len == 0) { - var consumed: usize = 0; - var offset: usize = 0; - const reader = protocol.StackReader.init(data, &consumed, &offset); - PostgresRequest.onData(this, protocol.StackReader, reader) catch |err| { - if (err == error.ShortRead) { - if (comptime bun.Environment.allow_assert) { - debug("read_buffer: empty and received short read: last_message_start: {d}, head: {d}, len: {d}", .{ - offset, - consumed, - data.len, - }); - } - - this.read_buffer.head = 0; - this.last_message_start = 0; - this.read_buffer.byte_list.len = 0; - this.read_buffer.write(bun.default_allocator, data[offset..]) catch @panic("failed to write to read buffer"); - } else { - bun.handleErrorReturnTrace(err, @errorReturnTrace()); - - this.fail("Failed to read data", err); - } - }; - // no need to reset anything, its already empty - return; - } - // read buffer is not empty, so we need to write the data to the buffer and then read it - this.read_buffer.write(bun.default_allocator, data) catch @panic("failed to write to read buffer"); - PostgresRequest.onData(this, Reader, this.bufferedReader()) catch |err| { - if (err != error.ShortRead) { - bun.handleErrorReturnTrace(err, @errorReturnTrace()); - this.fail("Failed to read data", err); - return; - } - - if (comptime bun.Environment.allow_assert) { - debug("read_buffer: not empty and received short read: last_message_start: {d}, head: {d}, len: {d}", .{ - this.last_message_start, - this.read_buffer.head, - this.read_buffer.byte_list.len, - }); - } - return; - }; - - debug("clean read_buffer", .{}); - // success, we read everything! let's reset the last message start and the head - this.last_message_start = 0; - this.read_buffer.head = 0; - } - - pub fn constructor(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!*PostgresSQLConnection { - _ = callframe; - return globalObject.throw("PostgresSQLConnection cannot be constructed directly", .{}); - } - - comptime { - const jscall = JSC.toJSHostFn(call); - @export(&jscall, .{ .name = "PostgresSQLConnection__createInstance" }); - } - - pub fn call(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { - var vm = globalObject.bunVM(); - const arguments = callframe.arguments_old(15).slice(); - const hostname_str = try arguments[0].toBunString(globalObject); - defer hostname_str.deref(); - const port = try arguments[1].coerce(i32, globalObject); - - const username_str = try arguments[2].toBunString(globalObject); - defer username_str.deref(); - const password_str = try arguments[3].toBunString(globalObject); - defer password_str.deref(); - const database_str = try arguments[4].toBunString(globalObject); - defer database_str.deref(); - const ssl_mode: SSLMode = switch (arguments[5].toInt32()) { - 0 => .disable, - 1 => .prefer, - 2 => .require, - 3 => .verify_ca, - 4 => .verify_full, - else => .disable, - }; - - const tls_object = arguments[6]; - - var tls_config: JSC.API.ServerConfig.SSLConfig = .{}; - var tls_ctx: ?*uws.SocketContext = null; - if (ssl_mode != .disable) { - tls_config = if (tls_object.isBoolean() and tls_object.toBoolean()) - .{} - else if (tls_object.isObject()) - (JSC.API.ServerConfig.SSLConfig.fromJS(vm, globalObject, tls_object) catch return .zero) orelse .{} - else { - return globalObject.throwInvalidArguments("tls must be a boolean or an object", .{}); - }; - - if (globalObject.hasException()) { - tls_config.deinit(); - return .zero; - } - - // we always request the cert so we can verify it and also we manually abort the connection if the hostname doesn't match - const original_reject_unauthorized = tls_config.reject_unauthorized; - tls_config.reject_unauthorized = 0; - tls_config.request_cert = 1; - // We create it right here so we can throw errors early. - const context_options = tls_config.asUSockets(); - var err: uws.create_bun_socket_error_t = .none; - tls_ctx = uws.SocketContext.createSSLContext(vm.uwsLoop(), @sizeOf(*PostgresSQLConnection), context_options, &err) orelse { - if (err != .none) { - return globalObject.throw("failed to create TLS context", .{}); - } else { - return globalObject.throwValue(err.toJS(globalObject)); - } - }; - // restore the original reject_unauthorized - tls_config.reject_unauthorized = original_reject_unauthorized; - if (err != .none) { - tls_config.deinit(); - if (tls_ctx) |ctx| { - ctx.deinit(true); - } - return globalObject.throwValue(err.toJS(globalObject)); - } - - uws.NewSocketHandler(true).configure(tls_ctx.?, true, *PostgresSQLConnection, SocketHandler(true)); - } - - var username: []const u8 = ""; - var password: []const u8 = ""; - var database: []const u8 = ""; - var options: []const u8 = ""; - var path: []const u8 = ""; - - const options_str = try arguments[7].toBunString(globalObject); - defer options_str.deref(); - - const path_str = try arguments[8].toBunString(globalObject); - defer path_str.deref(); - - const options_buf: []u8 = brk: { - var b = bun.StringBuilder{}; - b.cap += username_str.utf8ByteLength() + 1 + password_str.utf8ByteLength() + 1 + database_str.utf8ByteLength() + 1 + options_str.utf8ByteLength() + 1 + path_str.utf8ByteLength() + 1; - - b.allocate(bun.default_allocator) catch {}; - var u = username_str.toUTF8WithoutRef(bun.default_allocator); - defer u.deinit(); - username = b.append(u.slice()); - - var p = password_str.toUTF8WithoutRef(bun.default_allocator); - defer p.deinit(); - password = b.append(p.slice()); - - var d = database_str.toUTF8WithoutRef(bun.default_allocator); - defer d.deinit(); - database = b.append(d.slice()); - - var o = options_str.toUTF8WithoutRef(bun.default_allocator); - defer o.deinit(); - options = b.append(o.slice()); - - var _path = path_str.toUTF8WithoutRef(bun.default_allocator); - defer _path.deinit(); - path = b.append(_path.slice()); - - break :brk b.allocatedSlice(); - }; - - const on_connect = arguments[9]; - const on_close = arguments[10]; - const idle_timeout = arguments[11].toInt32(); - const connection_timeout = arguments[12].toInt32(); - const max_lifetime = arguments[13].toInt32(); - const use_unnamed_prepared_statements = arguments[14].asBoolean(); - - const ptr: *PostgresSQLConnection = try bun.default_allocator.create(PostgresSQLConnection); - - ptr.* = PostgresSQLConnection{ - .globalObject = globalObject, - - .database = database, - .user = username, - .password = password, - .path = path, - .options = options, - .options_buf = options_buf, - .socket = .{ .SocketTCP = .{ .socket = .{ .detached = {} } } }, - .requests = PostgresRequest.Queue.init(bun.default_allocator), - .statements = PreparedStatementsMap{}, - .tls_config = tls_config, - .tls_ctx = tls_ctx, - .ssl_mode = ssl_mode, - .tls_status = if (ssl_mode != .disable) .pending else .none, - .idle_timeout_interval_ms = @intCast(idle_timeout), - .connection_timeout_ms = @intCast(connection_timeout), - .max_lifetime_interval_ms = @intCast(max_lifetime), - .flags = .{ - .use_unnamed_prepared_statements = use_unnamed_prepared_statements, - }, - }; - - ptr.updateHasPendingActivity(); - ptr.poll_ref.ref(vm); - const js_value = ptr.toJS(globalObject); - js_value.ensureStillAlive(); - ptr.js_value = js_value; - - js.onconnectSetCached(js_value, globalObject, on_connect); - js.oncloseSetCached(js_value, globalObject, on_close); - bun.analytics.Features.postgres_connections += 1; - - { - const hostname = hostname_str.toUTF8(bun.default_allocator); - defer hostname.deinit(); - - const ctx = vm.rareData().postgresql_context.tcp orelse brk: { - const ctx_ = uws.SocketContext.createNoSSLContext(vm.uwsLoop(), @sizeOf(*PostgresSQLConnection)).?; - uws.NewSocketHandler(false).configure(ctx_, true, *PostgresSQLConnection, SocketHandler(false)); - vm.rareData().postgresql_context.tcp = ctx_; - break :brk ctx_; - }; - - if (path.len > 0) { - ptr.socket = .{ - .SocketTCP = uws.SocketTCP.connectUnixAnon(path, ctx, ptr, false) catch |err| { - tls_config.deinit(); - if (tls_ctx) |tls| { - tls.deinit(true); - } - ptr.deinit(); - return globalObject.throwError(err, "failed to connect to postgresql"); - }, - }; - } else { - ptr.socket = .{ - .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| { - tls_config.deinit(); - if (tls_ctx) |tls| { - tls.deinit(true); - } - ptr.deinit(); - return globalObject.throwError(err, "failed to connect to postgresql"); - }, - }; - } - ptr.resetConnectionTimeout(); - } - - return js_value; - } - - fn SocketHandler(comptime ssl: bool) type { - return struct { - const SocketType = uws.NewSocketHandler(ssl); - fn _socket(s: SocketType) Socket { - if (comptime ssl) { - return Socket{ .SocketTLS = s }; - } - - return Socket{ .SocketTCP = s }; - } - pub fn onOpen(this: *PostgresSQLConnection, socket: SocketType) void { - this.onOpen(_socket(socket)); - } - - fn onHandshake_(this: *PostgresSQLConnection, _: anytype, success: i32, ssl_error: uws.us_bun_verify_error_t) void { - this.onHandshake(success, ssl_error); - } - - pub const onHandshake = if (ssl) onHandshake_ else null; - - pub fn onClose(this: *PostgresSQLConnection, socket: SocketType, _: i32, _: ?*anyopaque) void { - _ = socket; - this.onClose(); - } - - pub fn onEnd(this: *PostgresSQLConnection, socket: SocketType) void { - _ = socket; - this.onClose(); - } - - pub fn onConnectError(this: *PostgresSQLConnection, socket: SocketType, _: i32) void { - _ = socket; - this.onClose(); - } - - pub fn onTimeout(this: *PostgresSQLConnection, socket: SocketType) void { - _ = socket; - this.onTimeout(); - } - - pub fn onData(this: *PostgresSQLConnection, socket: SocketType, data: []const u8) void { - _ = socket; - this.onData(data); - } - - pub fn onWritable(this: *PostgresSQLConnection, socket: SocketType) void { - _ = socket; - this.onDrain(); - } - }; - } - - pub fn ref(this: *@This()) void { - bun.assert(this.ref_count > 0); - this.ref_count += 1; - } - - pub fn doRef(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { - this.poll_ref.ref(this.globalObject.bunVM()); - this.updateHasPendingActivity(); - return .js_undefined; - } - - pub fn doUnref(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { - this.poll_ref.unref(this.globalObject.bunVM()); - this.updateHasPendingActivity(); - return .js_undefined; - } - pub fn doFlush(this: *PostgresSQLConnection, _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSC.JSValue { - this.flushData(); - return .js_undefined; - } - - pub fn deref(this: *@This()) void { - const ref_count = this.ref_count; - this.ref_count -= 1; - - if (ref_count == 1) { - this.disconnect(); - this.deinit(); - } - } - - pub fn doClose(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { - _ = globalObject; - this.disconnect(); - this.write_buffer.deinit(bun.default_allocator); - - return .js_undefined; - } - - pub fn stopTimers(this: *PostgresSQLConnection) void { - if (this.timer.state == .ACTIVE) { - this.globalObject.bunVM().timer.remove(&this.timer); - } - if (this.max_lifetime_timer.state == .ACTIVE) { - this.globalObject.bunVM().timer.remove(&this.max_lifetime_timer); - } - } - - pub fn deinit(this: *@This()) void { - this.stopTimers(); - var iter = this.statements.valueIterator(); - while (iter.next()) |stmt_ptr| { - var stmt = stmt_ptr.*; - stmt.deref(); - } - this.statements.deinit(bun.default_allocator); - this.write_buffer.deinit(bun.default_allocator); - this.read_buffer.deinit(bun.default_allocator); - this.backend_parameters.deinit(); - - bun.freeSensitive(bun.default_allocator, this.options_buf); - - this.tls_config.deinit(); - bun.default_allocator.destroy(this); - } - - fn refAndClose(this: *@This(), js_reason: ?JSC.JSValue) void { - // refAndClose is always called when we wanna to disconnect or when we are closed - - if (!this.socket.isClosed()) { - // event loop need to be alive to close the socket - this.poll_ref.ref(this.globalObject.bunVM()); - // will unref on socket close - this.socket.close(); - } - - // cleanup requests - while (this.current()) |request| { - switch (request.status) { - // pending we will fail the request and the stmt will be marked as error ConnectionClosed too - .pending => { - const stmt = request.statement orelse continue; - stmt.error_response = .{ .postgres_error = AnyPostgresError.ConnectionClosed }; - stmt.status = .failed; - if (js_reason) |reason| { - request.onJSError(reason, this.globalObject); - } else { - request.onError(.{ .postgres_error = AnyPostgresError.ConnectionClosed }, this.globalObject); - } - }, - // in the middle of running - .binding, - .running, - .partial_response, - => { - if (js_reason) |reason| { - request.onJSError(reason, this.globalObject); - } else { - request.onError(.{ .postgres_error = AnyPostgresError.ConnectionClosed }, this.globalObject); - } - }, - // just ignore success and fail cases - .success, .fail => {}, - } - request.deref(); - this.requests.discard(1); - } - } - - pub fn disconnect(this: *@This()) void { - this.stopTimers(); - - if (this.status == .connected) { - this.status = .disconnected; - this.refAndClose(null); - } - } - - fn current(this: *PostgresSQLConnection) ?*PostgresSQLQuery { - if (this.requests.readableLength() == 0) { - return null; - } - - return this.requests.peekItem(0); - } - - fn hasQueryRunning(this: *PostgresSQLConnection) bool { - return !this.flags.is_ready_for_query or this.current() != null; - } - - pub const Writer = struct { - connection: *PostgresSQLConnection, - - pub fn write(this: Writer, data: []const u8) AnyPostgresError!void { - var buffer = &this.connection.write_buffer; - try buffer.write(bun.default_allocator, data); - } - - pub fn pwrite(this: Writer, data: []const u8, index: usize) AnyPostgresError!void { - @memcpy(this.connection.write_buffer.byte_list.slice()[index..][0..data.len], data); - } - - pub fn offset(this: Writer) usize { - return this.connection.write_buffer.len(); - } - }; - - pub fn writer(this: *PostgresSQLConnection) protocol.NewWriter(Writer) { - return .{ - .wrapped = .{ - .connection = this, - }, - }; - } - - pub const Reader = struct { - connection: *PostgresSQLConnection, - - pub fn markMessageStart(this: Reader) void { - this.connection.last_message_start = this.connection.read_buffer.head; - } - - pub const ensureLength = ensureCapacity; - - pub fn peek(this: Reader) []const u8 { - return this.connection.read_buffer.remaining(); - } - pub fn skip(this: Reader, count: usize) void { - this.connection.read_buffer.head = @min(this.connection.read_buffer.head + @as(u32, @truncate(count)), this.connection.read_buffer.byte_list.len); - } - pub fn ensureCapacity(this: Reader, count: usize) bool { - return @as(usize, this.connection.read_buffer.head) + count <= @as(usize, this.connection.read_buffer.byte_list.len); - } - pub fn read(this: Reader, count: usize) AnyPostgresError!Data { - var remaining = this.connection.read_buffer.remaining(); - if (@as(usize, remaining.len) < count) { - return error.ShortRead; - } - - this.skip(count); - return Data{ - .temporary = remaining[0..count], - }; - } - pub fn readZ(this: Reader) AnyPostgresError!Data { - const remain = this.connection.read_buffer.remaining(); - - if (bun.strings.indexOfChar(remain, 0)) |zero| { - this.skip(zero + 1); - return Data{ - .temporary = remain[0..zero], - }; - } - - return error.ShortRead; - } - }; - - pub fn bufferedReader(this: *PostgresSQLConnection) protocol.NewReader(Reader) { - return .{ - .wrapped = .{ .connection = this }, - }; - } - - fn advance(this: *PostgresSQLConnection) !void { - while (this.requests.readableLength() > 0) { - var req: *PostgresSQLQuery = this.requests.peekItem(0); - switch (req.status) { - .pending => { - if (req.flags.simple) { - debug("executeQuery", .{}); - var query_str = req.query.toUTF8(bun.default_allocator); - defer query_str.deinit(); - PostgresRequest.executeQuery(query_str.slice(), PostgresSQLConnection.Writer, this.writer()) catch |err| { - req.onWriteFail(err, this.globalObject, this.getQueriesArray()); - req.deref(); - this.requests.discard(1); - - continue; - }; - this.flags.is_ready_for_query = false; - req.status = .running; - return; - } else { - const stmt = req.statement orelse return error.ExpectedStatement; - - switch (stmt.status) { - .failed => { - bun.assert(stmt.error_response != null); - req.onError(stmt.error_response.?, this.globalObject); - req.deref(); - this.requests.discard(1); - - continue; - }, - .prepared => { - const thisValue = req.thisValue.get(); - bun.assert(thisValue != .zero); - const binding_value = PostgresSQLQuery.js.bindingGetCached(thisValue) orelse .zero; - const columns_value = PostgresSQLQuery.js.columnsGetCached(thisValue) orelse .zero; - req.flags.binary = stmt.fields.len > 0; - - PostgresRequest.bindAndExecute(this.globalObject, stmt, binding_value, columns_value, PostgresSQLConnection.Writer, this.writer()) catch |err| { - req.onWriteFail(err, this.globalObject, this.getQueriesArray()); - req.deref(); - this.requests.discard(1); - - continue; - }; - this.flags.is_ready_for_query = false; - req.status = .binding; - return; - }, - .pending => { - // statement is pending, lets write/parse it - var query_str = req.query.toUTF8(bun.default_allocator); - defer query_str.deinit(); - const has_params = stmt.signature.fields.len > 0; - // If it does not have params, we can write and execute immediately in one go - if (!has_params) { - const thisValue = req.thisValue.get(); - bun.assert(thisValue != .zero); - // prepareAndQueryWithSignature will write + bind + execute, it will change to running after binding is complete - const binding_value = PostgresSQLQuery.js.bindingGetCached(thisValue) orelse .zero; - PostgresRequest.prepareAndQueryWithSignature(this.globalObject, query_str.slice(), binding_value, PostgresSQLConnection.Writer, this.writer(), &stmt.signature) catch |err| { - stmt.status = .failed; - stmt.error_response = .{ .postgres_error = err }; - req.onWriteFail(err, this.globalObject, this.getQueriesArray()); - req.deref(); - this.requests.discard(1); - - continue; - }; - this.flags.is_ready_for_query = false; - req.status = .binding; - stmt.status = .parsing; - - return; - } - const connection_writer = this.writer(); - // write query and wait for it to be prepared - PostgresRequest.writeQuery(query_str.slice(), stmt.signature.prepared_statement_name, stmt.signature.fields, PostgresSQLConnection.Writer, connection_writer) catch |err| { - stmt.error_response = .{ .postgres_error = err }; - stmt.status = .failed; - - req.onWriteFail(err, this.globalObject, this.getQueriesArray()); - req.deref(); - this.requests.discard(1); - - continue; - }; - connection_writer.write(&protocol.Sync) catch |err| { - stmt.error_response = .{ .postgres_error = err }; - stmt.status = .failed; - - req.onWriteFail(err, this.globalObject, this.getQueriesArray()); - req.deref(); - this.requests.discard(1); - - continue; - }; - this.flags.is_ready_for_query = false; - stmt.status = .parsing; - return; - }, - .parsing => { - // we are still parsing, lets wait for it to be prepared or failed - return; - }, - } - } - }, - - .running, .binding, .partial_response => { - // if we are binding it will switch to running immediately - // if we are running, we need to wait for it to be success or fail - return; - }, - .success, .fail => { - req.deref(); - this.requests.discard(1); - continue; - }, - } - } - } - - pub fn getQueriesArray(this: *const PostgresSQLConnection) JSValue { - return js.queriesGetCached(this.js_value) orelse .zero; - } - - pub const DataCell = @import("./DataCell.zig").DataCell; - - pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_literal), comptime Context: type, reader: protocol.NewReader(Context)) AnyPostgresError!void { - debug("on({s})", .{@tagName(MessageType)}); - - switch (comptime MessageType) { - .DataRow => { - const request = this.current() orelse return error.ExpectedRequest; - var statement = request.statement orelse return error.ExpectedStatement; - var structure: JSValue = .js_undefined; - var cached_structure: ?PostgresCachedStructure = null; - // explicit use switch without else so if new modes are added, we don't forget to check for duplicate fields - switch (request.flags.result_mode) { - .objects => { - cached_structure = statement.structure(this.js_value, this.globalObject); - structure = cached_structure.?.jsValue() orelse .js_undefined; - }, - .raw, .values => { - // no need to check for duplicate fields or structure - }, - } - - var putter = DataCell.Putter{ - .list = &.{}, - .fields = statement.fields, - .binary = request.flags.binary, - .bigint = request.flags.bigint, - .globalObject = this.globalObject, - }; - - var stack_buf: [70]DataCell = undefined; - var cells: []DataCell = stack_buf[0..@min(statement.fields.len, JSC.JSObject.maxInlineCapacity())]; - var free_cells = false; - defer { - for (cells[0..putter.count]) |*cell| { - cell.deinit(); - } - if (free_cells) bun.default_allocator.free(cells); - } - - if (statement.fields.len >= JSC.JSObject.maxInlineCapacity()) { - cells = try bun.default_allocator.alloc(DataCell, statement.fields.len); - free_cells = true; - } - // make sure all cells are reset if reader short breaks the fields will just be null with is better than undefined behavior - @memset(cells, DataCell{ .tag = .null, .value = .{ .null = 0 } }); - putter.list = cells; - - if (request.flags.result_mode == .raw) { - try protocol.DataRow.decode( - &putter, - Context, - reader, - DataCell.Putter.putRaw, - ); - } else { - try protocol.DataRow.decode( - &putter, - Context, - reader, - DataCell.Putter.put, - ); - } - const thisValue = request.thisValue.get(); - bun.assert(thisValue != .zero); - const pending_value = PostgresSQLQuery.js.pendingValueGetCached(thisValue) orelse .zero; - pending_value.ensureStillAlive(); - const result = putter.toJS(this.globalObject, pending_value, structure, statement.fields_flags, request.flags.result_mode, cached_structure); - - if (pending_value == .zero) { - PostgresSQLQuery.js.pendingValueSetCached(thisValue, this.globalObject, result); - } - }, - .CopyData => { - var copy_data: protocol.CopyData = undefined; - try copy_data.decodeInternal(Context, reader); - copy_data.data.deinit(); - }, - .ParameterStatus => { - var parameter_status: protocol.ParameterStatus = undefined; - try parameter_status.decodeInternal(Context, reader); - defer { - parameter_status.deinit(); - } - try this.backend_parameters.insert(parameter_status.name.slice(), parameter_status.value.slice()); - }, - .ReadyForQuery => { - var ready_for_query: protocol.ReadyForQuery = undefined; - try ready_for_query.decodeInternal(Context, reader); - - this.setStatus(.connected); - this.flags.is_ready_for_query = true; - this.socket.setTimeout(300); - defer this.updateRef(); - - if (this.current()) |request| { - if (request.status == .partial_response) { - // if is a partial response, just signal that the query is now complete - request.onResult("", this.globalObject, this.js_value, true); - } - } - try this.advance(); - - this.flushData(); - }, - .CommandComplete => { - var request = this.current() orelse return error.ExpectedRequest; - - var cmd: protocol.CommandComplete = undefined; - try cmd.decodeInternal(Context, reader); - defer { - cmd.deinit(); - } - debug("-> {s}", .{cmd.command_tag.slice()}); - defer this.updateRef(); - - if (request.flags.simple) { - // simple queries can have multiple commands - request.onResult(cmd.command_tag.slice(), this.globalObject, this.js_value, false); - } else { - request.onResult(cmd.command_tag.slice(), this.globalObject, this.js_value, true); - } - }, - .BindComplete => { - try reader.eatMessage(protocol.BindComplete); - var request = this.current() orelse return error.ExpectedRequest; - if (request.status == .binding) { - request.status = .running; - } - }, - .ParseComplete => { - try reader.eatMessage(protocol.ParseComplete); - const request = this.current() orelse return error.ExpectedRequest; - if (request.statement) |statement| { - // if we have params wait for parameter description - if (statement.status == .parsing and statement.signature.fields.len == 0) { - statement.status = .prepared; - } - } - }, - .ParameterDescription => { - var description: protocol.ParameterDescription = undefined; - try description.decodeInternal(Context, reader); - const request = this.current() orelse return error.ExpectedRequest; - var statement = request.statement orelse return error.ExpectedStatement; - statement.parameters = description.parameters; - if (statement.status == .parsing) { - statement.status = .prepared; - } - }, - .RowDescription => { - var description: protocol.RowDescription = undefined; - try description.decodeInternal(Context, reader); - errdefer description.deinit(); - const request = this.current() orelse return error.ExpectedRequest; - var statement = request.statement orelse return error.ExpectedStatement; - statement.fields = description.fields; - }, - .Authentication => { - var auth: protocol.Authentication = undefined; - try auth.decodeInternal(Context, reader); - defer auth.deinit(); - - switch (auth) { - .SASL => { - if (this.authentication_state != .SASL) { - this.authentication_state = .{ .SASL = .{} }; - } - - var mechanism_buf: [128]u8 = undefined; - const mechanism = std.fmt.bufPrintZ(&mechanism_buf, "n,,n=*,r={s}", .{this.authentication_state.SASL.nonce()}) catch unreachable; - var response = protocol.SASLInitialResponse{ - .mechanism = .{ - .temporary = "SCRAM-SHA-256", - }, - .data = .{ - .temporary = mechanism, - }, - }; - - try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); - debug("SASL", .{}); - this.flushData(); - }, - .SASLContinue => |*cont| { - if (this.authentication_state != .SASL) { - debug("Unexpected SASLContinue for authentiation state: {s}", .{@tagName(std.meta.activeTag(this.authentication_state))}); - return error.UnexpectedMessage; - } - var sasl = &this.authentication_state.SASL; - - if (sasl.status != .init) { - debug("Unexpected SASLContinue for SASL state: {s}", .{@tagName(sasl.status)}); - return error.UnexpectedMessage; - } - debug("SASLContinue", .{}); - - const iteration_count = try cont.iterationCount(); - - const server_salt_decoded_base64 = bun.base64.decodeAlloc(bun.z_allocator, cont.s) catch |err| { - return switch (err) { - error.DecodingFailed => error.SASL_SIGNATURE_INVALID_BASE64, - else => |e| e, - }; - }; - defer bun.z_allocator.free(server_salt_decoded_base64); - try sasl.computeSaltedPassword(server_salt_decoded_base64, iteration_count, this); - - const auth_string = try std.fmt.allocPrint( - bun.z_allocator, - "n=*,r={s},r={s},s={s},i={s},c=biws,r={s}", - .{ - sasl.nonce(), - cont.r, - cont.s, - cont.i, - cont.r, - }, - ); - defer bun.z_allocator.free(auth_string); - try sasl.computeServerSignature(auth_string); - - const client_key = sasl.clientKey(); - const client_key_signature = sasl.clientKeySignature(&client_key, auth_string); - var client_key_xor_buffer: [32]u8 = undefined; - for (&client_key_xor_buffer, client_key, client_key_signature) |*out, a, b| { - out.* = a ^ b; - } - - var client_key_xor_base64_buf = std.mem.zeroes([bun.base64.encodeLenFromSize(32)]u8); - const xor_base64_len = bun.base64.encode(&client_key_xor_base64_buf, &client_key_xor_buffer); - - const payload = try std.fmt.allocPrint( - bun.z_allocator, - "c=biws,r={s},p={s}", - .{ cont.r, client_key_xor_base64_buf[0..xor_base64_len] }, - ); - defer bun.z_allocator.free(payload); - - var response = protocol.SASLResponse{ - .data = .{ - .temporary = payload, - }, - }; - - try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); - sasl.status = .@"continue"; - this.flushData(); - }, - .SASLFinal => |final| { - if (this.authentication_state != .SASL) { - debug("SASLFinal - Unexpected SASLContinue for authentiation state: {s}", .{@tagName(std.meta.activeTag(this.authentication_state))}); - return error.UnexpectedMessage; - } - var sasl = &this.authentication_state.SASL; - - if (sasl.status != .@"continue") { - debug("SASLFinal - Unexpected SASLContinue for SASL state: {s}", .{@tagName(sasl.status)}); - return error.UnexpectedMessage; - } - - if (sasl.server_signature_len == 0) { - debug("SASLFinal - Server signature is empty", .{}); - return error.UnexpectedMessage; - } - - const server_signature = sasl.serverSignature(); - - // This will usually start with "v=" - const comparison_signature = final.data.slice(); - - if (comparison_signature.len < 2 or !bun.strings.eqlLong(server_signature, comparison_signature[2..], true)) { - debug("SASLFinal - SASL Server signature mismatch\nExpected: {s}\nActual: {s}", .{ server_signature, comparison_signature[2..] }); - this.fail("The server did not return the correct signature", error.SASL_SIGNATURE_MISMATCH); - } else { - debug("SASLFinal - SASL Server signature match", .{}); - this.authentication_state.zero(); - } - }, - .Ok => { - debug("Authentication OK", .{}); - this.authentication_state.zero(); - this.authentication_state = .{ .ok = {} }; - }, - - .Unknown => { - this.fail("Unknown authentication method", error.UNKNOWN_AUTHENTICATION_METHOD); - }, - - .ClearTextPassword => { - debug("ClearTextPassword", .{}); - var response = protocol.PasswordMessage{ - .password = .{ - .temporary = this.password, - }, - }; - - try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); - this.flushData(); - }, - - .MD5Password => |md5| { - debug("MD5Password", .{}); - // Format is: md5 + md5(md5(password + username) + salt) - var first_hash_buf: bun.sha.MD5.Digest = undefined; - var first_hash_str: [32]u8 = undefined; - var final_hash_buf: bun.sha.MD5.Digest = undefined; - var final_hash_str: [32]u8 = undefined; - var final_password_buf: [36]u8 = undefined; - - // First hash: md5(password + username) - var first_hasher = bun.sha.MD5.init(); - first_hasher.update(this.password); - first_hasher.update(this.user); - first_hasher.final(&first_hash_buf); - const first_hash_str_output = std.fmt.bufPrint(&first_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&first_hash_buf)}) catch unreachable; - - // Second hash: md5(first_hash + salt) - var final_hasher = bun.sha.MD5.init(); - final_hasher.update(first_hash_str_output); - final_hasher.update(&md5.salt); - final_hasher.final(&final_hash_buf); - const final_hash_str_output = std.fmt.bufPrint(&final_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&final_hash_buf)}) catch unreachable; - - // Format final password as "md5" + final_hash - const final_password = std.fmt.bufPrintZ(&final_password_buf, "md5{s}", .{final_hash_str_output}) catch unreachable; - - var response = protocol.PasswordMessage{ - .password = .{ - .temporary = final_password, - }, - }; - - this.authentication_state = .{ .md5 = {} }; - try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); - this.flushData(); - }, - - else => { - debug("TODO auth: {s}", .{@tagName(std.meta.activeTag(auth))}); - this.fail("TODO: support authentication method: {s}", error.UNSUPPORTED_AUTHENTICATION_METHOD); - }, - } - }, - .NoData => { - try reader.eatMessage(protocol.NoData); - var request = this.current() orelse return error.ExpectedRequest; - if (request.status == .binding) { - request.status = .running; - } - }, - .BackendKeyData => { - try this.backend_key_data.decodeInternal(Context, reader); - }, - .ErrorResponse => { - var err: protocol.ErrorResponse = undefined; - try err.decodeInternal(Context, reader); - - if (this.status == .connecting or this.status == .sent_startup_message) { - defer { - err.deinit(); - } - - this.failWithJSValue(err.toJS(this.globalObject)); - - // it shouldn't enqueue any requests while connecting - bun.assert(this.requests.count == 0); - return; - } - - var request = this.current() orelse { - debug("ErrorResponse: {}", .{err}); - return error.ExpectedRequest; - }; - var is_error_owned = true; - defer { - if (is_error_owned) { - err.deinit(); - } - } - if (request.statement) |stmt| { - if (stmt.status == PostgresSQLStatement.Status.parsing) { - stmt.status = PostgresSQLStatement.Status.failed; - stmt.error_response = .{ .protocol = err }; - is_error_owned = false; - if (this.statements.remove(bun.hash(stmt.signature.name))) { - stmt.deref(); - } - } - } - this.updateRef(); - - request.onError(.{ .protocol = err }, this.globalObject); - }, - .PortalSuspended => { - // try reader.eatMessage(&protocol.PortalSuspended); - // var request = this.current() orelse return error.ExpectedRequest; - // _ = request; - debug("TODO PortalSuspended", .{}); - }, - .CloseComplete => { - try reader.eatMessage(protocol.CloseComplete); - var request = this.current() orelse return error.ExpectedRequest; - defer this.updateRef(); - if (request.flags.simple) { - request.onResult("CLOSECOMPLETE", this.globalObject, this.js_value, false); - } else { - request.onResult("CLOSECOMPLETE", this.globalObject, this.js_value, true); - } - }, - .CopyInResponse => { - debug("TODO CopyInResponse", .{}); - }, - .NoticeResponse => { - debug("UNSUPPORTED NoticeResponse", .{}); - var resp: protocol.NoticeResponse = undefined; - - try resp.decodeInternal(Context, reader); - resp.deinit(); - }, - .EmptyQueryResponse => { - try reader.eatMessage(protocol.EmptyQueryResponse); - var request = this.current() orelse return error.ExpectedRequest; - defer this.updateRef(); - if (request.flags.simple) { - request.onResult("", this.globalObject, this.js_value, false); - } else { - request.onResult("", this.globalObject, this.js_value, true); - } - }, - .CopyOutResponse => { - debug("TODO CopyOutResponse", .{}); - }, - .CopyDone => { - debug("TODO CopyDone", .{}); - }, - .CopyBothResponse => { - debug("TODO CopyBothResponse", .{}); - }, - else => @compileError("Unknown message type: " ++ @tagName(MessageType)), - } - } - - pub fn updateRef(this: *PostgresSQLConnection) void { - this.updateHasPendingActivity(); - if (this.pending_activity_count.raw > 0) { - this.poll_ref.ref(this.globalObject.bunVM()); - } else { - this.poll_ref.unref(this.globalObject.bunVM()); - } - } - - pub fn getConnected(this: *PostgresSQLConnection, _: *JSC.JSGlobalObject) JSValue { - return JSValue.jsBoolean(this.status == Status.connected); - } - - pub fn consumeOnConnectCallback(this: *const PostgresSQLConnection, globalObject: *JSC.JSGlobalObject) ?JSC.JSValue { - debug("consumeOnConnectCallback", .{}); - const on_connect = js.onconnectGetCached(this.js_value) orelse return null; - debug("consumeOnConnectCallback exists", .{}); - - js.onconnectSetCached(this.js_value, globalObject, .zero); - return on_connect; - } - - pub fn consumeOnCloseCallback(this: *const PostgresSQLConnection, globalObject: *JSC.JSGlobalObject) ?JSC.JSValue { - debug("consumeOnCloseCallback", .{}); - const on_close = js.oncloseGetCached(this.js_value) orelse return null; - debug("consumeOnCloseCallback exists", .{}); - js.oncloseSetCached(this.js_value, globalObject, .zero); - return on_close; - } -}; - -pub const PostgresCachedStructure = struct { - structure: JSC.Strong.Optional = .empty, - // only populated if more than JSC.JSC__JSObject__maxInlineCapacity fields otherwise the structure will contain all fields inlined - fields: ?[]JSC.JSObject.ExternColumnIdentifier = null, - - pub fn has(this: *@This()) bool { - return this.structure.has() or this.fields != null; - } - - pub fn jsValue(this: *const @This()) ?JSC.JSValue { - return this.structure.get(); - } - - pub fn set(this: *@This(), globalObject: *JSC.JSGlobalObject, value: ?JSC.JSValue, fields: ?[]JSC.JSObject.ExternColumnIdentifier) void { - if (value) |v| { - this.structure.set(globalObject, v); - } - this.fields = fields; - } - - pub fn deinit(this: *@This()) void { - this.structure.deinit(); - if (this.fields) |fields| { - this.fields = null; - for (fields) |*name| { - name.deinit(); - } - bun.default_allocator.free(fields); - } - } -}; -pub const PostgresSQLStatement = struct { - cached_structure: PostgresCachedStructure = .{}, - ref_count: u32 = 1, - fields: []protocol.FieldDescription = &[_]protocol.FieldDescription{}, - parameters: []const int4 = &[_]int4{}, - signature: Signature, - status: Status = Status.pending, - error_response: ?Error = null, - needs_duplicate_check: bool = true, - fields_flags: PostgresSQLConnection.DataCell.Flags = .{}, - - pub const Error = union(enum) { - protocol: protocol.ErrorResponse, - postgres_error: AnyPostgresError, - - pub fn deinit(this: *@This()) void { - switch (this.*) { - .protocol => |*err| err.deinit(), - .postgres_error => {}, - } - } - - pub fn toJS(this: *const @This(), globalObject: *JSC.JSGlobalObject) JSValue { - return switch (this.*) { - .protocol => |err| err.toJS(globalObject), - .postgres_error => |err| postgresErrorToJS(globalObject, null, err), - }; - } - }; - pub const Status = enum { - pending, - parsing, - prepared, - failed, - - pub fn isRunning(this: @This()) bool { - return this == .parsing; - } - }; - pub fn ref(this: *@This()) void { - bun.assert(this.ref_count > 0); - this.ref_count += 1; - } - - pub fn deref(this: *@This()) void { - const ref_count = this.ref_count; - this.ref_count -= 1; - - if (ref_count == 1) { - this.deinit(); - } - } - - pub fn checkForDuplicateFields(this: *PostgresSQLStatement) void { - if (!this.needs_duplicate_check) return; - this.needs_duplicate_check = false; - - var seen_numbers = std.ArrayList(u32).init(bun.default_allocator); - defer seen_numbers.deinit(); - var seen_fields = bun.StringHashMap(void).init(bun.default_allocator); - seen_fields.ensureUnusedCapacity(@intCast(this.fields.len)) catch bun.outOfMemory(); - defer seen_fields.deinit(); - - // iterate backwards - var remaining = this.fields.len; - var flags: PostgresSQLConnection.DataCell.Flags = .{}; - while (remaining > 0) { - remaining -= 1; - const field: *protocol.FieldDescription = &this.fields[remaining]; - switch (field.name_or_index) { - .name => |*name| { - const seen = seen_fields.getOrPut(name.slice()) catch unreachable; - if (seen.found_existing) { - field.name_or_index = .duplicate; - flags.has_duplicate_columns = true; - } - - flags.has_named_columns = true; - }, - .index => |index| { - if (std.mem.indexOfScalar(u32, seen_numbers.items, index) != null) { - field.name_or_index = .duplicate; - flags.has_duplicate_columns = true; - } else { - seen_numbers.append(index) catch bun.outOfMemory(); - } - - flags.has_indexed_columns = true; - }, - .duplicate => { - flags.has_duplicate_columns = true; - }, - } - } - - this.fields_flags = flags; - } - - pub fn deinit(this: *PostgresSQLStatement) void { - debug("PostgresSQLStatement deinit", .{}); - - bun.assert(this.ref_count == 0); - - for (this.fields) |*field| { - field.deinit(); - } - bun.default_allocator.free(this.fields); - bun.default_allocator.free(this.parameters); - this.cached_structure.deinit(); - if (this.error_response) |err| { - this.error_response = null; - var _error = err; - _error.deinit(); - } - this.signature.deinit(); - bun.default_allocator.destroy(this); - } - - pub fn structure(this: *PostgresSQLStatement, owner: JSValue, globalObject: *JSC.JSGlobalObject) PostgresCachedStructure { - if (this.cached_structure.has()) { - return this.cached_structure; - } - this.checkForDuplicateFields(); - - // lets avoid most allocations - var stack_ids: [70]JSC.JSObject.ExternColumnIdentifier = undefined; - // lets de duplicate the fields early - var nonDuplicatedCount = this.fields.len; - for (this.fields) |*field| { - if (field.name_or_index == .duplicate) { - nonDuplicatedCount -= 1; - } - } - const ids = if (nonDuplicatedCount <= JSC.JSObject.maxInlineCapacity()) stack_ids[0..nonDuplicatedCount] else bun.default_allocator.alloc(JSC.JSObject.ExternColumnIdentifier, nonDuplicatedCount) catch bun.outOfMemory(); - - var i: usize = 0; - for (this.fields) |*field| { - if (field.name_or_index == .duplicate) continue; - - var id: *JSC.JSObject.ExternColumnIdentifier = &ids[i]; - switch (field.name_or_index) { - .name => |name| { - id.value.name = String.createAtomIfPossible(name.slice()); - }, - .index => |index| { - id.value.index = index; - }, - .duplicate => unreachable, - } - id.tag = switch (field.name_or_index) { - .name => 2, - .index => 1, - .duplicate => 0, - }; - i += 1; - } - - if (nonDuplicatedCount > JSC.JSObject.maxInlineCapacity()) { - this.cached_structure.set(globalObject, null, ids); - } else { - this.cached_structure.set(globalObject, JSC.JSObject.createStructure( - globalObject, - owner, - @truncate(ids.len), - ids.ptr, - ), null); - } - - return this.cached_structure; - } -}; - -const QueryBindingIterator = union(enum) { - array: JSC.JSArrayIterator, - objects: ObjectIterator, - - pub fn init(array: JSValue, columns: JSValue, globalObject: *JSC.JSGlobalObject) bun.JSError!QueryBindingIterator { - if (columns.isEmptyOrUndefinedOrNull()) { - return .{ .array = try JSC.JSArrayIterator.init(array, globalObject) }; - } - - return .{ - .objects = .{ - .array = array, - .columns = columns, - .globalObject = globalObject, - .columns_count = try columns.getLength(globalObject), - .array_length = try array.getLength(globalObject), - }, - }; - } - - pub const ObjectIterator = struct { - array: JSValue, - columns: JSValue = .zero, - globalObject: *JSC.JSGlobalObject, - cell_i: usize = 0, - row_i: usize = 0, - current_row: JSC.JSValue = .zero, - columns_count: usize = 0, - array_length: usize = 0, - any_failed: bool = false, - - pub fn next(this: *ObjectIterator) ?JSC.JSValue { - if (this.row_i >= this.array_length) { - return null; - } - - const cell_i = this.cell_i; - this.cell_i += 1; - const row_i = this.row_i; - - const globalObject = this.globalObject; - - if (this.current_row == .zero) { - this.current_row = JSC.JSObject.getIndex(this.array, globalObject, @intCast(row_i)) catch { - this.any_failed = true; - return null; - }; - if (this.current_row.isEmptyOrUndefinedOrNull()) { - return globalObject.throw("Expected a row to be returned at index {d}", .{row_i}) catch null; - } - } - - defer { - if (this.cell_i >= this.columns_count) { - this.cell_i = 0; - this.current_row = .zero; - this.row_i += 1; - } - } - - const property = JSC.JSObject.getIndex(this.columns, globalObject, @intCast(cell_i)) catch { - this.any_failed = true; - return null; - }; - if (property.isUndefined()) { - return globalObject.throw("Expected a column at index {d} in row {d}", .{ cell_i, row_i }) catch null; - } - - const value = this.current_row.getOwnByValue(globalObject, property); - if (value == .zero or (value != null and value.?.isUndefined())) { - if (!globalObject.hasException()) - return globalObject.throw("Expected a value at index {d} in row {d}", .{ cell_i, row_i }) catch null; - this.any_failed = true; - return null; - } - return value; - } - }; - - pub fn next(this: *QueryBindingIterator) bun.JSError!?JSC.JSValue { - return switch (this.*) { - .array => |*iter| iter.next(), - .objects => |*iter| iter.next(), - }; - } - - pub fn anyFailed(this: *const QueryBindingIterator) bool { - return switch (this.*) { - .array => false, - .objects => |*iter| iter.any_failed, - }; - } - - pub fn to(this: *QueryBindingIterator, index: u32) void { - switch (this.*) { - .array => |*iter| iter.i = index, - .objects => |*iter| { - iter.cell_i = index % iter.columns_count; - iter.row_i = index / iter.columns_count; - iter.current_row = .zero; - }, - } - } - - pub fn reset(this: *QueryBindingIterator) void { - switch (this.*) { - .array => |*iter| { - iter.i = 0; - }, - .objects => |*iter| { - iter.cell_i = 0; - iter.row_i = 0; - iter.current_row = .zero; - }, - } - } -}; - -const Signature = struct { - fields: []const int4, - name: []const u8, - query: []const u8, - prepared_statement_name: []const u8, - - pub fn empty() Signature { - return Signature{ - .fields = &[_]int4{}, - .name = &[_]u8{}, - .query = &[_]u8{}, - .prepared_statement_name = &[_]u8{}, - }; - } - - const log = bun.Output.scoped(.PostgresSignature, false); - pub fn deinit(this: *Signature) void { - if (this.prepared_statement_name.len > 0) { - bun.default_allocator.free(this.prepared_statement_name); - } - if (this.name.len > 0) { - bun.default_allocator.free(this.name); - } - if (this.fields.len > 0) { - bun.default_allocator.free(this.fields); - } - if (this.query.len > 0) { - bun.default_allocator.free(this.query); - } - } - - pub fn hash(this: *const Signature) u64 { - var hasher = std.hash.Wyhash.init(0); - hasher.update(this.name); - hasher.update(std.mem.sliceAsBytes(this.fields)); - return hasher.final(); - } - - pub fn generate(globalObject: *JSC.JSGlobalObject, query: []const u8, array_value: JSValue, columns: JSValue, prepared_statement_id: u64, unnamed: bool) !Signature { - var fields = std.ArrayList(int4).init(bun.default_allocator); - var name = try std.ArrayList(u8).initCapacity(bun.default_allocator, query.len); - - name.appendSliceAssumeCapacity(query); - - errdefer { - fields.deinit(); - name.deinit(); - } - - var iter = try QueryBindingIterator.init(array_value, columns, globalObject); - - while (try iter.next()) |value| { - if (value.isEmptyOrUndefinedOrNull()) { - // Allow postgres to decide the type - try fields.append(0); - try name.appendSlice(".null"); - continue; - } - - const tag = try types.Tag.fromJS(globalObject, value); - - switch (tag) { - .int8 => try name.appendSlice(".int8"), - .int4 => try name.appendSlice(".int4"), - // .int4_array => try name.appendSlice(".int4_array"), - .int2 => try name.appendSlice(".int2"), - .float8 => try name.appendSlice(".float8"), - .float4 => try name.appendSlice(".float4"), - .numeric => try name.appendSlice(".numeric"), - .json, .jsonb => try name.appendSlice(".json"), - .bool => try name.appendSlice(".bool"), - .timestamp => try name.appendSlice(".timestamp"), - .timestamptz => try name.appendSlice(".timestamptz"), - .bytea => try name.appendSlice(".bytea"), - else => try name.appendSlice(".string"), - } - - switch (tag) { - .bool, .int4, .int8, .float8, .int2, .numeric, .float4, .bytea => { - // We decide the type - try fields.append(@intFromEnum(tag)); - }, - else => { - // Allow postgres to decide the type - try fields.append(0); - }, - } - } - - if (iter.anyFailed()) { - return error.InvalidQueryBinding; - } - // max u64 length is 20, max prepared_statement_name length is 63 - const prepared_statement_name = if (unnamed) "" else try std.fmt.allocPrint(bun.default_allocator, "P{s}${d}", .{ name.items[0..@min(40, name.items.len)], prepared_statement_id }); - - return Signature{ - .prepared_statement_name = prepared_statement_name, - .name = name.items, - .fields = fields.items, - .query = try bun.default_allocator.dupe(u8, query), - }; - } -}; - pub fn createBinding(globalObject: *JSC.JSGlobalObject) JSValue { const binding = JSValue.createEmptyObjectWithNullPrototype(globalObject); binding.put(globalObject, ZigString.static("PostgresSQLConnection"), PostgresSQLConnection.js.getConstructor(globalObject)); @@ -3287,6 +17,15 @@ pub fn createBinding(globalObject: *JSC.JSGlobalObject) JSValue { return binding; } -const ZigString = JSC.ZigString; +// @sortImports -const assert = bun.assert; +pub const PostgresSQLConnection = @import("./postgres/PostgresSQLConnection.zig"); +pub const PostgresSQLContext = @import("./postgres/PostgresSQLContext.zig"); +pub const PostgresSQLQuery = @import("./postgres/PostgresSQLQuery.zig"); +const bun = @import("bun"); +pub const protocol = @import("./postgres/PostgresProtocol.zig"); +pub const types = @import("./postgres/PostgresTypes.zig"); + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; +const ZigString = JSC.ZigString; diff --git a/src/sql/postgres/AnyPostgresError.zig b/src/sql/postgres/AnyPostgresError.zig new file mode 100644 index 0000000000..04167ec521 --- /dev/null +++ b/src/sql/postgres/AnyPostgresError.zig @@ -0,0 +1,89 @@ +pub const AnyPostgresError = error{ + ConnectionClosed, + ExpectedRequest, + ExpectedStatement, + InvalidBackendKeyData, + InvalidBinaryData, + InvalidByteSequence, + InvalidByteSequenceForEncoding, + InvalidCharacter, + InvalidMessage, + InvalidMessageLength, + InvalidQueryBinding, + InvalidServerKey, + InvalidServerSignature, + JSError, + MultidimensionalArrayNotSupportedYet, + NullsInArrayNotSupportedYet, + OutOfMemory, + Overflow, + PBKDFD2, + SASL_SIGNATURE_MISMATCH, + SASL_SIGNATURE_INVALID_BASE64, + ShortRead, + TLSNotAvailable, + TLSUpgradeFailed, + UnexpectedMessage, + UNKNOWN_AUTHENTICATION_METHOD, + UNSUPPORTED_AUTHENTICATION_METHOD, + UnsupportedByteaFormat, + UnsupportedIntegerSize, + UnsupportedArrayFormat, + UnsupportedNumericFormat, + UnknownFormatCode, +}; + +pub fn postgresErrorToJS(globalObject: *JSC.JSGlobalObject, message: ?[]const u8, err: AnyPostgresError) JSValue { + const error_code: JSC.Error = switch (err) { + error.ConnectionClosed => .POSTGRES_CONNECTION_CLOSED, + error.ExpectedRequest => .POSTGRES_EXPECTED_REQUEST, + error.ExpectedStatement => .POSTGRES_EXPECTED_STATEMENT, + error.InvalidBackendKeyData => .POSTGRES_INVALID_BACKEND_KEY_DATA, + error.InvalidBinaryData => .POSTGRES_INVALID_BINARY_DATA, + error.InvalidByteSequence => .POSTGRES_INVALID_BYTE_SEQUENCE, + error.InvalidByteSequenceForEncoding => .POSTGRES_INVALID_BYTE_SEQUENCE_FOR_ENCODING, + error.InvalidCharacter => .POSTGRES_INVALID_CHARACTER, + error.InvalidMessage => .POSTGRES_INVALID_MESSAGE, + error.InvalidMessageLength => .POSTGRES_INVALID_MESSAGE_LENGTH, + error.InvalidQueryBinding => .POSTGRES_INVALID_QUERY_BINDING, + error.InvalidServerKey => .POSTGRES_INVALID_SERVER_KEY, + error.InvalidServerSignature => .POSTGRES_INVALID_SERVER_SIGNATURE, + error.MultidimensionalArrayNotSupportedYet => .POSTGRES_MULTIDIMENSIONAL_ARRAY_NOT_SUPPORTED_YET, + error.NullsInArrayNotSupportedYet => .POSTGRES_NULLS_IN_ARRAY_NOT_SUPPORTED_YET, + error.Overflow => .POSTGRES_OVERFLOW, + error.PBKDFD2 => .POSTGRES_AUTHENTICATION_FAILED_PBKDF2, + error.SASL_SIGNATURE_MISMATCH => .POSTGRES_SASL_SIGNATURE_MISMATCH, + error.SASL_SIGNATURE_INVALID_BASE64 => .POSTGRES_SASL_SIGNATURE_INVALID_BASE64, + error.TLSNotAvailable => .POSTGRES_TLS_NOT_AVAILABLE, + error.TLSUpgradeFailed => .POSTGRES_TLS_UPGRADE_FAILED, + error.UnexpectedMessage => .POSTGRES_UNEXPECTED_MESSAGE, + error.UNKNOWN_AUTHENTICATION_METHOD => .POSTGRES_UNKNOWN_AUTHENTICATION_METHOD, + error.UNSUPPORTED_AUTHENTICATION_METHOD => .POSTGRES_UNSUPPORTED_AUTHENTICATION_METHOD, + error.UnsupportedByteaFormat => .POSTGRES_UNSUPPORTED_BYTEA_FORMAT, + error.UnsupportedArrayFormat => .POSTGRES_UNSUPPORTED_ARRAY_FORMAT, + error.UnsupportedIntegerSize => .POSTGRES_UNSUPPORTED_INTEGER_SIZE, + error.UnsupportedNumericFormat => .POSTGRES_UNSUPPORTED_NUMERIC_FORMAT, + error.UnknownFormatCode => .POSTGRES_UNKNOWN_FORMAT_CODE, + error.JSError => { + return globalObject.takeException(error.JSError); + }, + error.OutOfMemory => { + // TODO: add binding for creating an out of memory error? + return globalObject.takeException(globalObject.throwOutOfMemory()); + }, + error.ShortRead => { + bun.unreachablePanic("Assertion failed: ShortRead should be handled by the caller in postgres", .{}); + }, + }; + if (message) |msg| { + return error_code.fmt(globalObject, "{s}", .{msg}); + } + return error_code.fmt(globalObject, "Failed to bind query: {s}", .{@errorName(err)}); +} + +// @sortImports + +const bun = @import("bun"); + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/AuthenticationState.zig b/src/sql/postgres/AuthenticationState.zig new file mode 100644 index 0000000000..97d19c0893 --- /dev/null +++ b/src/sql/postgres/AuthenticationState.zig @@ -0,0 +1,21 @@ +pub const AuthenticationState = union(enum) { + pending: void, + none: void, + ok: void, + SASL: SASL, + md5: void, + + pub fn zero(this: *AuthenticationState) void { + switch (this.*) { + .SASL => |*sasl| { + sasl.deinit(); + }, + else => {}, + } + this.* = .{ .none = {} }; + } +}; + +// @sortImports + +const SASL = @import("./SASL.zig"); diff --git a/src/sql/postgres/CommandTag.zig b/src/sql/postgres/CommandTag.zig new file mode 100644 index 0000000000..5c89426eb3 --- /dev/null +++ b/src/sql/postgres/CommandTag.zig @@ -0,0 +1,107 @@ +pub const CommandTag = union(enum) { + // For an INSERT command, the tag is INSERT oid rows, where rows is the + // number of rows inserted. oid used to be the object ID of the inserted + // row if rows was 1 and the target table had OIDs, but OIDs system + // columns are not supported anymore; therefore oid is always 0. + INSERT: u64, + // For a DELETE command, the tag is DELETE rows where rows is the number + // of rows deleted. + DELETE: u64, + // For an UPDATE command, the tag is UPDATE rows where rows is the + // number of rows updated. + UPDATE: u64, + // For a MERGE command, the tag is MERGE rows where rows is the number + // of rows inserted, updated, or deleted. + MERGE: u64, + // For a SELECT or CREATE TABLE AS command, the tag is SELECT rows where + // rows is the number of rows retrieved. + SELECT: u64, + // For a MOVE command, the tag is MOVE rows where rows is the number of + // rows the cursor's position has been changed by. + MOVE: u64, + // For a FETCH command, the tag is FETCH rows where rows is the number + // of rows that have been retrieved from the cursor. + FETCH: u64, + // For a COPY command, the tag is COPY rows where rows is the number of + // rows copied. (Note: the row count appears only in PostgreSQL 8.2 and + // later.) + COPY: u64, + + other: []const u8, + + pub fn toJSTag(this: CommandTag, globalObject: *JSC.JSGlobalObject) JSValue { + return switch (this) { + .INSERT => JSValue.jsNumber(1), + .DELETE => JSValue.jsNumber(2), + .UPDATE => JSValue.jsNumber(3), + .MERGE => JSValue.jsNumber(4), + .SELECT => JSValue.jsNumber(5), + .MOVE => JSValue.jsNumber(6), + .FETCH => JSValue.jsNumber(7), + .COPY => JSValue.jsNumber(8), + .other => |tag| JSC.ZigString.init(tag).toJS(globalObject), + }; + } + + pub fn toJSNumber(this: CommandTag) JSValue { + return switch (this) { + .other => JSValue.jsNumber(0), + inline else => |val| JSValue.jsNumber(val), + }; + } + + const KnownCommand = enum { + INSERT, + DELETE, + UPDATE, + MERGE, + SELECT, + MOVE, + FETCH, + COPY, + + pub const Map = bun.ComptimeEnumMap(KnownCommand); + }; + + pub fn init(tag: []const u8) CommandTag { + const first_space_index = bun.strings.indexOfChar(tag, ' ') orelse return .{ .other = tag }; + const cmd = KnownCommand.Map.get(tag[0..first_space_index]) orelse return .{ + .other = tag, + }; + + const number = brk: { + switch (cmd) { + .INSERT => { + var remaining = tag[@min(first_space_index + 1, tag.len)..]; + const second_space = bun.strings.indexOfChar(remaining, ' ') orelse return .{ .other = tag }; + remaining = remaining[@min(second_space + 1, remaining.len)..]; + break :brk std.fmt.parseInt(u64, remaining, 0) catch |err| { + debug("CommandTag failed to parse number: {s}", .{@errorName(err)}); + return .{ .other = tag }; + }; + }, + else => { + const after_tag = tag[@min(first_space_index + 1, tag.len)..]; + break :brk std.fmt.parseInt(u64, after_tag, 0) catch |err| { + debug("CommandTag failed to parse number: {s}", .{@errorName(err)}); + return .{ .other = tag }; + }; + }, + } + }; + + switch (cmd) { + inline else => |t| return @unionInit(CommandTag, @tagName(t), number), + } + } +}; + +const debug = bun.Output.scoped(.Postgres, false); + +// @sortImports + +const bun = @import("bun"); +const std = @import("std"); + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/ConnectionFlags.zig b/src/sql/postgres/ConnectionFlags.zig new file mode 100644 index 0000000000..49ad9d6f90 --- /dev/null +++ b/src/sql/postgres/ConnectionFlags.zig @@ -0,0 +1,7 @@ +pub const ConnectionFlags = packed struct { + is_ready_for_query: bool = false, + is_processing_data: bool = false, + use_unnamed_prepared_statements: bool = false, +}; + +// @sortImports diff --git a/src/sql/postgres/Data.zig b/src/sql/postgres/Data.zig new file mode 100644 index 0000000000..557d00fe49 --- /dev/null +++ b/src/sql/postgres/Data.zig @@ -0,0 +1,67 @@ +pub const Data = union(enum) { + owned: bun.ByteList, + temporary: []const u8, + empty: void, + + pub const Empty: Data = .{ .empty = {} }; + + pub fn toOwned(this: @This()) !bun.ByteList { + return switch (this) { + .owned => this.owned, + .temporary => bun.ByteList.init(try bun.default_allocator.dupe(u8, this.temporary)), + .empty => bun.ByteList.init(&.{}), + }; + } + + pub fn deinit(this: *@This()) void { + switch (this.*) { + .owned => this.owned.deinitWithAllocator(bun.default_allocator), + .temporary => {}, + .empty => {}, + } + } + + /// Zero bytes before deinit + /// Generally, for security reasons. + pub fn zdeinit(this: *@This()) void { + switch (this.*) { + .owned => { + + // Zero bytes before deinit + @memset(this.owned.slice(), 0); + + this.owned.deinitWithAllocator(bun.default_allocator); + }, + .temporary => {}, + .empty => {}, + } + } + + pub fn slice(this: @This()) []const u8 { + return switch (this) { + .owned => this.owned.slice(), + .temporary => this.temporary, + .empty => "", + }; + } + + pub fn substring(this: @This(), start_index: usize, end_index: usize) Data { + return switch (this) { + .owned => .{ .temporary = this.owned.slice()[start_index..end_index] }, + .temporary => .{ .temporary = this.temporary[start_index..end_index] }, + .empty => .{ .empty = {} }, + }; + } + + pub fn sliceZ(this: @This()) [:0]const u8 { + return switch (this) { + .owned => this.owned.slice()[0..this.owned.len :0], + .temporary => this.temporary[0..this.temporary.len :0], + .empty => "", + }; + } +}; + +// @sortImports + +const bun = @import("bun"); diff --git a/src/sql/DataCell.zig b/src/sql/postgres/DataCell.zig similarity index 99% rename from src/sql/DataCell.zig rename to src/sql/postgres/DataCell.zig index 3d841657f4..f2f8dd5f00 100644 --- a/src/sql/DataCell.zig +++ b/src/sql/postgres/DataCell.zig @@ -1085,19 +1085,23 @@ pub const DataCell = extern struct { }; }; +const debug = bun.Output.scoped(.Postgres, false); + +// @sortImports + +const PostgresCachedStructure = @import("./PostgresCachedStructure.zig"); +const protocol = @import("./PostgresProtocol.zig"); +const std = @import("std"); +const Data = @import("./Data.zig").Data; +const PostgresSQLQueryResultMode = @import("./PostgresSQLQueryResultMode.zig").PostgresSQLQueryResultMode; + +const types = @import("./PostgresTypes.zig"); +const AnyPostgresError = types.AnyPostgresError; +const int4 = types.int4; +const short = types.short; + const bun = @import("bun"); +const String = bun.String; const JSC = bun.JSC; -const std = @import("std"); const JSValue = JSC.JSValue; -const postgres = @import("./postgres.zig"); -const Data = postgres.Data; -const types = postgres.types; -const String = bun.String; -const int4 = postgres.int4; -const AnyPostgresError = postgres.AnyPostgresError; -const protocol = postgres.protocol; -const PostgresSQLQueryResultMode = postgres.PostgresSQLQueryResultMode; -const PostgresCachedStructure = postgres.PostgresCachedStructure; -const debug = postgres.debug; -const short = postgres.short; diff --git a/src/sql/postgres/DebugSocketMonitorReader.zig b/src/sql/postgres/DebugSocketMonitorReader.zig new file mode 100644 index 0000000000..19a95c58cd --- /dev/null +++ b/src/sql/postgres/DebugSocketMonitorReader.zig @@ -0,0 +1,25 @@ +var file: std.fs.File = undefined; +pub var enabled = false; +pub var check = std.once(load); + +pub fn load() void { + if (bun.getenvZAnyCase("BUN_POSTGRES_SOCKET_MONITOR_READER")) |monitor| { + enabled = true; + file = std.fs.cwd().createFile(monitor, .{ .truncate = true }) catch { + enabled = false; + return; + }; + debug("duplicating reads to {s}", .{monitor}); + } +} + +pub fn write(data: []const u8) void { + file.writeAll(data) catch {}; +} + +const debug = bun.Output.scoped(.Postgres, false); + +// @sortImports + +const bun = @import("bun"); +const std = @import("std"); diff --git a/src/sql/postgres/DebugSocketMonitorWriter.zig b/src/sql/postgres/DebugSocketMonitorWriter.zig new file mode 100644 index 0000000000..5dd43cdf79 --- /dev/null +++ b/src/sql/postgres/DebugSocketMonitorWriter.zig @@ -0,0 +1,25 @@ +var file: std.fs.File = undefined; +pub var enabled = false; +pub var check = std.once(load); + +pub fn write(data: []const u8) void { + file.writeAll(data) catch {}; +} + +pub fn load() void { + if (bun.getenvZAnyCase("BUN_POSTGRES_SOCKET_MONITOR")) |monitor| { + enabled = true; + file = std.fs.cwd().createFile(monitor, .{ .truncate = true }) catch { + enabled = false; + return; + }; + debug("writing to {s}", .{monitor}); + } +} + +const debug = bun.Output.scoped(.Postgres, false); + +// @sortImports + +const bun = @import("bun"); +const std = @import("std"); diff --git a/src/sql/postgres/ObjectIterator.zig b/src/sql/postgres/ObjectIterator.zig new file mode 100644 index 0000000000..4c8c6be7e9 --- /dev/null +++ b/src/sql/postgres/ObjectIterator.zig @@ -0,0 +1,64 @@ +array: JSValue, +columns: JSValue = .zero, +globalObject: *JSC.JSGlobalObject, +cell_i: usize = 0, +row_i: usize = 0, +current_row: JSC.JSValue = .zero, +columns_count: usize = 0, +array_length: usize = 0, +any_failed: bool = false, + +pub fn next(this: *ObjectIterator) ?JSC.JSValue { + if (this.row_i >= this.array_length) { + return null; + } + + const cell_i = this.cell_i; + this.cell_i += 1; + const row_i = this.row_i; + + const globalObject = this.globalObject; + + if (this.current_row == .zero) { + this.current_row = JSC.JSObject.getIndex(this.array, globalObject, @intCast(row_i)) catch { + this.any_failed = true; + return null; + }; + if (this.current_row.isEmptyOrUndefinedOrNull()) { + return globalObject.throw("Expected a row to be returned at index {d}", .{row_i}) catch null; + } + } + + defer { + if (this.cell_i >= this.columns_count) { + this.cell_i = 0; + this.current_row = .zero; + this.row_i += 1; + } + } + + const property = JSC.JSObject.getIndex(this.columns, globalObject, @intCast(cell_i)) catch { + this.any_failed = true; + return null; + }; + if (property.isUndefined()) { + return globalObject.throw("Expected a column at index {d} in row {d}", .{ cell_i, row_i }) catch null; + } + + const value = this.current_row.getOwnByValue(globalObject, property); + if (value == .zero or (value != null and value.?.isUndefined())) { + if (!globalObject.hasException()) + return globalObject.throw("Expected a value at index {d} in row {d}", .{ cell_i, row_i }) catch null; + this.any_failed = true; + return null; + } + return value; +} + +// @sortImports + +const ObjectIterator = @This(); +const bun = @import("bun"); + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/PostgresCachedStructure.zig b/src/sql/postgres/PostgresCachedStructure.zig new file mode 100644 index 0000000000..367e4c19c8 --- /dev/null +++ b/src/sql/postgres/PostgresCachedStructure.zig @@ -0,0 +1,34 @@ +structure: JSC.Strong.Optional = .empty, +// only populated if more than JSC.JSC__JSObject__maxInlineCapacity fields otherwise the structure will contain all fields inlined +fields: ?[]JSC.JSObject.ExternColumnIdentifier = null, + +pub fn has(this: *@This()) bool { + return this.structure.has() or this.fields != null; +} + +pub fn jsValue(this: *const @This()) ?JSC.JSValue { + return this.structure.get(); +} + +pub fn set(this: *@This(), globalObject: *JSC.JSGlobalObject, value: ?JSC.JSValue, fields: ?[]JSC.JSObject.ExternColumnIdentifier) void { + if (value) |v| { + this.structure.set(globalObject, v); + } + this.fields = fields; +} + +pub fn deinit(this: *@This()) void { + this.structure.deinit(); + if (this.fields) |fields| { + this.fields = null; + for (fields) |*name| { + name.deinit(); + } + bun.default_allocator.free(fields); + } +} + +// @sortImports + +const bun = @import("bun"); +const JSC = bun.JSC; diff --git a/src/sql/postgres/PostgresProtocol.zig b/src/sql/postgres/PostgresProtocol.zig new file mode 100644 index 0000000000..8f4ac063aa --- /dev/null +++ b/src/sql/postgres/PostgresProtocol.zig @@ -0,0 +1,63 @@ +pub const CloseComplete = [_]u8{'3'} ++ toBytes(Int32(4)); +pub const EmptyQueryResponse = [_]u8{'I'} ++ toBytes(Int32(4)); +pub const Terminate = [_]u8{'X'} ++ toBytes(Int32(4)); + +pub const BindComplete = [_]u8{'2'} ++ toBytes(Int32(4)); + +pub const ParseComplete = [_]u8{'1'} ++ toBytes(Int32(4)); + +pub const CopyDone = [_]u8{'c'} ++ toBytes(Int32(4)); +pub const Sync = [_]u8{'S'} ++ toBytes(Int32(4)); +pub const Flush = [_]u8{'H'} ++ toBytes(Int32(4)); +pub const SSLRequest = toBytes(Int32(8)) ++ toBytes(Int32(80877103)); +pub const NoData = [_]u8{'n'} ++ toBytes(Int32(4)); + +pub fn writeQuery(query: []const u8, comptime Context: type, writer: NewWriter(Context)) !void { + const count: u32 = @sizeOf((u32)) + @as(u32, @intCast(query.len)) + 1; + const header = [_]u8{ + 'Q', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(query); +} + +// @sortImports + +pub const ArrayList = @import("./protocol/ArrayList.zig"); +pub const BackendKeyData = @import("./protocol/BackendKeyData.zig"); +pub const CommandComplete = @import("./protocol/CommandComplete.zig"); +pub const CopyData = @import("./protocol/CopyData.zig"); +pub const CopyFail = @import("./protocol/CopyFail.zig"); +pub const DataRow = @import("./protocol/DataRow.zig"); +pub const Describe = @import("./protocol/Describe.zig"); +pub const ErrorResponse = @import("./protocol/ErrorResponse.zig"); +pub const Execute = @import("./protocol/Execute.zig"); +pub const FieldDescription = @import("./protocol/FieldDescription.zig"); +pub const NegotiateProtocolVersion = @import("./protocol/NegotiateProtocolVersion.zig"); +pub const NoticeResponse = @import("./protocol/NoticeResponse.zig"); +pub const NotificationResponse = @import("./protocol/NotificationResponse.zig"); +pub const ParameterDescription = @import("./protocol/ParameterDescription.zig"); +pub const ParameterStatus = @import("./protocol/ParameterStatus.zig"); +pub const Parse = @import("./protocol/Parse.zig"); +pub const PasswordMessage = @import("./protocol/PasswordMessage.zig"); +pub const ReadyForQuery = @import("./protocol/ReadyForQuery.zig"); +pub const RowDescription = @import("./protocol/RowDescription.zig"); +pub const SASLInitialResponse = @import("./protocol/SASLInitialResponse.zig"); +pub const SASLResponse = @import("./protocol/SASLResponse.zig"); +pub const StackReader = @import("./protocol/StackReader.zig"); +pub const StartupMessage = @import("./protocol/StartupMessage.zig"); +const std = @import("std"); +const types = @import("./PostgresTypes.zig"); +pub const Authentication = @import("./protocol/Authentication.zig").Authentication; +pub const ColumnIdentifier = @import("./protocol/ColumnIdentifier.zig").ColumnIdentifier; +pub const DecoderWrap = @import("./protocol/DecoderWrap.zig").DecoderWrap; +pub const FieldMessage = @import("./protocol/FieldMessage.zig").FieldMessage; +pub const FieldType = @import("./protocol/FieldType.zig").FieldType; +pub const NewReader = @import("./protocol/NewReader.zig").NewReader; +pub const NewWriter = @import("./protocol/NewWriter.zig").NewWriter; +pub const PortalOrPreparedStatement = @import("./protocol/PortalOrPreparedStatement.zig").PortalOrPreparedStatement; +pub const WriteWrap = @import("./protocol/WriteWrap.zig").WriteWrap; +const toBytes = std.mem.toBytes; + +const int_types = @import("./types/int_types.zig"); +const Int32 = int_types.Int32; diff --git a/src/sql/postgres/PostgresRequest.zig b/src/sql/postgres/PostgresRequest.zig new file mode 100644 index 0000000000..c8769f4047 --- /dev/null +++ b/src/sql/postgres/PostgresRequest.zig @@ -0,0 +1,348 @@ +pub fn writeBind( + name: []const u8, + cursor_name: bun.String, + globalObject: *JSC.JSGlobalObject, + values_array: JSValue, + columns_value: JSValue, + parameter_fields: []const int4, + result_fields: []const protocol.FieldDescription, + comptime Context: type, + writer: protocol.NewWriter(Context), +) !void { + try writer.write("B"); + const length = try writer.length(); + + try writer.String(cursor_name); + try writer.string(name); + + const len: u32 = @truncate(parameter_fields.len); + + // The number of parameter format codes that follow (denoted C + // below). This can be zero to indicate that there are no + // parameters or that the parameters all use the default format + // (text); or one, in which case the specified format code is + // applied to all parameters; or it can equal the actual number + // of parameters. + try writer.short(len); + + var iter = try QueryBindingIterator.init(values_array, columns_value, globalObject); + for (0..len) |i| { + const parameter_field = parameter_fields[i]; + const is_custom_type = std.math.maxInt(short) < parameter_field; + const tag: types.Tag = if (is_custom_type) .text else @enumFromInt(@as(short, @intCast(parameter_field))); + + const force_text = is_custom_type or (tag.isBinaryFormatSupported() and brk: { + iter.to(@truncate(i)); + if (try iter.next()) |value| { + break :brk value.isString(); + } + if (iter.anyFailed()) { + return error.InvalidQueryBinding; + } + break :brk false; + }); + + if (force_text) { + // If they pass a value as a string, let's avoid attempting to + // convert it to the binary representation. This minimizes the room + // for mistakes on our end, such as stripping the timezone + // differently than what Postgres does when given a timestamp with + // timezone. + try writer.short(0); + continue; + } + + try writer.short( + tag.formatCode(), + ); + } + + // The number of parameter values that follow (possibly zero). This + // must match the number of parameters needed by the query. + try writer.short(len); + + debug("Bind: {} ({d} args)", .{ bun.fmt.quote(name), len }); + iter.to(0); + var i: usize = 0; + while (try iter.next()) |value| : (i += 1) { + const tag: types.Tag = brk: { + if (i >= len) { + // parameter in array but not in parameter_fields + // this is probably a bug a bug in bun lets return .text here so the server will send a error 08P01 + // with will describe better the error saying exactly how many parameters are missing and are expected + // Example: + // SQL error: PostgresError: bind message supplies 0 parameters, but prepared statement "PSELECT * FROM test_table WHERE id=$1 .in$0" requires 1 + // errno: "08P01", + // code: "ERR_POSTGRES_SERVER_ERROR" + break :brk .text; + } + const parameter_field = parameter_fields[i]; + const is_custom_type = std.math.maxInt(short) < parameter_field; + break :brk if (is_custom_type) .text else @enumFromInt(@as(short, @intCast(parameter_field))); + }; + if (value.isEmptyOrUndefinedOrNull()) { + debug(" -> NULL", .{}); + // As a special case, -1 indicates a + // NULL parameter value. No value bytes follow in the NULL case. + try writer.int4(@bitCast(@as(i32, -1))); + continue; + } + if (comptime bun.Environment.enable_logs) { + debug(" -> {s}", .{tag.tagName() orelse "(unknown)"}); + } + + switch ( + // If they pass a value as a string, let's avoid attempting to + // convert it to the binary representation. This minimizes the room + // for mistakes on our end, such as stripping the timezone + // differently than what Postgres does when given a timestamp with + // timezone. + if (tag.isBinaryFormatSupported() and value.isString()) .text else tag) { + .jsonb, .json => { + var str = bun.String.empty; + defer str.deref(); + try value.jsonStringify(globalObject, 0, &str); + const slice = str.toUTF8WithoutRef(bun.default_allocator); + defer slice.deinit(); + const l = try writer.length(); + try writer.write(slice.slice()); + try l.writeExcludingSelf(); + }, + .bool => { + const l = try writer.length(); + try writer.write(&[1]u8{@intFromBool(value.toBoolean())}); + try l.writeExcludingSelf(); + }, + .timestamp, .timestamptz => { + const l = try writer.length(); + try writer.int8(types.date.fromJS(globalObject, value)); + try l.writeExcludingSelf(); + }, + .bytea => { + var bytes: []const u8 = ""; + if (value.asArrayBuffer(globalObject)) |buf| { + bytes = buf.byteSlice(); + } + const l = try writer.length(); + debug(" {d} bytes", .{bytes.len}); + + try writer.write(bytes); + try l.writeExcludingSelf(); + }, + .int4 => { + const l = try writer.length(); + try writer.int4(@bitCast(try value.coerceToInt32(globalObject))); + try l.writeExcludingSelf(); + }, + .int4_array => { + const l = try writer.length(); + try writer.int4(@bitCast(try value.coerceToInt32(globalObject))); + try l.writeExcludingSelf(); + }, + .float8 => { + const l = try writer.length(); + try writer.f64(@bitCast(try value.toNumber(globalObject))); + try l.writeExcludingSelf(); + }, + + else => { + const str = try String.fromJS(value, globalObject); + if (str.tag == .Dead) return error.OutOfMemory; + defer str.deref(); + const slice = str.toUTF8WithoutRef(bun.default_allocator); + defer slice.deinit(); + const l = try writer.length(); + try writer.write(slice.slice()); + try l.writeExcludingSelf(); + }, + } + } + + var any_non_text_fields: bool = false; + for (result_fields) |field| { + if (field.typeTag().isBinaryFormatSupported()) { + any_non_text_fields = true; + break; + } + } + + if (any_non_text_fields) { + try writer.short(result_fields.len); + for (result_fields) |field| { + try writer.short( + field.typeTag().formatCode(), + ); + } + } else { + try writer.short(0); + } + + try length.write(); +} + +pub fn writeQuery( + query: []const u8, + name: []const u8, + params: []const int4, + comptime Context: type, + writer: protocol.NewWriter(Context), +) AnyPostgresError!void { + { + var q = protocol.Parse{ + .name = name, + .params = params, + .query = query, + }; + try q.writeInternal(Context, writer); + debug("Parse: {}", .{bun.fmt.quote(query)}); + } + + { + var d = protocol.Describe{ + .p = .{ + .prepared_statement = name, + }, + }; + try d.writeInternal(Context, writer); + debug("Describe: {}", .{bun.fmt.quote(name)}); + } +} + +pub fn prepareAndQueryWithSignature( + globalObject: *JSC.JSGlobalObject, + query: []const u8, + array_value: JSValue, + comptime Context: type, + writer: protocol.NewWriter(Context), + signature: *Signature, +) AnyPostgresError!void { + try writeQuery(query, signature.prepared_statement_name, signature.fields, Context, writer); + try writeBind(signature.prepared_statement_name, bun.String.empty, globalObject, array_value, .zero, &.{}, &.{}, Context, writer); + var exec = protocol.Execute{ + .p = .{ + .prepared_statement = signature.prepared_statement_name, + }, + }; + try exec.writeInternal(Context, writer); + + try writer.write(&protocol.Flush); + try writer.write(&protocol.Sync); +} + +pub fn bindAndExecute( + globalObject: *JSC.JSGlobalObject, + statement: *PostgresSQLStatement, + array_value: JSValue, + columns_value: JSValue, + comptime Context: type, + writer: protocol.NewWriter(Context), +) !void { + try writeBind(statement.signature.prepared_statement_name, bun.String.empty, globalObject, array_value, columns_value, statement.parameters, statement.fields, Context, writer); + var exec = protocol.Execute{ + .p = .{ + .prepared_statement = statement.signature.prepared_statement_name, + }, + }; + try exec.writeInternal(Context, writer); + + try writer.write(&protocol.Flush); + try writer.write(&protocol.Sync); +} + +pub fn executeQuery( + query: []const u8, + comptime Context: type, + writer: protocol.NewWriter(Context), +) !void { + try protocol.writeQuery(query, Context, writer); + try writer.write(&protocol.Flush); + try writer.write(&protocol.Sync); +} + +pub fn onData( + connection: *PostgresSQLConnection, + comptime Context: type, + reader: protocol.NewReader(Context), +) !void { + while (true) { + reader.markMessageStart(); + const c = try reader.int(u8); + debug("read: {c}", .{c}); + switch (c) { + 'D' => try connection.on(.DataRow, Context, reader), + 'd' => try connection.on(.CopyData, Context, reader), + 'S' => { + if (connection.tls_status == .message_sent) { + bun.debugAssert(connection.tls_status.message_sent == 8); + connection.tls_status = .ssl_ok; + connection.setupTLS(); + return; + } + + try connection.on(.ParameterStatus, Context, reader); + }, + 'Z' => try connection.on(.ReadyForQuery, Context, reader), + 'C' => try connection.on(.CommandComplete, Context, reader), + '2' => try connection.on(.BindComplete, Context, reader), + '1' => try connection.on(.ParseComplete, Context, reader), + 't' => try connection.on(.ParameterDescription, Context, reader), + 'T' => try connection.on(.RowDescription, Context, reader), + 'R' => try connection.on(.Authentication, Context, reader), + 'n' => try connection.on(.NoData, Context, reader), + 'K' => try connection.on(.BackendKeyData, Context, reader), + 'E' => try connection.on(.ErrorResponse, Context, reader), + 's' => try connection.on(.PortalSuspended, Context, reader), + '3' => try connection.on(.CloseComplete, Context, reader), + 'G' => try connection.on(.CopyInResponse, Context, reader), + 'N' => { + if (connection.tls_status == .message_sent) { + connection.tls_status = .ssl_not_available; + debug("Server does not support SSL", .{}); + if (connection.ssl_mode == .require) { + connection.fail("Server does not support SSL", error.TLSNotAvailable); + return; + } + continue; + } + + try connection.on(.NoticeResponse, Context, reader); + }, + 'I' => try connection.on(.EmptyQueryResponse, Context, reader), + 'H' => try connection.on(.CopyOutResponse, Context, reader), + 'c' => try connection.on(.CopyDone, Context, reader), + 'W' => try connection.on(.CopyBothResponse, Context, reader), + + else => { + debug("Unknown message: {c}", .{c}); + const to_skip = try reader.length() -| 1; + debug("to_skip: {d}", .{to_skip}); + try reader.skip(@intCast(@max(to_skip, 0))); + }, + } + } +} + +pub const Queue = std.fifo.LinearFifo(*PostgresSQLQuery, .Dynamic); + +const debug = bun.Output.scoped(.Postgres, false); + +// @sortImports + +const PostgresSQLConnection = @import("./PostgresSQLConnection.zig"); +const PostgresSQLQuery = @import("./PostgresSQLQuery.zig"); +const PostgresSQLStatement = @import("./PostgresSQLStatement.zig"); +const Signature = @import("./Signature.zig"); +const protocol = @import("./PostgresProtocol.zig"); +const std = @import("std"); +const QueryBindingIterator = @import("./QueryBindingIterator.zig").QueryBindingIterator; + +const types = @import("./PostgresTypes.zig"); +const AnyPostgresError = @import("./PostgresTypes.zig").AnyPostgresError; +const int4 = types.int4; +const short = types.short; + +const bun = @import("bun"); +const String = bun.String; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/PostgresSQLConnection.zig b/src/sql/postgres/PostgresSQLConnection.zig new file mode 100644 index 0000000000..4ee0e3c8f2 --- /dev/null +++ b/src/sql/postgres/PostgresSQLConnection.zig @@ -0,0 +1,1574 @@ +socket: Socket, +status: Status = Status.connecting, +ref_count: u32 = 1, + +write_buffer: bun.OffsetByteList = .{}, +read_buffer: bun.OffsetByteList = .{}, +last_message_start: u32 = 0, +requests: PostgresRequest.Queue, + +poll_ref: bun.Async.KeepAlive = .{}, +globalObject: *JSC.JSGlobalObject, + +statements: PreparedStatementsMap, +prepared_statement_id: u64 = 0, +pending_activity_count: std.atomic.Value(u32) = std.atomic.Value(u32).init(0), +js_value: JSValue = .js_undefined, + +backend_parameters: bun.StringMap = bun.StringMap.init(bun.default_allocator, true), +backend_key_data: protocol.BackendKeyData = .{}, + +database: []const u8 = "", +user: []const u8 = "", +password: []const u8 = "", +path: []const u8 = "", +options: []const u8 = "", +options_buf: []const u8 = "", + +authentication_state: AuthenticationState = .{ .pending = {} }, + +tls_ctx: ?*uws.SocketContext = null, +tls_config: JSC.API.ServerConfig.SSLConfig = .{}, +tls_status: TLSStatus = .none, +ssl_mode: SSLMode = .disable, + +idle_timeout_interval_ms: u32 = 0, +connection_timeout_ms: u32 = 0, + +flags: ConnectionFlags = .{}, + +/// Before being connected, this is a connection timeout timer. +/// After being connected, this is an idle timeout timer. +timer: bun.api.Timer.EventLoopTimer = .{ + .tag = .PostgresSQLConnectionTimeout, + .next = .{ + .sec = 0, + .nsec = 0, + }, +}, + +/// This timer controls the maximum lifetime of a connection. +/// It starts when the connection successfully starts (i.e. after handshake is complete). +/// It stops when the connection is closed. +max_lifetime_interval_ms: u32 = 0, +max_lifetime_timer: bun.api.Timer.EventLoopTimer = .{ + .tag = .PostgresSQLConnectionMaxLifetime, + .next = .{ + .sec = 0, + .nsec = 0, + }, +}, + +fn getTimeoutInterval(this: *const PostgresSQLConnection) u32 { + return switch (this.status) { + .connected => this.idle_timeout_interval_ms, + .failed => 0, + else => this.connection_timeout_ms, + }; +} +pub fn disableConnectionTimeout(this: *PostgresSQLConnection) void { + if (this.timer.state == .ACTIVE) { + this.globalObject.bunVM().timer.remove(&this.timer); + } + this.timer.state = .CANCELLED; +} +pub fn resetConnectionTimeout(this: *PostgresSQLConnection) void { + // if we are processing data, don't reset the timeout, wait for the data to be processed + if (this.flags.is_processing_data) return; + const interval = this.getTimeoutInterval(); + if (this.timer.state == .ACTIVE) { + this.globalObject.bunVM().timer.remove(&this.timer); + } + if (interval == 0) { + return; + } + + this.timer.next = bun.timespec.msFromNow(@intCast(interval)); + this.globalObject.bunVM().timer.insert(&this.timer); +} + +pub fn getQueries(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) bun.JSError!JSC.JSValue { + if (js.queriesGetCached(thisValue)) |value| { + return value; + } + + const array = try JSC.JSValue.createEmptyArray(globalObject, 0); + js.queriesSetCached(thisValue, globalObject, array); + + return array; +} + +pub fn getOnConnect(_: *PostgresSQLConnection, thisValue: JSC.JSValue, _: *JSC.JSGlobalObject) JSC.JSValue { + if (js.onconnectGetCached(thisValue)) |value| { + return value; + } + + return .js_undefined; +} + +pub fn setOnConnect(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject, value: JSC.JSValue) void { + js.onconnectSetCached(thisValue, globalObject, value); +} + +pub fn getOnClose(_: *PostgresSQLConnection, thisValue: JSC.JSValue, _: *JSC.JSGlobalObject) JSC.JSValue { + if (js.oncloseGetCached(thisValue)) |value| { + return value; + } + + return .js_undefined; +} + +pub fn setOnClose(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject, value: JSC.JSValue) void { + js.oncloseSetCached(thisValue, globalObject, value); +} + +pub fn setupTLS(this: *PostgresSQLConnection) void { + debug("setupTLS", .{}); + const new_socket = this.socket.SocketTCP.socket.connected.upgrade(this.tls_ctx.?, this.tls_config.server_name) orelse { + this.fail("Failed to upgrade to TLS", error.TLSUpgradeFailed); + return; + }; + this.socket = .{ + .SocketTLS = .{ + .socket = .{ + .connected = new_socket, + }, + }, + }; + + this.start(); +} +fn setupMaxLifetimeTimerIfNecessary(this: *PostgresSQLConnection) void { + if (this.max_lifetime_interval_ms == 0) return; + if (this.max_lifetime_timer.state == .ACTIVE) return; + + this.max_lifetime_timer.next = bun.timespec.msFromNow(@intCast(this.max_lifetime_interval_ms)); + this.globalObject.bunVM().timer.insert(&this.max_lifetime_timer); +} + +pub fn onConnectionTimeout(this: *PostgresSQLConnection) bun.api.Timer.EventLoopTimer.Arm { + debug("onConnectionTimeout", .{}); + + this.timer.state = .FIRED; + if (this.flags.is_processing_data) { + return .disarm; + } + + if (this.getTimeoutInterval() == 0) { + this.resetConnectionTimeout(); + return .disarm; + } + + switch (this.status) { + .connected => { + this.failFmt(.POSTGRES_IDLE_TIMEOUT, "Idle timeout reached after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.idle_timeout_interval_ms) *| std.time.ns_per_ms)}); + }, + else => { + this.failFmt(.POSTGRES_CONNECTION_TIMEOUT, "Connection timeout after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.connection_timeout_ms) *| std.time.ns_per_ms)}); + }, + .sent_startup_message => { + this.failFmt(.POSTGRES_CONNECTION_TIMEOUT, "Connection timed out after {} (sent startup message, but never received response)", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.connection_timeout_ms) *| std.time.ns_per_ms)}); + }, + } + return .disarm; +} + +pub fn onMaxLifetimeTimeout(this: *PostgresSQLConnection) bun.api.Timer.EventLoopTimer.Arm { + debug("onMaxLifetimeTimeout", .{}); + this.max_lifetime_timer.state = .FIRED; + if (this.status == .failed) return .disarm; + this.failFmt(.POSTGRES_LIFETIME_TIMEOUT, "Max lifetime timeout reached after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.max_lifetime_interval_ms) *| std.time.ns_per_ms)}); + return .disarm; +} + +fn start(this: *PostgresSQLConnection) void { + this.setupMaxLifetimeTimerIfNecessary(); + this.resetConnectionTimeout(); + this.sendStartupMessage(); + + const event_loop = this.globalObject.bunVM().eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + this.flushData(); +} + +pub fn hasPendingActivity(this: *PostgresSQLConnection) bool { + return this.pending_activity_count.load(.acquire) > 0; +} + +fn updateHasPendingActivity(this: *PostgresSQLConnection) void { + const a: u32 = if (this.requests.readableLength() > 0) 1 else 0; + const b: u32 = if (this.status != .disconnected) 1 else 0; + this.pending_activity_count.store(a + b, .release); +} + +pub fn setStatus(this: *PostgresSQLConnection, status: Status) void { + if (this.status == status) return; + defer this.updateHasPendingActivity(); + + this.status = status; + this.resetConnectionTimeout(); + + switch (status) { + .connected => { + const on_connect = this.consumeOnConnectCallback(this.globalObject) orelse return; + const js_value = this.js_value; + js_value.ensureStillAlive(); + this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ JSValue.jsNull(), js_value }); + this.poll_ref.unref(this.globalObject.bunVM()); + }, + else => {}, + } +} + +pub fn finalize(this: *PostgresSQLConnection) void { + debug("PostgresSQLConnection finalize", .{}); + this.stopTimers(); + this.js_value = .zero; + this.deref(); +} + +pub fn flushDataAndResetTimeout(this: *PostgresSQLConnection) void { + this.resetConnectionTimeout(); + this.flushData(); +} + +pub fn flushData(this: *PostgresSQLConnection) void { + const chunk = this.write_buffer.remaining(); + if (chunk.len == 0) return; + const wrote = this.socket.write(chunk); + if (wrote > 0) { + SocketMonitor.write(chunk[0..@intCast(wrote)]); + this.write_buffer.consume(@intCast(wrote)); + } +} + +pub fn failWithJSValue(this: *PostgresSQLConnection, value: JSValue) void { + defer this.updateHasPendingActivity(); + this.stopTimers(); + if (this.status == .failed) return; + + this.status = .failed; + + this.ref(); + defer this.deref(); + // we defer the refAndClose so the on_close will be called first before we reject the pending requests + defer this.refAndClose(value); + const on_close = this.consumeOnCloseCallback(this.globalObject) orelse return; + + const loop = this.globalObject.bunVM().eventLoop(); + loop.enter(); + defer loop.exit(); + _ = on_close.call( + this.globalObject, + this.js_value, + &[_]JSValue{ + value, + this.getQueriesArray(), + }, + ) catch |e| this.globalObject.reportActiveExceptionAsUnhandled(e); +} + +pub fn failFmt(this: *PostgresSQLConnection, comptime error_code: JSC.Error, comptime fmt: [:0]const u8, args: anytype) void { + this.failWithJSValue(error_code.fmt(this.globalObject, fmt, args)); +} + +pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: AnyPostgresError) void { + debug("failed: {s}: {s}", .{ message, @errorName(err) }); + + const globalObject = this.globalObject; + + this.failWithJSValue(postgresErrorToJS(globalObject, message, err)); +} + +pub fn onClose(this: *PostgresSQLConnection) void { + var vm = this.globalObject.bunVM(); + const loop = vm.eventLoop(); + loop.enter(); + defer loop.exit(); + this.poll_ref.unref(this.globalObject.bunVM()); + + this.fail("Connection closed", error.ConnectionClosed); +} + +fn sendStartupMessage(this: *PostgresSQLConnection) void { + if (this.status != .connecting) return; + debug("sendStartupMessage", .{}); + this.status = .sent_startup_message; + var msg = protocol.StartupMessage{ + .user = Data{ .temporary = this.user }, + .database = Data{ .temporary = this.database }, + .options = Data{ .temporary = this.options }, + }; + msg.writeInternal(Writer, this.writer()) catch |err| { + this.fail("Failed to write startup message", err); + }; +} + +fn startTLS(this: *PostgresSQLConnection, socket: uws.AnySocket) void { + debug("startTLS", .{}); + const offset = switch (this.tls_status) { + .message_sent => |count| count, + else => 0, + }; + const ssl_request = [_]u8{ + 0x00, 0x00, 0x00, 0x08, // Length + 0x04, 0xD2, 0x16, 0x2F, // SSL request code + }; + + const written = socket.write(ssl_request[offset..]); + if (written > 0) { + this.tls_status = .{ + .message_sent = offset + @as(u8, @intCast(written)), + }; + } else { + this.tls_status = .{ + .message_sent = offset, + }; + } +} + +pub fn onOpen(this: *PostgresSQLConnection, socket: uws.AnySocket) void { + this.socket = socket; + + this.poll_ref.ref(this.globalObject.bunVM()); + this.updateHasPendingActivity(); + + if (this.tls_status == .message_sent or this.tls_status == .pending) { + this.startTLS(socket); + return; + } + + this.start(); +} + +pub fn onHandshake(this: *PostgresSQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + debug("onHandshake: {d} {d}", .{ success, ssl_error.error_no }); + const handshake_success = if (success == 1) true else false; + if (handshake_success) { + if (this.tls_config.reject_unauthorized != 0) { + // only reject the connection if reject_unauthorized == true + switch (this.ssl_mode) { + // https://github.com/porsager/postgres/blob/6ec85a432b17661ccacbdf7f765c651e88969d36/src/connection.js#L272-L279 + + .verify_ca, .verify_full => { + if (ssl_error.error_no != 0) { + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + return; + } + + const ssl_ptr: *BoringSSL.c.SSL = @ptrCast(this.socket.getNativeHandle()); + if (BoringSSL.c.SSL_get_servername(ssl_ptr, 0)) |servername| { + const hostname = servername[0..bun.len(servername)]; + if (!BoringSSL.checkServerIdentity(ssl_ptr, hostname)) { + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + } + } + }, + else => { + return; + }, + } + } + } else { + // if we are here is because server rejected us, and the error_no is the cause of this + // no matter if reject_unauthorized is false because we are disconnected by the server + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + } +} + +pub fn onTimeout(this: *PostgresSQLConnection) void { + _ = this; + debug("onTimeout", .{}); +} + +pub fn onDrain(this: *PostgresSQLConnection) void { + + // Don't send any other messages while we're waiting for TLS. + if (this.tls_status == .message_sent) { + if (this.tls_status.message_sent < 8) { + this.startTLS(this.socket); + } + + return; + } + + const event_loop = this.globalObject.bunVM().eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + this.flushData(); +} + +pub fn onData(this: *PostgresSQLConnection, data: []const u8) void { + this.ref(); + this.flags.is_processing_data = true; + const vm = this.globalObject.bunVM(); + + this.disableConnectionTimeout(); + defer { + if (this.status == .connected and !this.hasQueryRunning() and this.write_buffer.remaining().len == 0) { + // Don't keep the process alive when there's nothing to do. + this.poll_ref.unref(vm); + } else if (this.status == .connected) { + // Keep the process alive if there's something to do. + this.poll_ref.ref(vm); + } + this.flags.is_processing_data = false; + + // reset the connection timeout after we're done processing the data + this.resetConnectionTimeout(); + this.deref(); + } + + const event_loop = vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + SocketMonitor.read(data); + // reset the head to the last message so remaining reflects the right amount of bytes + this.read_buffer.head = this.last_message_start; + + if (this.read_buffer.remaining().len == 0) { + var consumed: usize = 0; + var offset: usize = 0; + const reader = protocol.StackReader.init(data, &consumed, &offset); + PostgresRequest.onData(this, protocol.StackReader, reader) catch |err| { + if (err == error.ShortRead) { + if (comptime bun.Environment.allow_assert) { + debug("read_buffer: empty and received short read: last_message_start: {d}, head: {d}, len: {d}", .{ + offset, + consumed, + data.len, + }); + } + + this.read_buffer.head = 0; + this.last_message_start = 0; + this.read_buffer.byte_list.len = 0; + this.read_buffer.write(bun.default_allocator, data[offset..]) catch @panic("failed to write to read buffer"); + } else { + bun.handleErrorReturnTrace(err, @errorReturnTrace()); + + this.fail("Failed to read data", err); + } + }; + // no need to reset anything, its already empty + return; + } + // read buffer is not empty, so we need to write the data to the buffer and then read it + this.read_buffer.write(bun.default_allocator, data) catch @panic("failed to write to read buffer"); + PostgresRequest.onData(this, Reader, this.bufferedReader()) catch |err| { + if (err != error.ShortRead) { + bun.handleErrorReturnTrace(err, @errorReturnTrace()); + this.fail("Failed to read data", err); + return; + } + + if (comptime bun.Environment.allow_assert) { + debug("read_buffer: not empty and received short read: last_message_start: {d}, head: {d}, len: {d}", .{ + this.last_message_start, + this.read_buffer.head, + this.read_buffer.byte_list.len, + }); + } + return; + }; + + debug("clean read_buffer", .{}); + // success, we read everything! let's reset the last message start and the head + this.last_message_start = 0; + this.read_buffer.head = 0; +} + +pub fn constructor(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!*PostgresSQLConnection { + _ = callframe; + return globalObject.throw("PostgresSQLConnection cannot be constructed directly", .{}); +} + +comptime { + const jscall = JSC.toJSHostFn(call); + @export(&jscall, .{ .name = "PostgresSQLConnection__createInstance" }); +} + +pub fn call(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { + var vm = globalObject.bunVM(); + const arguments = callframe.arguments_old(15).slice(); + const hostname_str = try arguments[0].toBunString(globalObject); + defer hostname_str.deref(); + const port = try arguments[1].coerce(i32, globalObject); + + const username_str = try arguments[2].toBunString(globalObject); + defer username_str.deref(); + const password_str = try arguments[3].toBunString(globalObject); + defer password_str.deref(); + const database_str = try arguments[4].toBunString(globalObject); + defer database_str.deref(); + const ssl_mode: SSLMode = switch (arguments[5].toInt32()) { + 0 => .disable, + 1 => .prefer, + 2 => .require, + 3 => .verify_ca, + 4 => .verify_full, + else => .disable, + }; + + const tls_object = arguments[6]; + + var tls_config: JSC.API.ServerConfig.SSLConfig = .{}; + var tls_ctx: ?*uws.SocketContext = null; + if (ssl_mode != .disable) { + tls_config = if (tls_object.isBoolean() and tls_object.toBoolean()) + .{} + else if (tls_object.isObject()) + (JSC.API.ServerConfig.SSLConfig.fromJS(vm, globalObject, tls_object) catch return .zero) orelse .{} + else { + return globalObject.throwInvalidArguments("tls must be a boolean or an object", .{}); + }; + + if (globalObject.hasException()) { + tls_config.deinit(); + return .zero; + } + + // we always request the cert so we can verify it and also we manually abort the connection if the hostname doesn't match + const original_reject_unauthorized = tls_config.reject_unauthorized; + tls_config.reject_unauthorized = 0; + tls_config.request_cert = 1; + // We create it right here so we can throw errors early. + const context_options = tls_config.asUSockets(); + var err: uws.create_bun_socket_error_t = .none; + tls_ctx = uws.SocketContext.createSSLContext(vm.uwsLoop(), @sizeOf(*PostgresSQLConnection), context_options, &err) orelse { + if (err != .none) { + return globalObject.throw("failed to create TLS context", .{}); + } else { + return globalObject.throwValue(err.toJS(globalObject)); + } + }; + // restore the original reject_unauthorized + tls_config.reject_unauthorized = original_reject_unauthorized; + if (err != .none) { + tls_config.deinit(); + if (tls_ctx) |ctx| { + ctx.deinit(true); + } + return globalObject.throwValue(err.toJS(globalObject)); + } + + uws.NewSocketHandler(true).configure(tls_ctx.?, true, *PostgresSQLConnection, SocketHandler(true)); + } + + var username: []const u8 = ""; + var password: []const u8 = ""; + var database: []const u8 = ""; + var options: []const u8 = ""; + var path: []const u8 = ""; + + const options_str = try arguments[7].toBunString(globalObject); + defer options_str.deref(); + + const path_str = try arguments[8].toBunString(globalObject); + defer path_str.deref(); + + const options_buf: []u8 = brk: { + var b = bun.StringBuilder{}; + b.cap += username_str.utf8ByteLength() + 1 + password_str.utf8ByteLength() + 1 + database_str.utf8ByteLength() + 1 + options_str.utf8ByteLength() + 1 + path_str.utf8ByteLength() + 1; + + b.allocate(bun.default_allocator) catch {}; + var u = username_str.toUTF8WithoutRef(bun.default_allocator); + defer u.deinit(); + username = b.append(u.slice()); + + var p = password_str.toUTF8WithoutRef(bun.default_allocator); + defer p.deinit(); + password = b.append(p.slice()); + + var d = database_str.toUTF8WithoutRef(bun.default_allocator); + defer d.deinit(); + database = b.append(d.slice()); + + var o = options_str.toUTF8WithoutRef(bun.default_allocator); + defer o.deinit(); + options = b.append(o.slice()); + + var _path = path_str.toUTF8WithoutRef(bun.default_allocator); + defer _path.deinit(); + path = b.append(_path.slice()); + + break :brk b.allocatedSlice(); + }; + + const on_connect = arguments[9]; + const on_close = arguments[10]; + const idle_timeout = arguments[11].toInt32(); + const connection_timeout = arguments[12].toInt32(); + const max_lifetime = arguments[13].toInt32(); + const use_unnamed_prepared_statements = arguments[14].asBoolean(); + + const ptr: *PostgresSQLConnection = try bun.default_allocator.create(PostgresSQLConnection); + + ptr.* = PostgresSQLConnection{ + .globalObject = globalObject, + + .database = database, + .user = username, + .password = password, + .path = path, + .options = options, + .options_buf = options_buf, + .socket = .{ .SocketTCP = .{ .socket = .{ .detached = {} } } }, + .requests = PostgresRequest.Queue.init(bun.default_allocator), + .statements = PreparedStatementsMap{}, + .tls_config = tls_config, + .tls_ctx = tls_ctx, + .ssl_mode = ssl_mode, + .tls_status = if (ssl_mode != .disable) .pending else .none, + .idle_timeout_interval_ms = @intCast(idle_timeout), + .connection_timeout_ms = @intCast(connection_timeout), + .max_lifetime_interval_ms = @intCast(max_lifetime), + .flags = .{ + .use_unnamed_prepared_statements = use_unnamed_prepared_statements, + }, + }; + + ptr.updateHasPendingActivity(); + ptr.poll_ref.ref(vm); + const js_value = ptr.toJS(globalObject); + js_value.ensureStillAlive(); + ptr.js_value = js_value; + + js.onconnectSetCached(js_value, globalObject, on_connect); + js.oncloseSetCached(js_value, globalObject, on_close); + bun.analytics.Features.postgres_connections += 1; + + { + const hostname = hostname_str.toUTF8(bun.default_allocator); + defer hostname.deinit(); + + const ctx = vm.rareData().postgresql_context.tcp orelse brk: { + const ctx_ = uws.SocketContext.createNoSSLContext(vm.uwsLoop(), @sizeOf(*PostgresSQLConnection)).?; + uws.NewSocketHandler(false).configure(ctx_, true, *PostgresSQLConnection, SocketHandler(false)); + vm.rareData().postgresql_context.tcp = ctx_; + break :brk ctx_; + }; + + if (path.len > 0) { + ptr.socket = .{ + .SocketTCP = uws.SocketTCP.connectUnixAnon(path, ctx, ptr, false) catch |err| { + tls_config.deinit(); + if (tls_ctx) |tls| { + tls.deinit(true); + } + ptr.deinit(); + return globalObject.throwError(err, "failed to connect to postgresql"); + }, + }; + } else { + ptr.socket = .{ + .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| { + tls_config.deinit(); + if (tls_ctx) |tls| { + tls.deinit(true); + } + ptr.deinit(); + return globalObject.throwError(err, "failed to connect to postgresql"); + }, + }; + } + ptr.resetConnectionTimeout(); + } + + return js_value; +} + +fn SocketHandler(comptime ssl: bool) type { + return struct { + const SocketType = uws.NewSocketHandler(ssl); + fn _socket(s: SocketType) Socket { + if (comptime ssl) { + return Socket{ .SocketTLS = s }; + } + + return Socket{ .SocketTCP = s }; + } + pub fn onOpen(this: *PostgresSQLConnection, socket: SocketType) void { + this.onOpen(_socket(socket)); + } + + fn onHandshake_(this: *PostgresSQLConnection, _: anytype, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + this.onHandshake(success, ssl_error); + } + + pub const onHandshake = if (ssl) onHandshake_ else null; + + pub fn onClose(this: *PostgresSQLConnection, socket: SocketType, _: i32, _: ?*anyopaque) void { + _ = socket; + this.onClose(); + } + + pub fn onEnd(this: *PostgresSQLConnection, socket: SocketType) void { + _ = socket; + this.onClose(); + } + + pub fn onConnectError(this: *PostgresSQLConnection, socket: SocketType, _: i32) void { + _ = socket; + this.onClose(); + } + + pub fn onTimeout(this: *PostgresSQLConnection, socket: SocketType) void { + _ = socket; + this.onTimeout(); + } + + pub fn onData(this: *PostgresSQLConnection, socket: SocketType, data: []const u8) void { + _ = socket; + this.onData(data); + } + + pub fn onWritable(this: *PostgresSQLConnection, socket: SocketType) void { + _ = socket; + this.onDrain(); + } + }; +} + +pub fn ref(this: *@This()) void { + bun.assert(this.ref_count > 0); + this.ref_count += 1; +} + +pub fn doRef(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { + this.poll_ref.ref(this.globalObject.bunVM()); + this.updateHasPendingActivity(); + return .js_undefined; +} + +pub fn doUnref(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { + this.poll_ref.unref(this.globalObject.bunVM()); + this.updateHasPendingActivity(); + return .js_undefined; +} +pub fn doFlush(this: *PostgresSQLConnection, _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSC.JSValue { + this.flushData(); + return .js_undefined; +} + +pub fn deref(this: *@This()) void { + const ref_count = this.ref_count; + this.ref_count -= 1; + + if (ref_count == 1) { + this.disconnect(); + this.deinit(); + } +} + +pub fn doClose(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { + _ = globalObject; + this.disconnect(); + this.write_buffer.deinit(bun.default_allocator); + + return .js_undefined; +} + +pub fn stopTimers(this: *PostgresSQLConnection) void { + if (this.timer.state == .ACTIVE) { + this.globalObject.bunVM().timer.remove(&this.timer); + } + if (this.max_lifetime_timer.state == .ACTIVE) { + this.globalObject.bunVM().timer.remove(&this.max_lifetime_timer); + } +} + +pub fn deinit(this: *@This()) void { + this.stopTimers(); + var iter = this.statements.valueIterator(); + while (iter.next()) |stmt_ptr| { + var stmt = stmt_ptr.*; + stmt.deref(); + } + this.statements.deinit(bun.default_allocator); + this.write_buffer.deinit(bun.default_allocator); + this.read_buffer.deinit(bun.default_allocator); + this.backend_parameters.deinit(); + + bun.freeSensitive(bun.default_allocator, this.options_buf); + + this.tls_config.deinit(); + bun.default_allocator.destroy(this); +} + +fn refAndClose(this: *@This(), js_reason: ?JSC.JSValue) void { + // refAndClose is always called when we wanna to disconnect or when we are closed + + if (!this.socket.isClosed()) { + // event loop need to be alive to close the socket + this.poll_ref.ref(this.globalObject.bunVM()); + // will unref on socket close + this.socket.close(); + } + + // cleanup requests + while (this.current()) |request| { + switch (request.status) { + // pending we will fail the request and the stmt will be marked as error ConnectionClosed too + .pending => { + const stmt = request.statement orelse continue; + stmt.error_response = .{ .postgres_error = AnyPostgresError.ConnectionClosed }; + stmt.status = .failed; + if (js_reason) |reason| { + request.onJSError(reason, this.globalObject); + } else { + request.onError(.{ .postgres_error = AnyPostgresError.ConnectionClosed }, this.globalObject); + } + }, + // in the middle of running + .binding, + .running, + .partial_response, + => { + if (js_reason) |reason| { + request.onJSError(reason, this.globalObject); + } else { + request.onError(.{ .postgres_error = AnyPostgresError.ConnectionClosed }, this.globalObject); + } + }, + // just ignore success and fail cases + .success, .fail => {}, + } + request.deref(); + this.requests.discard(1); + } +} + +pub fn disconnect(this: *@This()) void { + this.stopTimers(); + + if (this.status == .connected) { + this.status = .disconnected; + this.refAndClose(null); + } +} + +fn current(this: *PostgresSQLConnection) ?*PostgresSQLQuery { + if (this.requests.readableLength() == 0) { + return null; + } + + return this.requests.peekItem(0); +} + +pub fn hasQueryRunning(this: *PostgresSQLConnection) bool { + return !this.flags.is_ready_for_query or this.current() != null; +} + +pub const Writer = struct { + connection: *PostgresSQLConnection, + + pub fn write(this: Writer, data: []const u8) AnyPostgresError!void { + var buffer = &this.connection.write_buffer; + try buffer.write(bun.default_allocator, data); + } + + pub fn pwrite(this: Writer, data: []const u8, index: usize) AnyPostgresError!void { + @memcpy(this.connection.write_buffer.byte_list.slice()[index..][0..data.len], data); + } + + pub fn offset(this: Writer) usize { + return this.connection.write_buffer.len(); + } +}; + +pub fn writer(this: *PostgresSQLConnection) protocol.NewWriter(Writer) { + return .{ + .wrapped = .{ + .connection = this, + }, + }; +} + +pub const Reader = struct { + connection: *PostgresSQLConnection, + + pub fn markMessageStart(this: Reader) void { + this.connection.last_message_start = this.connection.read_buffer.head; + } + + pub const ensureLength = ensureCapacity; + + pub fn peek(this: Reader) []const u8 { + return this.connection.read_buffer.remaining(); + } + pub fn skip(this: Reader, count: usize) void { + this.connection.read_buffer.head = @min(this.connection.read_buffer.head + @as(u32, @truncate(count)), this.connection.read_buffer.byte_list.len); + } + pub fn ensureCapacity(this: Reader, count: usize) bool { + return @as(usize, this.connection.read_buffer.head) + count <= @as(usize, this.connection.read_buffer.byte_list.len); + } + pub fn read(this: Reader, count: usize) AnyPostgresError!Data { + var remaining = this.connection.read_buffer.remaining(); + if (@as(usize, remaining.len) < count) { + return error.ShortRead; + } + + this.skip(count); + return Data{ + .temporary = remaining[0..count], + }; + } + pub fn readZ(this: Reader) AnyPostgresError!Data { + const remain = this.connection.read_buffer.remaining(); + + if (bun.strings.indexOfChar(remain, 0)) |zero| { + this.skip(zero + 1); + return Data{ + .temporary = remain[0..zero], + }; + } + + return error.ShortRead; + } +}; + +pub fn bufferedReader(this: *PostgresSQLConnection) protocol.NewReader(Reader) { + return .{ + .wrapped = .{ .connection = this }, + }; +} + +fn advance(this: *PostgresSQLConnection) !void { + while (this.requests.readableLength() > 0) { + var req: *PostgresSQLQuery = this.requests.peekItem(0); + switch (req.status) { + .pending => { + if (req.flags.simple) { + debug("executeQuery", .{}); + var query_str = req.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + PostgresRequest.executeQuery(query_str.slice(), PostgresSQLConnection.Writer, this.writer()) catch |err| { + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + req.deref(); + this.requests.discard(1); + + continue; + }; + this.flags.is_ready_for_query = false; + req.status = .running; + return; + } else { + const stmt = req.statement orelse return error.ExpectedStatement; + + switch (stmt.status) { + .failed => { + bun.assert(stmt.error_response != null); + req.onError(stmt.error_response.?, this.globalObject); + req.deref(); + this.requests.discard(1); + + continue; + }, + .prepared => { + const thisValue = req.thisValue.get(); + bun.assert(thisValue != .zero); + const binding_value = PostgresSQLQuery.js.bindingGetCached(thisValue) orelse .zero; + const columns_value = PostgresSQLQuery.js.columnsGetCached(thisValue) orelse .zero; + req.flags.binary = stmt.fields.len > 0; + + PostgresRequest.bindAndExecute(this.globalObject, stmt, binding_value, columns_value, PostgresSQLConnection.Writer, this.writer()) catch |err| { + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + req.deref(); + this.requests.discard(1); + + continue; + }; + this.flags.is_ready_for_query = false; + req.status = .binding; + return; + }, + .pending => { + // statement is pending, lets write/parse it + var query_str = req.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + const has_params = stmt.signature.fields.len > 0; + // If it does not have params, we can write and execute immediately in one go + if (!has_params) { + const thisValue = req.thisValue.get(); + bun.assert(thisValue != .zero); + // prepareAndQueryWithSignature will write + bind + execute, it will change to running after binding is complete + const binding_value = PostgresSQLQuery.js.bindingGetCached(thisValue) orelse .zero; + PostgresRequest.prepareAndQueryWithSignature(this.globalObject, query_str.slice(), binding_value, PostgresSQLConnection.Writer, this.writer(), &stmt.signature) catch |err| { + stmt.status = .failed; + stmt.error_response = .{ .postgres_error = err }; + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + req.deref(); + this.requests.discard(1); + + continue; + }; + this.flags.is_ready_for_query = false; + req.status = .binding; + stmt.status = .parsing; + + return; + } + const connection_writer = this.writer(); + // write query and wait for it to be prepared + PostgresRequest.writeQuery(query_str.slice(), stmt.signature.prepared_statement_name, stmt.signature.fields, PostgresSQLConnection.Writer, connection_writer) catch |err| { + stmt.error_response = .{ .postgres_error = err }; + stmt.status = .failed; + + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + req.deref(); + this.requests.discard(1); + + continue; + }; + connection_writer.write(&protocol.Sync) catch |err| { + stmt.error_response = .{ .postgres_error = err }; + stmt.status = .failed; + + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + req.deref(); + this.requests.discard(1); + + continue; + }; + this.flags.is_ready_for_query = false; + stmt.status = .parsing; + return; + }, + .parsing => { + // we are still parsing, lets wait for it to be prepared or failed + return; + }, + } + } + }, + + .running, .binding, .partial_response => { + // if we are binding it will switch to running immediately + // if we are running, we need to wait for it to be success or fail + return; + }, + .success, .fail => { + req.deref(); + this.requests.discard(1); + continue; + }, + } + } +} + +pub fn getQueriesArray(this: *const PostgresSQLConnection) JSValue { + return js.queriesGetCached(this.js_value) orelse .zero; +} + +pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_literal), comptime Context: type, reader: protocol.NewReader(Context)) AnyPostgresError!void { + debug("on({s})", .{@tagName(MessageType)}); + + switch (comptime MessageType) { + .DataRow => { + const request = this.current() orelse return error.ExpectedRequest; + var statement = request.statement orelse return error.ExpectedStatement; + var structure: JSValue = .js_undefined; + var cached_structure: ?PostgresCachedStructure = null; + // explicit use switch without else so if new modes are added, we don't forget to check for duplicate fields + switch (request.flags.result_mode) { + .objects => { + cached_structure = statement.structure(this.js_value, this.globalObject); + structure = cached_structure.?.jsValue() orelse .js_undefined; + }, + .raw, .values => { + // no need to check for duplicate fields or structure + }, + } + + var putter = DataCell.Putter{ + .list = &.{}, + .fields = statement.fields, + .binary = request.flags.binary, + .bigint = request.flags.bigint, + .globalObject = this.globalObject, + }; + + var stack_buf: [70]DataCell = undefined; + var cells: []DataCell = stack_buf[0..@min(statement.fields.len, JSC.JSObject.maxInlineCapacity())]; + var free_cells = false; + defer { + for (cells[0..putter.count]) |*cell| { + cell.deinit(); + } + if (free_cells) bun.default_allocator.free(cells); + } + + if (statement.fields.len >= JSC.JSObject.maxInlineCapacity()) { + cells = try bun.default_allocator.alloc(DataCell, statement.fields.len); + free_cells = true; + } + // make sure all cells are reset if reader short breaks the fields will just be null with is better than undefined behavior + @memset(cells, DataCell{ .tag = .null, .value = .{ .null = 0 } }); + putter.list = cells; + + if (request.flags.result_mode == .raw) { + try protocol.DataRow.decode( + &putter, + Context, + reader, + DataCell.Putter.putRaw, + ); + } else { + try protocol.DataRow.decode( + &putter, + Context, + reader, + DataCell.Putter.put, + ); + } + const thisValue = request.thisValue.get(); + bun.assert(thisValue != .zero); + const pending_value = PostgresSQLQuery.js.pendingValueGetCached(thisValue) orelse .zero; + pending_value.ensureStillAlive(); + const result = putter.toJS(this.globalObject, pending_value, structure, statement.fields_flags, request.flags.result_mode, cached_structure); + + if (pending_value == .zero) { + PostgresSQLQuery.js.pendingValueSetCached(thisValue, this.globalObject, result); + } + }, + .CopyData => { + var copy_data: protocol.CopyData = undefined; + try copy_data.decodeInternal(Context, reader); + copy_data.data.deinit(); + }, + .ParameterStatus => { + var parameter_status: protocol.ParameterStatus = undefined; + try parameter_status.decodeInternal(Context, reader); + defer { + parameter_status.deinit(); + } + try this.backend_parameters.insert(parameter_status.name.slice(), parameter_status.value.slice()); + }, + .ReadyForQuery => { + var ready_for_query: protocol.ReadyForQuery = undefined; + try ready_for_query.decodeInternal(Context, reader); + + this.setStatus(.connected); + this.flags.is_ready_for_query = true; + this.socket.setTimeout(300); + defer this.updateRef(); + + if (this.current()) |request| { + if (request.status == .partial_response) { + // if is a partial response, just signal that the query is now complete + request.onResult("", this.globalObject, this.js_value, true); + } + } + try this.advance(); + + this.flushData(); + }, + .CommandComplete => { + var request = this.current() orelse return error.ExpectedRequest; + + var cmd: protocol.CommandComplete = undefined; + try cmd.decodeInternal(Context, reader); + defer { + cmd.deinit(); + } + debug("-> {s}", .{cmd.command_tag.slice()}); + defer this.updateRef(); + + if (request.flags.simple) { + // simple queries can have multiple commands + request.onResult(cmd.command_tag.slice(), this.globalObject, this.js_value, false); + } else { + request.onResult(cmd.command_tag.slice(), this.globalObject, this.js_value, true); + } + }, + .BindComplete => { + try reader.eatMessage(protocol.BindComplete); + var request = this.current() orelse return error.ExpectedRequest; + if (request.status == .binding) { + request.status = .running; + } + }, + .ParseComplete => { + try reader.eatMessage(protocol.ParseComplete); + const request = this.current() orelse return error.ExpectedRequest; + if (request.statement) |statement| { + // if we have params wait for parameter description + if (statement.status == .parsing and statement.signature.fields.len == 0) { + statement.status = .prepared; + } + } + }, + .ParameterDescription => { + var description: protocol.ParameterDescription = undefined; + try description.decodeInternal(Context, reader); + const request = this.current() orelse return error.ExpectedRequest; + var statement = request.statement orelse return error.ExpectedStatement; + statement.parameters = description.parameters; + if (statement.status == .parsing) { + statement.status = .prepared; + } + }, + .RowDescription => { + var description: protocol.RowDescription = undefined; + try description.decodeInternal(Context, reader); + errdefer description.deinit(); + const request = this.current() orelse return error.ExpectedRequest; + var statement = request.statement orelse return error.ExpectedStatement; + statement.fields = description.fields; + }, + .Authentication => { + var auth: protocol.Authentication = undefined; + try auth.decodeInternal(Context, reader); + defer auth.deinit(); + + switch (auth) { + .SASL => { + if (this.authentication_state != .SASL) { + this.authentication_state = .{ .SASL = .{} }; + } + + var mechanism_buf: [128]u8 = undefined; + const mechanism = std.fmt.bufPrintZ(&mechanism_buf, "n,,n=*,r={s}", .{this.authentication_state.SASL.nonce()}) catch unreachable; + var response = protocol.SASLInitialResponse{ + .mechanism = .{ + .temporary = "SCRAM-SHA-256", + }, + .data = .{ + .temporary = mechanism, + }, + }; + + try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); + debug("SASL", .{}); + this.flushData(); + }, + .SASLContinue => |*cont| { + if (this.authentication_state != .SASL) { + debug("Unexpected SASLContinue for authentiation state: {s}", .{@tagName(std.meta.activeTag(this.authentication_state))}); + return error.UnexpectedMessage; + } + var sasl = &this.authentication_state.SASL; + + if (sasl.status != .init) { + debug("Unexpected SASLContinue for SASL state: {s}", .{@tagName(sasl.status)}); + return error.UnexpectedMessage; + } + debug("SASLContinue", .{}); + + const iteration_count = try cont.iterationCount(); + + const server_salt_decoded_base64 = bun.base64.decodeAlloc(bun.z_allocator, cont.s) catch |err| { + return switch (err) { + error.DecodingFailed => error.SASL_SIGNATURE_INVALID_BASE64, + else => |e| e, + }; + }; + defer bun.z_allocator.free(server_salt_decoded_base64); + try sasl.computeSaltedPassword(server_salt_decoded_base64, iteration_count, this); + + const auth_string = try std.fmt.allocPrint( + bun.z_allocator, + "n=*,r={s},r={s},s={s},i={s},c=biws,r={s}", + .{ + sasl.nonce(), + cont.r, + cont.s, + cont.i, + cont.r, + }, + ); + defer bun.z_allocator.free(auth_string); + try sasl.computeServerSignature(auth_string); + + const client_key = sasl.clientKey(); + const client_key_signature = sasl.clientKeySignature(&client_key, auth_string); + var client_key_xor_buffer: [32]u8 = undefined; + for (&client_key_xor_buffer, client_key, client_key_signature) |*out, a, b| { + out.* = a ^ b; + } + + var client_key_xor_base64_buf = std.mem.zeroes([bun.base64.encodeLenFromSize(32)]u8); + const xor_base64_len = bun.base64.encode(&client_key_xor_base64_buf, &client_key_xor_buffer); + + const payload = try std.fmt.allocPrint( + bun.z_allocator, + "c=biws,r={s},p={s}", + .{ cont.r, client_key_xor_base64_buf[0..xor_base64_len] }, + ); + defer bun.z_allocator.free(payload); + + var response = protocol.SASLResponse{ + .data = .{ + .temporary = payload, + }, + }; + + try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); + sasl.status = .@"continue"; + this.flushData(); + }, + .SASLFinal => |final| { + if (this.authentication_state != .SASL) { + debug("SASLFinal - Unexpected SASLContinue for authentiation state: {s}", .{@tagName(std.meta.activeTag(this.authentication_state))}); + return error.UnexpectedMessage; + } + var sasl = &this.authentication_state.SASL; + + if (sasl.status != .@"continue") { + debug("SASLFinal - Unexpected SASLContinue for SASL state: {s}", .{@tagName(sasl.status)}); + return error.UnexpectedMessage; + } + + if (sasl.server_signature_len == 0) { + debug("SASLFinal - Server signature is empty", .{}); + return error.UnexpectedMessage; + } + + const server_signature = sasl.serverSignature(); + + // This will usually start with "v=" + const comparison_signature = final.data.slice(); + + if (comparison_signature.len < 2 or !bun.strings.eqlLong(server_signature, comparison_signature[2..], true)) { + debug("SASLFinal - SASL Server signature mismatch\nExpected: {s}\nActual: {s}", .{ server_signature, comparison_signature[2..] }); + this.fail("The server did not return the correct signature", error.SASL_SIGNATURE_MISMATCH); + } else { + debug("SASLFinal - SASL Server signature match", .{}); + this.authentication_state.zero(); + } + }, + .Ok => { + debug("Authentication OK", .{}); + this.authentication_state.zero(); + this.authentication_state = .{ .ok = {} }; + }, + + .Unknown => { + this.fail("Unknown authentication method", error.UNKNOWN_AUTHENTICATION_METHOD); + }, + + .ClearTextPassword => { + debug("ClearTextPassword", .{}); + var response = protocol.PasswordMessage{ + .password = .{ + .temporary = this.password, + }, + }; + + try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); + this.flushData(); + }, + + .MD5Password => |md5| { + debug("MD5Password", .{}); + // Format is: md5 + md5(md5(password + username) + salt) + var first_hash_buf: bun.sha.MD5.Digest = undefined; + var first_hash_str: [32]u8 = undefined; + var final_hash_buf: bun.sha.MD5.Digest = undefined; + var final_hash_str: [32]u8 = undefined; + var final_password_buf: [36]u8 = undefined; + + // First hash: md5(password + username) + var first_hasher = bun.sha.MD5.init(); + first_hasher.update(this.password); + first_hasher.update(this.user); + first_hasher.final(&first_hash_buf); + const first_hash_str_output = std.fmt.bufPrint(&first_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&first_hash_buf)}) catch unreachable; + + // Second hash: md5(first_hash + salt) + var final_hasher = bun.sha.MD5.init(); + final_hasher.update(first_hash_str_output); + final_hasher.update(&md5.salt); + final_hasher.final(&final_hash_buf); + const final_hash_str_output = std.fmt.bufPrint(&final_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&final_hash_buf)}) catch unreachable; + + // Format final password as "md5" + final_hash + const final_password = std.fmt.bufPrintZ(&final_password_buf, "md5{s}", .{final_hash_str_output}) catch unreachable; + + var response = protocol.PasswordMessage{ + .password = .{ + .temporary = final_password, + }, + }; + + this.authentication_state = .{ .md5 = {} }; + try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); + this.flushData(); + }, + + else => { + debug("TODO auth: {s}", .{@tagName(std.meta.activeTag(auth))}); + this.fail("TODO: support authentication method: {s}", error.UNSUPPORTED_AUTHENTICATION_METHOD); + }, + } + }, + .NoData => { + try reader.eatMessage(protocol.NoData); + var request = this.current() orelse return error.ExpectedRequest; + if (request.status == .binding) { + request.status = .running; + } + }, + .BackendKeyData => { + try this.backend_key_data.decodeInternal(Context, reader); + }, + .ErrorResponse => { + var err: protocol.ErrorResponse = undefined; + try err.decodeInternal(Context, reader); + + if (this.status == .connecting or this.status == .sent_startup_message) { + defer { + err.deinit(); + } + + this.failWithJSValue(err.toJS(this.globalObject)); + + // it shouldn't enqueue any requests while connecting + bun.assert(this.requests.count == 0); + return; + } + + var request = this.current() orelse { + debug("ErrorResponse: {}", .{err}); + return error.ExpectedRequest; + }; + var is_error_owned = true; + defer { + if (is_error_owned) { + err.deinit(); + } + } + if (request.statement) |stmt| { + if (stmt.status == PostgresSQLStatement.Status.parsing) { + stmt.status = PostgresSQLStatement.Status.failed; + stmt.error_response = .{ .protocol = err }; + is_error_owned = false; + if (this.statements.remove(bun.hash(stmt.signature.name))) { + stmt.deref(); + } + } + } + this.updateRef(); + + request.onError(.{ .protocol = err }, this.globalObject); + }, + .PortalSuspended => { + // try reader.eatMessage(&protocol.PortalSuspended); + // var request = this.current() orelse return error.ExpectedRequest; + // _ = request; + debug("TODO PortalSuspended", .{}); + }, + .CloseComplete => { + try reader.eatMessage(protocol.CloseComplete); + var request = this.current() orelse return error.ExpectedRequest; + defer this.updateRef(); + if (request.flags.simple) { + request.onResult("CLOSECOMPLETE", this.globalObject, this.js_value, false); + } else { + request.onResult("CLOSECOMPLETE", this.globalObject, this.js_value, true); + } + }, + .CopyInResponse => { + debug("TODO CopyInResponse", .{}); + }, + .NoticeResponse => { + debug("UNSUPPORTED NoticeResponse", .{}); + var resp: protocol.NoticeResponse = undefined; + + try resp.decodeInternal(Context, reader); + resp.deinit(); + }, + .EmptyQueryResponse => { + try reader.eatMessage(protocol.EmptyQueryResponse); + var request = this.current() orelse return error.ExpectedRequest; + defer this.updateRef(); + if (request.flags.simple) { + request.onResult("", this.globalObject, this.js_value, false); + } else { + request.onResult("", this.globalObject, this.js_value, true); + } + }, + .CopyOutResponse => { + debug("TODO CopyOutResponse", .{}); + }, + .CopyDone => { + debug("TODO CopyDone", .{}); + }, + .CopyBothResponse => { + debug("TODO CopyBothResponse", .{}); + }, + else => @compileError("Unknown message type: " ++ @tagName(MessageType)), + } +} + +pub fn updateRef(this: *PostgresSQLConnection) void { + this.updateHasPendingActivity(); + if (this.pending_activity_count.raw > 0) { + this.poll_ref.ref(this.globalObject.bunVM()); + } else { + this.poll_ref.unref(this.globalObject.bunVM()); + } +} + +pub fn getConnected(this: *PostgresSQLConnection, _: *JSC.JSGlobalObject) JSValue { + return JSValue.jsBoolean(this.status == Status.connected); +} + +pub fn consumeOnConnectCallback(this: *const PostgresSQLConnection, globalObject: *JSC.JSGlobalObject) ?JSC.JSValue { + debug("consumeOnConnectCallback", .{}); + const on_connect = js.onconnectGetCached(this.js_value) orelse return null; + debug("consumeOnConnectCallback exists", .{}); + + js.onconnectSetCached(this.js_value, globalObject, .zero); + return on_connect; +} + +pub fn consumeOnCloseCallback(this: *const PostgresSQLConnection, globalObject: *JSC.JSGlobalObject) ?JSC.JSValue { + debug("consumeOnCloseCallback", .{}); + const on_close = js.oncloseGetCached(this.js_value) orelse return null; + debug("consumeOnCloseCallback exists", .{}); + js.oncloseSetCached(this.js_value, globalObject, .zero); + return on_close; +} + +const PreparedStatementsMap = std.HashMapUnmanaged(u64, *PostgresSQLStatement, bun.IdentityContext(u64), 80); + +const debug = bun.Output.scoped(.Postgres, false); + +// @sortImports + +const PostgresCachedStructure = @import("./PostgresCachedStructure.zig"); +const PostgresRequest = @import("./PostgresRequest.zig"); +const PostgresSQLConnection = @This(); +const PostgresSQLQuery = @import("./PostgresSQLQuery.zig"); +const PostgresSQLStatement = @import("./PostgresSQLStatement.zig"); +const SocketMonitor = @import("./SocketMonitor.zig"); +const protocol = @import("./PostgresProtocol.zig"); +const std = @import("std"); +const AuthenticationState = @import("./AuthenticationState.zig").AuthenticationState; +const ConnectionFlags = @import("./ConnectionFlags.zig").ConnectionFlags; +const Data = @import("./Data.zig").Data; +const DataCell = @import("./DataCell.zig").DataCell; +const SSLMode = @import("./SSLMode.zig").SSLMode; +const Status = @import("./Status.zig").Status; +const TLSStatus = @import("./TLSStatus.zig").TLSStatus; + +const AnyPostgresError = @import("./AnyPostgresError.zig").AnyPostgresError; +const postgresErrorToJS = @import("./AnyPostgresError.zig").postgresErrorToJS; + +const bun = @import("bun"); +const BoringSSL = bun.BoringSSL; +const assert = bun.assert; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; + +pub const js = JSC.Codegen.JSPostgresSQLConnection; +pub const fromJS = js.fromJS; +pub const fromJSDirect = js.fromJSDirect; +pub const toJS = js.toJS; + +const uws = bun.uws; +const Socket = uws.AnySocket; diff --git a/src/sql/postgres/PostgresSQLContext.zig b/src/sql/postgres/PostgresSQLContext.zig new file mode 100644 index 0000000000..35ecc7f46e --- /dev/null +++ b/src/sql/postgres/PostgresSQLContext.zig @@ -0,0 +1,23 @@ +tcp: ?*uws.SocketContext = null, + +onQueryResolveFn: JSC.Strong.Optional = .empty, +onQueryRejectFn: JSC.Strong.Optional = .empty, + +pub fn init(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { + var ctx = &globalObject.bunVM().rareData().postgresql_context; + ctx.onQueryResolveFn.set(globalObject, callframe.argument(0)); + ctx.onQueryRejectFn.set(globalObject, callframe.argument(1)); + + return .js_undefined; +} + +comptime { + const js_init = JSC.toJSHostFn(init); + @export(&js_init, .{ .name = "PostgresSQLContext__init" }); +} + +// @sortImports + +const bun = @import("bun"); +const JSC = bun.JSC; +const uws = bun.uws; diff --git a/src/sql/postgres/PostgresSQLQuery.zig b/src/sql/postgres/PostgresSQLQuery.zig new file mode 100644 index 0000000000..3aaa4d3920 --- /dev/null +++ b/src/sql/postgres/PostgresSQLQuery.zig @@ -0,0 +1,499 @@ +statement: ?*PostgresSQLStatement = null, +query: bun.String = bun.String.empty, +cursor_name: bun.String = bun.String.empty, + +thisValue: JSRef = JSRef.empty(), + +status: Status = Status.pending, + +ref_count: std.atomic.Value(u32) = std.atomic.Value(u32).init(1), + +flags: packed struct(u8) { + is_done: bool = false, + binary: bool = false, + bigint: bool = false, + simple: bool = false, + result_mode: PostgresSQLQueryResultMode = .objects, + _padding: u2 = 0, +} = .{}, + +pub fn getTarget(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, clean_target: bool) JSC.JSValue { + const thisValue = this.thisValue.get(); + if (thisValue == .zero) { + return .zero; + } + const target = js.targetGetCached(thisValue) orelse return .zero; + if (clean_target) { + js.targetSetCached(thisValue, globalObject, .zero); + } + return target; +} + +pub const Status = enum(u8) { + /// The query was just enqueued, statement status can be checked for more details + pending, + /// The query is being bound to the statement + binding, + /// The query is running + running, + /// The query is waiting for a partial response + partial_response, + /// The query was successful + success, + /// The query failed + fail, + + pub fn isRunning(this: Status) bool { + return @intFromEnum(this) > @intFromEnum(Status.pending) and @intFromEnum(this) < @intFromEnum(Status.success); + } +}; + +pub fn hasPendingActivity(this: *@This()) bool { + return this.ref_count.load(.monotonic) > 1; +} + +pub fn deinit(this: *@This()) void { + this.thisValue.deinit(); + if (this.statement) |statement| { + statement.deref(); + } + this.query.deref(); + this.cursor_name.deref(); + bun.default_allocator.destroy(this); +} + +pub fn finalize(this: *@This()) void { + debug("PostgresSQLQuery finalize", .{}); + if (this.thisValue == .weak) { + // clean up if is a weak reference, if is a strong reference we need to wait until the query is done + // if we are a strong reference, here is probably a bug because GC'd should not happen + this.thisValue.weak = .zero; + } + this.deref(); +} + +pub fn deref(this: *@This()) void { + const ref_count = this.ref_count.fetchSub(1, .monotonic); + + if (ref_count == 1) { + this.deinit(); + } +} + +pub fn ref(this: *@This()) void { + bun.assert(this.ref_count.fetchAdd(1, .monotonic) > 0); +} + +pub fn onWriteFail( + this: *@This(), + err: AnyPostgresError, + globalObject: *JSC.JSGlobalObject, + queries_array: JSValue, +) void { + this.status = .fail; + const thisValue = this.thisValue.get(); + defer this.thisValue.deinit(); + const targetValue = this.getTarget(globalObject, true); + if (thisValue == .zero or targetValue == .zero) { + return; + } + + const vm = JSC.VirtualMachine.get(); + const function = vm.rareData().postgresql_context.onQueryRejectFn.get().?; + const event_loop = vm.eventLoop(); + event_loop.runCallback(function, globalObject, thisValue, &.{ + targetValue, + postgresErrorToJS(globalObject, null, err), + queries_array, + }); +} +pub fn onJSError(this: *@This(), err: JSC.JSValue, globalObject: *JSC.JSGlobalObject) void { + this.status = .fail; + this.ref(); + defer this.deref(); + + const thisValue = this.thisValue.get(); + defer this.thisValue.deinit(); + const targetValue = this.getTarget(globalObject, true); + if (thisValue == .zero or targetValue == .zero) { + return; + } + + var vm = JSC.VirtualMachine.get(); + const function = vm.rareData().postgresql_context.onQueryRejectFn.get().?; + const event_loop = vm.eventLoop(); + event_loop.runCallback(function, globalObject, thisValue, &.{ + targetValue, + err, + }); +} +pub fn onError(this: *@This(), err: PostgresSQLStatement.Error, globalObject: *JSC.JSGlobalObject) void { + this.onJSError(err.toJS(globalObject), globalObject); +} + +pub fn allowGC(thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) void { + if (thisValue == .zero) { + return; + } + + defer thisValue.ensureStillAlive(); + js.bindingSetCached(thisValue, globalObject, .zero); + js.pendingValueSetCached(thisValue, globalObject, .zero); + js.targetSetCached(thisValue, globalObject, .zero); +} + +fn consumePendingValue(thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) ?JSValue { + const pending_value = js.pendingValueGetCached(thisValue) orelse return null; + js.pendingValueSetCached(thisValue, globalObject, .zero); + return pending_value; +} + +pub fn onResult(this: *@This(), command_tag_str: []const u8, globalObject: *JSC.JSGlobalObject, connection: JSC.JSValue, is_last: bool) void { + this.ref(); + defer this.deref(); + + const thisValue = this.thisValue.get(); + const targetValue = this.getTarget(globalObject, is_last); + if (is_last) { + this.status = .success; + } else { + this.status = .partial_response; + } + defer if (is_last) { + allowGC(thisValue, globalObject); + this.thisValue.deinit(); + }; + if (thisValue == .zero or targetValue == .zero) { + return; + } + + const vm = JSC.VirtualMachine.get(); + const function = vm.rareData().postgresql_context.onQueryResolveFn.get().?; + const event_loop = vm.eventLoop(); + const tag = CommandTag.init(command_tag_str); + + event_loop.runCallback(function, globalObject, thisValue, &.{ + targetValue, + consumePendingValue(thisValue, globalObject) orelse .js_undefined, + tag.toJSTag(globalObject), + tag.toJSNumber(), + if (connection == .zero) .js_undefined else PostgresSQLConnection.js.queriesGetCached(connection) orelse .js_undefined, + JSValue.jsBoolean(is_last), + }); +} + +pub fn constructor(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!*PostgresSQLQuery { + _ = callframe; + return globalThis.throw("PostgresSQLQuery cannot be constructed directly", .{}); +} + +pub fn estimatedSize(this: *PostgresSQLQuery) usize { + _ = this; + return @sizeOf(PostgresSQLQuery); +} + +pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { + const arguments = callframe.arguments_old(6).slice(); + var args = JSC.CallFrame.ArgumentsSlice.init(globalThis.bunVM(), arguments); + defer args.deinit(); + const query = args.nextEat() orelse { + return globalThis.throw("query must be a string", .{}); + }; + const values = args.nextEat() orelse { + return globalThis.throw("values must be an array", .{}); + }; + + if (!query.isString()) { + return globalThis.throw("query must be a string", .{}); + } + + if (values.jsType() != .Array) { + return globalThis.throw("values must be an array", .{}); + } + + const pending_value: JSValue = args.nextEat() orelse .js_undefined; + const columns: JSValue = args.nextEat() orelse .js_undefined; + const js_bigint: JSValue = args.nextEat() orelse .false; + const js_simple: JSValue = args.nextEat() orelse .false; + + const bigint = js_bigint.isBoolean() and js_bigint.asBoolean(); + const simple = js_simple.isBoolean() and js_simple.asBoolean(); + if (simple) { + if (try values.getLength(globalThis) > 0) { + return globalThis.throwInvalidArguments("simple query cannot have parameters", .{}); + } + if (try query.getLength(globalThis) >= std.math.maxInt(i32)) { + return globalThis.throwInvalidArguments("query is too long", .{}); + } + } + if (!pending_value.jsType().isArrayLike()) { + return globalThis.throwInvalidArgumentType("query", "pendingValue", "Array"); + } + + var ptr = try bun.default_allocator.create(PostgresSQLQuery); + + const this_value = ptr.toJS(globalThis); + this_value.ensureStillAlive(); + + ptr.* = .{ + .query = try query.toBunString(globalThis), + .thisValue = JSRef.initWeak(this_value), + .flags = .{ + .bigint = bigint, + .simple = simple, + }, + }; + + js.bindingSetCached(this_value, globalThis, values); + js.pendingValueSetCached(this_value, globalThis, pending_value); + if (!columns.isUndefined()) { + js.columnsSetCached(this_value, globalThis, columns); + } + + return this_value; +} + +pub fn push(this: *PostgresSQLQuery, globalThis: *JSC.JSGlobalObject, value: JSValue) void { + var pending_value = this.pending_value.get() orelse return; + pending_value.push(globalThis, value); +} + +pub fn doDone(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { + _ = globalObject; + this.flags.is_done = true; + return .js_undefined; +} +pub fn setPendingValue(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + const result = callframe.argument(0); + js.pendingValueSetCached(this.thisValue.get(), globalObject, result); + return .js_undefined; +} +pub fn setMode(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + const js_mode = callframe.argument(0); + if (js_mode.isEmptyOrUndefinedOrNull() or !js_mode.isNumber()) { + return globalObject.throwInvalidArgumentType("setMode", "mode", "Number"); + } + + const mode = try js_mode.coerce(i32, globalObject); + this.flags.result_mode = std.meta.intToEnum(PostgresSQLQueryResultMode, mode) catch { + return globalObject.throwInvalidArgumentTypeValue("mode", "Number", js_mode); + }; + return .js_undefined; +} + +pub fn doRun(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + var arguments_ = callframe.arguments_old(2); + const arguments = arguments_.slice(); + const connection: *PostgresSQLConnection = arguments[0].as(PostgresSQLConnection) orelse { + return globalObject.throw("connection must be a PostgresSQLConnection", .{}); + }; + + connection.poll_ref.ref(globalObject.bunVM()); + var query = arguments[1]; + + if (!query.isObject()) { + return globalObject.throwInvalidArgumentType("run", "query", "Query"); + } + + const this_value = callframe.this(); + const binding_value = js.bindingGetCached(this_value) orelse .zero; + var query_str = this.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + var writer = connection.writer(); + + if (this.flags.simple) { + debug("executeQuery", .{}); + + const can_execute = !connection.hasQueryRunning(); + if (can_execute) { + PostgresRequest.executeQuery(query_str.slice(), PostgresSQLConnection.Writer, writer) catch |err| { + if (!globalObject.hasException()) + return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to execute query", err)); + return error.JSError; + }; + connection.flags.is_ready_for_query = false; + this.status = .running; + } else { + this.status = .pending; + } + const stmt = bun.default_allocator.create(PostgresSQLStatement) catch { + return globalObject.throwOutOfMemory(); + }; + // Query is simple and it's the only owner of the statement + stmt.* = .{ + .signature = Signature.empty(), + .ref_count = 1, + .status = .parsing, + }; + this.statement = stmt; + // We need a strong reference to the query so that it doesn't get GC'd + connection.requests.writeItem(this) catch return globalObject.throwOutOfMemory(); + this.ref(); + this.thisValue.upgrade(globalObject); + + js.targetSetCached(this_value, globalObject, query); + if (this.status == .running) { + connection.flushDataAndResetTimeout(); + } else { + connection.resetConnectionTimeout(); + } + return .js_undefined; + } + + const columns_value: JSValue = js.columnsGetCached(this_value) orelse .js_undefined; + + var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value, connection.prepared_statement_id, connection.flags.use_unnamed_prepared_statements) catch |err| { + if (!globalObject.hasException()) + return globalObject.throwError(err, "failed to generate signature"); + return error.JSError; + }; + + const has_params = signature.fields.len > 0; + var did_write = false; + enqueue: { + var connection_entry_value: ?**PostgresSQLStatement = null; + if (!connection.flags.use_unnamed_prepared_statements) { + const entry = connection.statements.getOrPut(bun.default_allocator, bun.hash(signature.name)) catch |err| { + signature.deinit(); + return globalObject.throwError(err, "failed to allocate statement"); + }; + connection_entry_value = entry.value_ptr; + if (entry.found_existing) { + this.statement = connection_entry_value.?.*; + this.statement.?.ref(); + signature.deinit(); + + switch (this.statement.?.status) { + .failed => { + // If the statement failed, we need to throw the error + return globalObject.throwValue(this.statement.?.error_response.?.toJS(globalObject)); + }, + .prepared => { + if (!connection.hasQueryRunning()) { + this.flags.binary = this.statement.?.fields.len > 0; + debug("bindAndExecute", .{}); + + // bindAndExecute will bind + execute, it will change to running after binding is complete + PostgresRequest.bindAndExecute(globalObject, this.statement.?, binding_value, columns_value, PostgresSQLConnection.Writer, writer) catch |err| { + if (!globalObject.hasException()) + return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to bind and execute query", err)); + return error.JSError; + }; + connection.flags.is_ready_for_query = false; + this.status = .binding; + + did_write = true; + } + }, + .parsing, .pending => {}, + } + + break :enqueue; + } + } + const can_execute = !connection.hasQueryRunning(); + + if (can_execute) { + // If it does not have params, we can write and execute immediately in one go + if (!has_params) { + debug("prepareAndQueryWithSignature", .{}); + // prepareAndQueryWithSignature will write + bind + execute, it will change to running after binding is complete + PostgresRequest.prepareAndQueryWithSignature(globalObject, query_str.slice(), binding_value, PostgresSQLConnection.Writer, writer, &signature) catch |err| { + signature.deinit(); + if (!globalObject.hasException()) + return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to prepare and query", err)); + return error.JSError; + }; + connection.flags.is_ready_for_query = false; + this.status = .binding; + did_write = true; + } else { + debug("writeQuery", .{}); + + PostgresRequest.writeQuery(query_str.slice(), signature.prepared_statement_name, signature.fields, PostgresSQLConnection.Writer, writer) catch |err| { + signature.deinit(); + if (!globalObject.hasException()) + return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to write query", err)); + return error.JSError; + }; + writer.write(&protocol.Sync) catch |err| { + signature.deinit(); + if (!globalObject.hasException()) + return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to flush", err)); + return error.JSError; + }; + connection.flags.is_ready_for_query = false; + did_write = true; + } + } + { + const stmt = bun.default_allocator.create(PostgresSQLStatement) catch { + return globalObject.throwOutOfMemory(); + }; + // we only have connection_entry_value if we are using named prepared statements + if (connection_entry_value) |entry_value| { + connection.prepared_statement_id += 1; + stmt.* = .{ .signature = signature, .ref_count = 2, .status = if (can_execute) .parsing else .pending }; + this.statement = stmt; + + entry_value.* = stmt; + } else { + stmt.* = .{ .signature = signature, .ref_count = 1, .status = if (can_execute) .parsing else .pending }; + this.statement = stmt; + } + } + } + // We need a strong reference to the query so that it doesn't get GC'd + connection.requests.writeItem(this) catch return globalObject.throwOutOfMemory(); + this.ref(); + this.thisValue.upgrade(globalObject); + + js.targetSetCached(this_value, globalObject, query); + if (did_write) { + connection.flushDataAndResetTimeout(); + } else { + connection.resetConnectionTimeout(); + } + return .js_undefined; +} + +pub fn doCancel(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + _ = callframe; + _ = globalObject; + _ = this; + + return .js_undefined; +} + +comptime { + const jscall = JSC.toJSHostFn(call); + @export(&jscall, .{ .name = "PostgresSQLQuery__createInstance" }); +} + +const debug = bun.Output.scoped(.Postgres, false); + +// @sortImports + +const PostgresRequest = @import("./PostgresRequest.zig"); +const PostgresSQLConnection = @import("./PostgresSQLConnection.zig"); +const PostgresSQLQuery = @This(); +const PostgresSQLStatement = @import("./PostgresSQLStatement.zig"); +const Signature = @import("./Signature.zig"); +const bun = @import("bun"); +const protocol = @import("./PostgresProtocol.zig"); +const std = @import("std"); +const CommandTag = @import("./CommandTag.zig").CommandTag; +const PostgresSQLQueryResultMode = @import("./PostgresSQLQueryResultMode.zig").PostgresSQLQueryResultMode; + +const AnyPostgresError = @import("./AnyPostgresError.zig").AnyPostgresError; +const postgresErrorToJS = @import("./AnyPostgresError.zig").postgresErrorToJS; + +const JSC = bun.JSC; +const JSGlobalObject = JSC.JSGlobalObject; +const JSRef = JSC.JSRef; +const JSValue = JSC.JSValue; + +pub const js = JSC.Codegen.JSPostgresSQLQuery; +pub const fromJS = js.fromJS; +pub const fromJSDirect = js.fromJSDirect; +pub const toJS = js.toJS; diff --git a/src/sql/postgres/PostgresSQLQueryResultMode.zig b/src/sql/postgres/PostgresSQLQueryResultMode.zig new file mode 100644 index 0000000000..d8f7c9c444 --- /dev/null +++ b/src/sql/postgres/PostgresSQLQueryResultMode.zig @@ -0,0 +1,7 @@ +pub const PostgresSQLQueryResultMode = enum(u2) { + objects = 0, + values = 1, + raw = 2, +}; + +// @sortImports diff --git a/src/sql/postgres/PostgresSQLStatement.zig b/src/sql/postgres/PostgresSQLStatement.zig new file mode 100644 index 0000000000..cc832e6909 --- /dev/null +++ b/src/sql/postgres/PostgresSQLStatement.zig @@ -0,0 +1,192 @@ +cached_structure: PostgresCachedStructure = .{}, +ref_count: u32 = 1, +fields: []protocol.FieldDescription = &[_]protocol.FieldDescription{}, +parameters: []const int4 = &[_]int4{}, +signature: Signature, +status: Status = Status.pending, +error_response: ?Error = null, +needs_duplicate_check: bool = true, +fields_flags: DataCell.Flags = .{}, + +pub const Error = union(enum) { + protocol: protocol.ErrorResponse, + postgres_error: AnyPostgresError, + + pub fn deinit(this: *@This()) void { + switch (this.*) { + .protocol => |*err| err.deinit(), + .postgres_error => {}, + } + } + + pub fn toJS(this: *const @This(), globalObject: *JSC.JSGlobalObject) JSValue { + return switch (this.*) { + .protocol => |err| err.toJS(globalObject), + .postgres_error => |err| postgresErrorToJS(globalObject, null, err), + }; + } +}; +pub const Status = enum { + pending, + parsing, + prepared, + failed, + + pub fn isRunning(this: @This()) bool { + return this == .parsing; + } +}; +pub fn ref(this: *@This()) void { + bun.assert(this.ref_count > 0); + this.ref_count += 1; +} + +pub fn deref(this: *@This()) void { + const ref_count = this.ref_count; + this.ref_count -= 1; + + if (ref_count == 1) { + this.deinit(); + } +} + +pub fn checkForDuplicateFields(this: *PostgresSQLStatement) void { + if (!this.needs_duplicate_check) return; + this.needs_duplicate_check = false; + + var seen_numbers = std.ArrayList(u32).init(bun.default_allocator); + defer seen_numbers.deinit(); + var seen_fields = bun.StringHashMap(void).init(bun.default_allocator); + seen_fields.ensureUnusedCapacity(@intCast(this.fields.len)) catch bun.outOfMemory(); + defer seen_fields.deinit(); + + // iterate backwards + var remaining = this.fields.len; + var flags: DataCell.Flags = .{}; + while (remaining > 0) { + remaining -= 1; + const field: *protocol.FieldDescription = &this.fields[remaining]; + switch (field.name_or_index) { + .name => |*name| { + const seen = seen_fields.getOrPut(name.slice()) catch unreachable; + if (seen.found_existing) { + field.name_or_index = .duplicate; + flags.has_duplicate_columns = true; + } + + flags.has_named_columns = true; + }, + .index => |index| { + if (std.mem.indexOfScalar(u32, seen_numbers.items, index) != null) { + field.name_or_index = .duplicate; + flags.has_duplicate_columns = true; + } else { + seen_numbers.append(index) catch bun.outOfMemory(); + } + + flags.has_indexed_columns = true; + }, + .duplicate => { + flags.has_duplicate_columns = true; + }, + } + } + + this.fields_flags = flags; +} + +pub fn deinit(this: *PostgresSQLStatement) void { + debug("PostgresSQLStatement deinit", .{}); + + bun.assert(this.ref_count == 0); + + for (this.fields) |*field| { + field.deinit(); + } + bun.default_allocator.free(this.fields); + bun.default_allocator.free(this.parameters); + this.cached_structure.deinit(); + if (this.error_response) |err| { + this.error_response = null; + var _error = err; + _error.deinit(); + } + this.signature.deinit(); + bun.default_allocator.destroy(this); +} + +pub fn structure(this: *PostgresSQLStatement, owner: JSValue, globalObject: *JSC.JSGlobalObject) PostgresCachedStructure { + if (this.cached_structure.has()) { + return this.cached_structure; + } + this.checkForDuplicateFields(); + + // lets avoid most allocations + var stack_ids: [70]JSC.JSObject.ExternColumnIdentifier = undefined; + // lets de duplicate the fields early + var nonDuplicatedCount = this.fields.len; + for (this.fields) |*field| { + if (field.name_or_index == .duplicate) { + nonDuplicatedCount -= 1; + } + } + const ids = if (nonDuplicatedCount <= JSC.JSObject.maxInlineCapacity()) stack_ids[0..nonDuplicatedCount] else bun.default_allocator.alloc(JSC.JSObject.ExternColumnIdentifier, nonDuplicatedCount) catch bun.outOfMemory(); + + var i: usize = 0; + for (this.fields) |*field| { + if (field.name_or_index == .duplicate) continue; + + var id: *JSC.JSObject.ExternColumnIdentifier = &ids[i]; + switch (field.name_or_index) { + .name => |name| { + id.value.name = String.createAtomIfPossible(name.slice()); + }, + .index => |index| { + id.value.index = index; + }, + .duplicate => unreachable, + } + id.tag = switch (field.name_or_index) { + .name => 2, + .index => 1, + .duplicate => 0, + }; + i += 1; + } + + if (nonDuplicatedCount > JSC.JSObject.maxInlineCapacity()) { + this.cached_structure.set(globalObject, null, ids); + } else { + this.cached_structure.set(globalObject, JSC.JSObject.createStructure( + globalObject, + owner, + @truncate(ids.len), + ids.ptr, + ), null); + } + + return this.cached_structure; +} + +const debug = bun.Output.scoped(.Postgres, false); + +// @sortImports + +const PostgresCachedStructure = @import("./PostgresCachedStructure.zig"); +const PostgresSQLStatement = @This(); +const Signature = @import("./Signature.zig"); +const protocol = @import("./PostgresProtocol.zig"); +const std = @import("std"); +const DataCell = @import("./DataCell.zig").DataCell; + +const AnyPostgresError = @import("./AnyPostgresError.zig").AnyPostgresError; +const postgresErrorToJS = @import("./AnyPostgresError.zig").postgresErrorToJS; + +const types = @import("./PostgresTypes.zig"); +const int4 = types.int4; + +const bun = @import("bun"); +const String = bun.String; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/PostgresTypes.zig b/src/sql/postgres/PostgresTypes.zig new file mode 100644 index 0000000000..d21d7ce483 --- /dev/null +++ b/src/sql/postgres/PostgresTypes.zig @@ -0,0 +1,20 @@ +pub const @"bool" = @import("./types/bool.zig"); + +// @sortImports + +pub const bytea = @import("./types/bytea.zig"); +pub const date = @import("./types/date.zig"); +pub const json = @import("./types/json.zig"); +pub const numeric = @import("./types/numeric.zig"); +pub const string = @import("./types/PostgresString.zig"); +pub const AnyPostgresError = @import("./AnyPostgresError.zig").AnyPostgresError; +pub const Tag = @import("./types/Tag.zig").Tag; + +const int_types = @import("./types/int_types.zig"); +pub const Int32 = int_types.Int32; +pub const PostgresInt32 = int_types.int4; +pub const PostgresInt64 = int_types.int8; +pub const PostgresShort = int_types.short; +pub const int4 = int_types.int4; +pub const int8 = int_types.int8; +pub const short = int_types.short; diff --git a/src/sql/postgres/QueryBindingIterator.zig b/src/sql/postgres/QueryBindingIterator.zig new file mode 100644 index 0000000000..af2e3e78fb --- /dev/null +++ b/src/sql/postgres/QueryBindingIterator.zig @@ -0,0 +1,66 @@ +pub const QueryBindingIterator = union(enum) { + array: JSC.JSArrayIterator, + objects: ObjectIterator, + + pub fn init(array: JSValue, columns: JSValue, globalObject: *JSC.JSGlobalObject) bun.JSError!QueryBindingIterator { + if (columns.isEmptyOrUndefinedOrNull()) { + return .{ .array = try JSC.JSArrayIterator.init(array, globalObject) }; + } + + return .{ + .objects = .{ + .array = array, + .columns = columns, + .globalObject = globalObject, + .columns_count = try columns.getLength(globalObject), + .array_length = try array.getLength(globalObject), + }, + }; + } + + pub fn next(this: *QueryBindingIterator) bun.JSError!?JSC.JSValue { + return switch (this.*) { + .array => |*iter| iter.next(), + .objects => |*iter| iter.next(), + }; + } + + pub fn anyFailed(this: *const QueryBindingIterator) bool { + return switch (this.*) { + .array => false, + .objects => |*iter| iter.any_failed, + }; + } + + pub fn to(this: *QueryBindingIterator, index: u32) void { + switch (this.*) { + .array => |*iter| iter.i = index, + .objects => |*iter| { + iter.cell_i = index % iter.columns_count; + iter.row_i = index / iter.columns_count; + iter.current_row = .zero; + }, + } + } + + pub fn reset(this: *QueryBindingIterator) void { + switch (this.*) { + .array => |*iter| { + iter.i = 0; + }, + .objects => |*iter| { + iter.cell_i = 0; + iter.row_i = 0; + iter.current_row = .zero; + }, + } + } +}; + +// @sortImports + +const ObjectIterator = @import("./ObjectIterator.zig"); +const bun = @import("bun"); + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/SASL.zig b/src/sql/postgres/SASL.zig new file mode 100644 index 0000000000..6fb482f40a --- /dev/null +++ b/src/sql/postgres/SASL.zig @@ -0,0 +1,95 @@ +const nonce_byte_len = 18; +const nonce_base64_len = bun.base64.encodeLenFromSize(nonce_byte_len); + +const server_signature_byte_len = 32; +const server_signature_base64_len = bun.base64.encodeLenFromSize(server_signature_byte_len); + +const salted_password_byte_len = 32; + +nonce_base64_bytes: [nonce_base64_len]u8 = .{0} ** nonce_base64_len, +nonce_len: u8 = 0, + +server_signature_base64_bytes: [server_signature_base64_len]u8 = .{0} ** server_signature_base64_len, +server_signature_len: u8 = 0, + +salted_password_bytes: [salted_password_byte_len]u8 = .{0} ** salted_password_byte_len, +salted_password_created: bool = false, + +status: SASLStatus = .init, + +pub const SASLStatus = enum { + init, + @"continue", +}; + +fn hmac(password: []const u8, data: []const u8) ?[32]u8 { + var buf = std.mem.zeroes([bun.BoringSSL.c.EVP_MAX_MD_SIZE]u8); + + // TODO: I don't think this is failable. + const result = bun.hmac.generate(password, data, .sha256, &buf) orelse return null; + + assert(result.len == 32); + return buf[0..32].*; +} + +pub fn computeSaltedPassword(this: *SASL, salt_bytes: []const u8, iteration_count: u32, connection: *PostgresSQLConnection) !void { + this.salted_password_created = true; + if (Crypto.EVP.pbkdf2(&this.salted_password_bytes, connection.password, salt_bytes, iteration_count, .sha256) == null) { + return error.PBKDFD2; + } +} + +pub fn saltedPassword(this: *const SASL) []const u8 { + assert(this.salted_password_created); + return this.salted_password_bytes[0..salted_password_byte_len]; +} + +pub fn serverSignature(this: *const SASL) []const u8 { + assert(this.server_signature_len > 0); + return this.server_signature_base64_bytes[0..this.server_signature_len]; +} + +pub fn computeServerSignature(this: *SASL, auth_string: []const u8) !void { + assert(this.server_signature_len == 0); + + const server_key = hmac(this.saltedPassword(), "Server Key") orelse return error.InvalidServerKey; + const server_signature_bytes = hmac(&server_key, auth_string) orelse return error.InvalidServerSignature; + this.server_signature_len = @intCast(bun.base64.encode(&this.server_signature_base64_bytes, &server_signature_bytes)); +} + +pub fn clientKey(this: *const SASL) [32]u8 { + return hmac(this.saltedPassword(), "Client Key").?; +} + +pub fn clientKeySignature(_: *const SASL, client_key: []const u8, auth_string: []const u8) [32]u8 { + var sha_digest = std.mem.zeroes(bun.sha.SHA256.Digest); + bun.sha.SHA256.hash(client_key, &sha_digest, JSC.VirtualMachine.get().rareData().boringEngine()); + return hmac(&sha_digest, auth_string).?; +} + +pub fn nonce(this: *SASL) []const u8 { + if (this.nonce_len == 0) { + var bytes: [nonce_byte_len]u8 = .{0} ** nonce_byte_len; + bun.csprng(&bytes); + this.nonce_len = @intCast(bun.base64.encode(&this.nonce_base64_bytes, &bytes)); + } + return this.nonce_base64_bytes[0..this.nonce_len]; +} + +pub fn deinit(this: *SASL) void { + this.nonce_len = 0; + this.salted_password_created = false; + this.server_signature_len = 0; + this.status = .init; +} + +// @sortImports + +const PostgresSQLConnection = @import("./PostgresSQLConnection.zig"); +const SASL = @This(); +const std = @import("std"); + +const bun = @import("bun"); +const JSC = bun.JSC; +const assert = bun.assert; +const Crypto = JSC.API.Bun.Crypto; diff --git a/src/sql/postgres/SSLMode.zig b/src/sql/postgres/SSLMode.zig new file mode 100644 index 0000000000..adc78ff605 --- /dev/null +++ b/src/sql/postgres/SSLMode.zig @@ -0,0 +1,9 @@ +pub const SSLMode = enum(u8) { + disable = 0, + prefer = 1, + require = 2, + verify_ca = 3, + verify_full = 4, +}; + +// @sortImports diff --git a/src/sql/postgres/Signature.zig b/src/sql/postgres/Signature.zig new file mode 100644 index 0000000000..37720ef626 --- /dev/null +++ b/src/sql/postgres/Signature.zig @@ -0,0 +1,113 @@ +fields: []const int4, +name: []const u8, +query: []const u8, +prepared_statement_name: []const u8, + +pub fn empty() Signature { + return Signature{ + .fields = &[_]int4{}, + .name = &[_]u8{}, + .query = &[_]u8{}, + .prepared_statement_name = &[_]u8{}, + }; +} + +pub fn deinit(this: *Signature) void { + if (this.prepared_statement_name.len > 0) { + bun.default_allocator.free(this.prepared_statement_name); + } + if (this.name.len > 0) { + bun.default_allocator.free(this.name); + } + if (this.fields.len > 0) { + bun.default_allocator.free(this.fields); + } + if (this.query.len > 0) { + bun.default_allocator.free(this.query); + } +} + +pub fn hash(this: *const Signature) u64 { + var hasher = std.hash.Wyhash.init(0); + hasher.update(this.name); + hasher.update(std.mem.sliceAsBytes(this.fields)); + return hasher.final(); +} + +pub fn generate(globalObject: *JSC.JSGlobalObject, query: []const u8, array_value: JSValue, columns: JSValue, prepared_statement_id: u64, unnamed: bool) !Signature { + var fields = std.ArrayList(int4).init(bun.default_allocator); + var name = try std.ArrayList(u8).initCapacity(bun.default_allocator, query.len); + + name.appendSliceAssumeCapacity(query); + + errdefer { + fields.deinit(); + name.deinit(); + } + + var iter = try QueryBindingIterator.init(array_value, columns, globalObject); + + while (try iter.next()) |value| { + if (value.isEmptyOrUndefinedOrNull()) { + // Allow postgres to decide the type + try fields.append(0); + try name.appendSlice(".null"); + continue; + } + + const tag = try types.Tag.fromJS(globalObject, value); + + switch (tag) { + .int8 => try name.appendSlice(".int8"), + .int4 => try name.appendSlice(".int4"), + // .int4_array => try name.appendSlice(".int4_array"), + .int2 => try name.appendSlice(".int2"), + .float8 => try name.appendSlice(".float8"), + .float4 => try name.appendSlice(".float4"), + .numeric => try name.appendSlice(".numeric"), + .json, .jsonb => try name.appendSlice(".json"), + .bool => try name.appendSlice(".bool"), + .timestamp => try name.appendSlice(".timestamp"), + .timestamptz => try name.appendSlice(".timestamptz"), + .bytea => try name.appendSlice(".bytea"), + else => try name.appendSlice(".string"), + } + + switch (tag) { + .bool, .int4, .int8, .float8, .int2, .numeric, .float4, .bytea => { + // We decide the type + try fields.append(@intFromEnum(tag)); + }, + else => { + // Allow postgres to decide the type + try fields.append(0); + }, + } + } + + if (iter.anyFailed()) { + return error.InvalidQueryBinding; + } + // max u64 length is 20, max prepared_statement_name length is 63 + const prepared_statement_name = if (unnamed) "" else try std.fmt.allocPrint(bun.default_allocator, "P{s}${d}", .{ name.items[0..@min(40, name.items.len)], prepared_statement_id }); + + return Signature{ + .prepared_statement_name = prepared_statement_name, + .name = name.items, + .fields = fields.items, + .query = try bun.default_allocator.dupe(u8, query), + }; +} + +// @sortImports + +const Signature = @This(); +const bun = @import("bun"); +const std = @import("std"); +const QueryBindingIterator = @import("./QueryBindingIterator.zig").QueryBindingIterator; + +const types = @import("./PostgresTypes.zig"); +const int4 = types.int4; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/SocketMonitor.zig b/src/sql/postgres/SocketMonitor.zig new file mode 100644 index 0000000000..7765d31017 --- /dev/null +++ b/src/sql/postgres/SocketMonitor.zig @@ -0,0 +1,23 @@ +pub fn write(data: []const u8) void { + if (comptime bun.Environment.isDebug) { + DebugSocketMonitorWriter.check.call(); + if (DebugSocketMonitorWriter.enabled) { + DebugSocketMonitorWriter.write(data); + } + } +} + +pub fn read(data: []const u8) void { + if (comptime bun.Environment.isDebug) { + DebugSocketMonitorReader.check.call(); + if (DebugSocketMonitorReader.enabled) { + DebugSocketMonitorReader.write(data); + } + } +} + +// @sortImports + +const DebugSocketMonitorReader = @import("./DebugSocketMonitorReader.zig"); +const DebugSocketMonitorWriter = @import("./DebugSocketMonitorWriter.zig"); +const bun = @import("bun"); diff --git a/src/sql/postgres/Status.zig b/src/sql/postgres/Status.zig new file mode 100644 index 0000000000..f4a0e9290e --- /dev/null +++ b/src/sql/postgres/Status.zig @@ -0,0 +1,11 @@ +pub const Status = enum { + disconnected, + connecting, + // Prevent sending the startup message multiple times. + // Particularly relevant for TLS connections. + sent_startup_message, + connected, + failed, +}; + +// @sortImports diff --git a/src/sql/postgres/TLSStatus.zig b/src/sql/postgres/TLSStatus.zig new file mode 100644 index 0000000000..a711af013a --- /dev/null +++ b/src/sql/postgres/TLSStatus.zig @@ -0,0 +1,11 @@ +pub const TLSStatus = union(enum) { + none, + pending, + + /// Number of bytes sent of the 8-byte SSL request message. + /// Since we may send a partial message, we need to know how many bytes were sent. + message_sent: u8, + + ssl_not_available, + ssl_ok, +}; diff --git a/src/sql/postgres/postgres_protocol.zig b/src/sql/postgres/postgres_protocol.zig deleted file mode 100644 index 111e96b0ac..0000000000 --- a/src/sql/postgres/postgres_protocol.zig +++ /dev/null @@ -1,1551 +0,0 @@ -const std = @import("std"); -const bun = @import("bun"); -const postgres = bun.api.Postgres; -const Data = postgres.Data; -const protocol = @This(); -const PostgresInt32 = postgres.PostgresInt32; -const PostgresShort = postgres.PostgresShort; -const String = bun.String; -const debug = postgres.debug; -const JSValue = JSC.JSValue; -const JSC = bun.JSC; -const short = postgres.short; -const int4 = postgres.int4; -const int8 = postgres.int8; -const PostgresInt64 = postgres.PostgresInt64; -const types = postgres.types; -const AnyPostgresError = postgres.AnyPostgresError; -pub const ArrayList = struct { - array: *std.ArrayList(u8), - - pub fn offset(this: @This()) usize { - return this.array.items.len; - } - - pub fn write(this: @This(), bytes: []const u8) AnyPostgresError!void { - try this.array.appendSlice(bytes); - } - - pub fn pwrite(this: @This(), bytes: []const u8, i: usize) AnyPostgresError!void { - @memcpy(this.array.items[i..][0..bytes.len], bytes); - } - - pub const Writer = NewWriter(@This()); -}; - -pub const StackReader = struct { - buffer: []const u8 = "", - offset: *usize, - message_start: *usize, - - pub fn markMessageStart(this: @This()) void { - this.message_start.* = this.offset.*; - } - - pub fn ensureLength(this: @This(), length: usize) bool { - return this.buffer.len >= (this.offset.* + length); - } - - pub fn init(buffer: []const u8, offset: *usize, message_start: *usize) protocol.NewReader(StackReader) { - return .{ - .wrapped = .{ - .buffer = buffer, - .offset = offset, - .message_start = message_start, - }, - }; - } - - pub fn peek(this: StackReader) []const u8 { - return this.buffer[this.offset.*..]; - } - pub fn skip(this: StackReader, count: usize) void { - if (this.offset.* + count > this.buffer.len) { - this.offset.* = this.buffer.len; - return; - } - - this.offset.* += count; - } - pub fn ensureCapacity(this: StackReader, count: usize) bool { - return this.buffer.len >= (this.offset.* + count); - } - pub fn read(this: StackReader, count: usize) AnyPostgresError!Data { - const offset = this.offset.*; - if (!this.ensureCapacity(count)) { - return error.ShortRead; - } - - this.skip(count); - return Data{ - .temporary = this.buffer[offset..this.offset.*], - }; - } - pub fn readZ(this: StackReader) AnyPostgresError!Data { - const remaining = this.peek(); - if (bun.strings.indexOfChar(remaining, 0)) |zero| { - this.skip(zero + 1); - return Data{ - .temporary = remaining[0..zero], - }; - } - - return error.ShortRead; - } -}; - -pub fn NewWriterWrap( - comptime Context: type, - comptime offsetFn_: (fn (ctx: Context) usize), - comptime writeFunction_: (fn (ctx: Context, bytes: []const u8) AnyPostgresError!void), - comptime pwriteFunction_: (fn (ctx: Context, bytes: []const u8, offset: usize) AnyPostgresError!void), -) type { - return struct { - wrapped: Context, - - const writeFn = writeFunction_; - const pwriteFn = pwriteFunction_; - const offsetFn = offsetFn_; - pub const Ctx = Context; - - pub const WrappedWriter = @This(); - - pub inline fn write(this: @This(), data: []const u8) AnyPostgresError!void { - try writeFn(this.wrapped, data); - } - - pub const LengthWriter = struct { - index: usize, - context: WrappedWriter, - - pub fn write(this: LengthWriter) AnyPostgresError!void { - try this.context.pwrite(&Int32(this.context.offset() - this.index), this.index); - } - - pub fn writeExcludingSelf(this: LengthWriter) AnyPostgresError!void { - try this.context.pwrite(&Int32(this.context.offset() -| (this.index + 4)), this.index); - } - }; - - pub inline fn length(this: @This()) AnyPostgresError!LengthWriter { - const i = this.offset(); - try this.int4(0); - return LengthWriter{ - .index = i, - .context = this, - }; - } - - pub inline fn offset(this: @This()) usize { - return offsetFn(this.wrapped); - } - - pub inline fn pwrite(this: @This(), data: []const u8, i: usize) AnyPostgresError!void { - try pwriteFn(this.wrapped, data, i); - } - - pub fn int4(this: @This(), value: PostgresInt32) !void { - try this.write(std.mem.asBytes(&@byteSwap(value))); - } - - pub fn int8(this: @This(), value: PostgresInt64) !void { - try this.write(std.mem.asBytes(&@byteSwap(value))); - } - - pub fn sint4(this: @This(), value: i32) !void { - try this.write(std.mem.asBytes(&@byteSwap(value))); - } - - pub fn @"f64"(this: @This(), value: f64) !void { - try this.write(std.mem.asBytes(&@byteSwap(@as(u64, @bitCast(value))))); - } - - pub fn @"f32"(this: @This(), value: f32) !void { - try this.write(std.mem.asBytes(&@byteSwap(@as(u32, @bitCast(value))))); - } - - pub fn short(this: @This(), value: anytype) !void { - try this.write(std.mem.asBytes(&@byteSwap(@as(u16, @intCast(value))))); - } - - pub fn string(this: @This(), value: []const u8) !void { - try this.write(value); - if (value.len == 0 or value[value.len - 1] != 0) - try this.write(&[_]u8{0}); - } - - pub fn bytes(this: @This(), value: []const u8) !void { - try this.write(value); - if (value.len == 0 or value[value.len - 1] != 0) - try this.write(&[_]u8{0}); - } - - pub fn @"bool"(this: @This(), value: bool) !void { - try this.write(if (value) "t" else "f"); - } - - pub fn @"null"(this: @This()) !void { - try this.int4(std.math.maxInt(PostgresInt32)); - } - - pub fn String(this: @This(), value: bun.String) !void { - if (value.isEmpty()) { - try this.write(&[_]u8{0}); - return; - } - - var sliced = value.toUTF8(bun.default_allocator); - defer sliced.deinit(); - const slice = sliced.slice(); - - try this.write(slice); - if (slice.len == 0 or slice[slice.len - 1] != 0) - try this.write(&[_]u8{0}); - } - }; -} - -pub const FieldType = enum(u8) { - /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a localized translation of one of these. Always present. - severity = 'S', - - /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message). This is identical to the S field except that the contents are never localized. This is present only in messages generated by PostgreSQL versions 9.6 and later. - localized_severity = 'V', - - /// Code: the SQLSTATE code for the error (see Appendix A). Not localizable. Always present. - code = 'C', - - /// Message: the primary human-readable error message. This should be accurate but terse (typically one line). Always present. - message = 'M', - - /// Detail: an optional secondary error message carrying more detail about the problem. Might run to multiple lines. - detail = 'D', - - /// Hint: an optional suggestion what to do about the problem. This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. Might run to multiple lines. - hint = 'H', - - /// Position: the field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. The first character has index 1, and positions are measured in characters not bytes. - position = 'P', - - /// Internal position: this is defined the same as the P field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. The q field will always appear when this field appears. - internal_position = 'p', - - /// Internal query: the text of a failed internally-generated command. This could be, for example, an SQL query issued by a PL/pgSQL function. - internal = 'q', - - /// Where: an indication of the context in which the error occurred. Presently this includes a call stack traceback of active procedural language functions and internally-generated queries. The trace is one entry per line, most recent first. - where = 'W', - - /// Schema name: if the error was associated with a specific database object, the name of the schema containing that object, if any. - schema = 's', - - /// Table name: if the error was associated with a specific table, the name of the table. (Refer to the schema name field for the name of the table's schema.) - table = 't', - - /// Column name: if the error was associated with a specific table column, the name of the column. (Refer to the schema and table name fields to identify the table.) - column = 'c', - - /// Data type name: if the error was associated with a specific data type, the name of the data type. (Refer to the schema name field for the name of the data type's schema.) - datatype = 'd', - - /// Constraint name: if the error was associated with a specific constraint, the name of the constraint. Refer to fields listed above for the associated table or domain. (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) - constraint = 'n', - - /// File: the file name of the source-code location where the error was reported. - file = 'F', - - /// Line: the line number of the source-code location where the error was reported. - line = 'L', - - /// Routine: the name of the source-code routine reporting the error. - routine = 'R', - - _, -}; - -pub const FieldMessage = union(FieldType) { - severity: String, - localized_severity: String, - code: String, - message: String, - detail: String, - hint: String, - position: String, - internal_position: String, - internal: String, - where: String, - schema: String, - table: String, - column: String, - datatype: String, - constraint: String, - file: String, - line: String, - routine: String, - - pub fn format(this: FieldMessage, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { - switch (this) { - inline else => |str| { - try std.fmt.format(writer, "{}", .{str}); - }, - } - } - - pub fn deinit(this: *FieldMessage) void { - switch (this.*) { - inline else => |*message| { - message.deref(); - }, - } - } - - pub fn decodeList(comptime Context: type, reader: NewReader(Context)) !std.ArrayListUnmanaged(FieldMessage) { - var messages = std.ArrayListUnmanaged(FieldMessage){}; - while (true) { - const field_int = try reader.int(u8); - if (field_int == 0) break; - const field: FieldType = @enumFromInt(field_int); - - var message = try reader.readZ(); - defer message.deinit(); - if (message.slice().len == 0) break; - - try messages.append(bun.default_allocator, FieldMessage.init(field, message.slice()) catch continue); - } - - return messages; - } - - pub fn init(tag: FieldType, message: []const u8) !FieldMessage { - return switch (tag) { - .severity => FieldMessage{ .severity = String.createUTF8(message) }, - // Ignore this one for now. - // .localized_severity => FieldMessage{ .localized_severity = String.createUTF8(message) }, - .code => FieldMessage{ .code = String.createUTF8(message) }, - .message => FieldMessage{ .message = String.createUTF8(message) }, - .detail => FieldMessage{ .detail = String.createUTF8(message) }, - .hint => FieldMessage{ .hint = String.createUTF8(message) }, - .position => FieldMessage{ .position = String.createUTF8(message) }, - .internal_position => FieldMessage{ .internal_position = String.createUTF8(message) }, - .internal => FieldMessage{ .internal = String.createUTF8(message) }, - .where => FieldMessage{ .where = String.createUTF8(message) }, - .schema => FieldMessage{ .schema = String.createUTF8(message) }, - .table => FieldMessage{ .table = String.createUTF8(message) }, - .column => FieldMessage{ .column = String.createUTF8(message) }, - .datatype => FieldMessage{ .datatype = String.createUTF8(message) }, - .constraint => FieldMessage{ .constraint = String.createUTF8(message) }, - .file => FieldMessage{ .file = String.createUTF8(message) }, - .line => FieldMessage{ .line = String.createUTF8(message) }, - .routine => FieldMessage{ .routine = String.createUTF8(message) }, - else => error.UnknownFieldType, - }; - } -}; - -pub fn NewReaderWrap( - comptime Context: type, - comptime markMessageStartFn_: (fn (ctx: Context) void), - comptime peekFn_: (fn (ctx: Context) []const u8), - comptime skipFn_: (fn (ctx: Context, count: usize) void), - comptime ensureCapacityFn_: (fn (ctx: Context, count: usize) bool), - comptime readFunction_: (fn (ctx: Context, count: usize) AnyPostgresError!Data), - comptime readZ_: (fn (ctx: Context) AnyPostgresError!Data), -) type { - return struct { - wrapped: Context, - const readFn = readFunction_; - const readZFn = readZ_; - const ensureCapacityFn = ensureCapacityFn_; - const skipFn = skipFn_; - const peekFn = peekFn_; - const markMessageStartFn = markMessageStartFn_; - - pub const Ctx = Context; - - pub inline fn markMessageStart(this: @This()) void { - markMessageStartFn(this.wrapped); - } - - pub inline fn read(this: @This(), count: usize) AnyPostgresError!Data { - return try readFn(this.wrapped, count); - } - - pub inline fn eatMessage(this: @This(), comptime msg_: anytype) AnyPostgresError!void { - const msg = msg_[1..]; - try this.ensureCapacity(msg.len); - - var input = try readFn(this.wrapped, msg.len); - defer input.deinit(); - if (bun.strings.eqlComptime(input.slice(), msg)) return; - return error.InvalidMessage; - } - - pub fn skip(this: @This(), count: usize) AnyPostgresError!void { - skipFn(this.wrapped, count); - } - - pub fn peek(this: @This()) []const u8 { - return peekFn(this.wrapped); - } - - pub inline fn readZ(this: @This()) AnyPostgresError!Data { - return try readZFn(this.wrapped); - } - - pub inline fn ensureCapacity(this: @This(), count: usize) AnyPostgresError!void { - if (!ensureCapacityFn(this.wrapped, count)) { - return error.ShortRead; - } - } - - pub fn int(this: @This(), comptime Int: type) !Int { - var data = try this.read(@sizeOf((Int))); - defer data.deinit(); - if (comptime Int == u8) { - return @as(Int, data.slice()[0]); - } - return @byteSwap(@as(Int, @bitCast(data.slice()[0..@sizeOf(Int)].*))); - } - - pub fn peekInt(this: @This(), comptime Int: type) ?Int { - const remain = this.peek(); - if (remain.len < @sizeOf(Int)) { - return null; - } - return @byteSwap(@as(Int, @bitCast(remain[0..@sizeOf(Int)].*))); - } - - pub fn expectInt(this: @This(), comptime Int: type, comptime value: comptime_int) !bool { - const actual = try this.int(Int); - return actual == value; - } - - pub fn int4(this: @This()) !PostgresInt32 { - return this.int(PostgresInt32); - } - - pub fn short(this: @This()) !PostgresShort { - return this.int(PostgresShort); - } - - pub fn length(this: @This()) !PostgresInt32 { - const expected = try this.int(PostgresInt32); - if (expected > -1) { - try this.ensureCapacity(@intCast(expected -| 4)); - } - - return expected; - } - - pub const bytes = read; - - pub fn String(this: @This()) !bun.String { - var result = try this.readZ(); - defer result.deinit(); - return bun.String.fromUTF8(result.slice()); - } - }; -} - -pub fn NewReader(comptime Context: type) type { - return NewReaderWrap(Context, Context.markMessageStart, Context.peek, Context.skip, Context.ensureLength, Context.read, Context.readZ); -} - -pub fn NewWriter(comptime Context: type) type { - return NewWriterWrap(Context, Context.offset, Context.write, Context.pwrite); -} - -fn decoderWrap(comptime Container: type, comptime decodeFn: anytype) type { - return struct { - pub fn decode(this: *Container, context: anytype) AnyPostgresError!void { - const Context = @TypeOf(context); - try decodeFn(this, Context, NewReader(Context){ .wrapped = context }); - } - }; -} - -fn writeWrap(comptime Container: type, comptime writeFn: anytype) type { - return struct { - pub fn write(this: *Container, context: anytype) AnyPostgresError!void { - const Context = @TypeOf(context); - try writeFn(this, Context, NewWriter(Context){ .wrapped = context }); - } - }; -} - -pub const Authentication = union(enum) { - Ok: void, - ClearTextPassword: struct {}, - MD5Password: struct { - salt: [4]u8, - }, - KerberosV5: struct {}, - SCMCredential: struct {}, - GSS: struct {}, - GSSContinue: struct { - data: Data, - }, - SSPI: struct {}, - SASL: struct {}, - SASLContinue: struct { - data: Data, - r: []const u8, - s: []const u8, - i: []const u8, - - pub fn iterationCount(this: *const @This()) !u32 { - return try std.fmt.parseInt(u32, this.i, 0); - } - }, - SASLFinal: struct { - data: Data, - }, - Unknown: void, - - pub fn deinit(this: *@This()) void { - switch (this.*) { - .MD5Password => {}, - .SASL => {}, - .SASLContinue => { - this.SASLContinue.data.zdeinit(); - }, - .SASLFinal => { - this.SASLFinal.data.zdeinit(); - }, - else => {}, - } - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const message_length = try reader.length(); - - switch (try reader.int4()) { - 0 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ .Ok = {} }; - }, - 2 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ - .KerberosV5 = .{}, - }; - }, - 3 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ - .ClearTextPassword = .{}, - }; - }, - 5 => { - if (message_length != 12) return error.InvalidMessageLength; - var salt_data = try reader.bytes(4); - defer salt_data.deinit(); - this.* = .{ - .MD5Password = .{ - .salt = salt_data.slice()[0..4].*, - }, - }; - }, - 7 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ - .GSS = .{}, - }; - }, - - 8 => { - if (message_length < 9) return error.InvalidMessageLength; - const bytes = try reader.read(message_length - 8); - this.* = .{ - .GSSContinue = .{ - .data = bytes, - }, - }; - }, - 9 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ - .SSPI = .{}, - }; - }, - - 10 => { - if (message_length < 9) return error.InvalidMessageLength; - try reader.skip(message_length - 8); - this.* = .{ - .SASL = .{}, - }; - }, - - 11 => { - if (message_length < 9) return error.InvalidMessageLength; - var bytes = try reader.bytes(message_length - 8); - errdefer { - bytes.deinit(); - } - - var iter = bun.strings.split(bytes.slice(), ","); - var r: ?[]const u8 = null; - var i: ?[]const u8 = null; - var s: ?[]const u8 = null; - - while (iter.next()) |item| { - if (item.len > 2) { - const key = item[0]; - const after_equals = item[2..]; - if (key == 'r') { - r = after_equals; - } else if (key == 's') { - s = after_equals; - } else if (key == 'i') { - i = after_equals; - } - } - } - - if (r == null) { - debug("Missing r", .{}); - } - - if (s == null) { - debug("Missing s", .{}); - } - - if (i == null) { - debug("Missing i", .{}); - } - - this.* = .{ - .SASLContinue = .{ - .data = bytes, - .r = r orelse return error.InvalidMessage, - .s = s orelse return error.InvalidMessage, - .i = i orelse return error.InvalidMessage, - }, - }; - }, - - 12 => { - if (message_length < 9) return error.InvalidMessageLength; - const remaining: usize = message_length - 8; - - const bytes = try reader.read(remaining); - this.* = .{ - .SASLFinal = .{ - .data = bytes, - }, - }; - }, - - else => { - this.* = .{ .Unknown = {} }; - }, - } - } - - pub const decode = decoderWrap(Authentication, decodeInternal).decode; -}; - -pub const ParameterStatus = struct { - name: Data = .{ .empty = {} }, - value: Data = .{ .empty = {} }, - - pub fn deinit(this: *@This()) void { - this.name.deinit(); - this.value.deinit(); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - this.* = .{ - .name = try reader.readZ(), - .value = try reader.readZ(), - }; - } - - pub const decode = decoderWrap(ParameterStatus, decodeInternal).decode; -}; - -pub const BackendKeyData = struct { - process_id: u32 = 0, - secret_key: u32 = 0, - pub const decode = decoderWrap(BackendKeyData, decodeInternal).decode; - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - if (!try reader.expectInt(u32, 12)) { - return error.InvalidBackendKeyData; - } - - this.* = .{ - .process_id = @bitCast(try reader.int4()), - .secret_key = @bitCast(try reader.int4()), - }; - } -}; - -pub const ErrorResponse = struct { - messages: std.ArrayListUnmanaged(FieldMessage) = .{}, - - pub fn format(formatter: ErrorResponse, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { - for (formatter.messages.items) |message| { - try std.fmt.format(writer, "{}\n", .{message}); - } - } - - pub fn deinit(this: *ErrorResponse) void { - for (this.messages.items) |*message| { - message.deinit(); - } - this.messages.deinit(bun.default_allocator); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var remaining_bytes = try reader.length(); - if (remaining_bytes < 4) return error.InvalidMessageLength; - remaining_bytes -|= 4; - - if (remaining_bytes > 0) { - this.* = .{ - .messages = try FieldMessage.decodeList(Container, reader), - }; - } - } - - pub const decode = decoderWrap(ErrorResponse, decodeInternal).decode; - - pub fn toJS(this: ErrorResponse, globalObject: *JSC.JSGlobalObject) JSValue { - var b = bun.StringBuilder{}; - defer b.deinit(bun.default_allocator); - - // Pre-calculate capacity to avoid reallocations - for (this.messages.items) |*msg| { - b.cap += switch (msg.*) { - inline else => |m| m.utf8ByteLength(), - } + 1; - } - b.allocate(bun.default_allocator) catch {}; - - // Build a more structured error message - var severity: String = String.dead; - var code: String = String.dead; - var message: String = String.dead; - var detail: String = String.dead; - var hint: String = String.dead; - var position: String = String.dead; - var where: String = String.dead; - var schema: String = String.dead; - var table: String = String.dead; - var column: String = String.dead; - var datatype: String = String.dead; - var constraint: String = String.dead; - var file: String = String.dead; - var line: String = String.dead; - var routine: String = String.dead; - - for (this.messages.items) |*msg| { - switch (msg.*) { - .severity => |str| severity = str, - .code => |str| code = str, - .message => |str| message = str, - .detail => |str| detail = str, - .hint => |str| hint = str, - .position => |str| position = str, - .where => |str| where = str, - .schema => |str| schema = str, - .table => |str| table = str, - .column => |str| column = str, - .datatype => |str| datatype = str, - .constraint => |str| constraint = str, - .file => |str| file = str, - .line => |str| line = str, - .routine => |str| routine = str, - else => {}, - } - } - - var needs_newline = false; - construct_message: { - if (!message.isEmpty()) { - _ = b.appendStr(message); - needs_newline = true; - break :construct_message; - } - if (!detail.isEmpty()) { - if (needs_newline) { - _ = b.append("\n"); - } else { - _ = b.append(" "); - } - needs_newline = true; - _ = b.appendStr(detail); - } - if (!hint.isEmpty()) { - if (needs_newline) { - _ = b.append("\n"); - } else { - _ = b.append(" "); - } - needs_newline = true; - _ = b.appendStr(hint); - } - } - - const possible_fields = .{ - .{ "detail", detail, void }, - .{ "hint", hint, void }, - .{ "column", column, void }, - .{ "constraint", constraint, void }, - .{ "datatype", datatype, void }, - // in the past this was set to i32 but postgres returns a strings lets keep it compatible - .{ "errno", code, void }, - .{ "position", position, i32 }, - .{ "schema", schema, void }, - .{ "table", table, void }, - .{ "where", where, void }, - }; - const error_code: JSC.Error = - // https://www.postgresql.org/docs/8.1/errcodes-appendix.html - if (code.eqlComptime("42601")) - .POSTGRES_SYNTAX_ERROR - else - .POSTGRES_SERVER_ERROR; - const err = error_code.fmt(globalObject, "{s}", .{b.allocatedSlice()[0..b.len]}); - - inline for (possible_fields) |field| { - if (!field.@"1".isEmpty()) { - const value = brk: { - if (field.@"2" == i32) { - if (field.@"1".toInt32()) |val| { - break :brk JSC.JSValue.jsNumberFromInt32(val); - } - } - - break :brk field.@"1".toJS(globalObject); - }; - - err.put(globalObject, JSC.ZigString.static(field.@"0"), value); - } - } - - return err; - } -}; - -pub const PortalOrPreparedStatement = union(enum) { - portal: []const u8, - prepared_statement: []const u8, - - pub fn slice(this: @This()) []const u8 { - return switch (this) { - .portal => this.portal, - .prepared_statement => this.prepared_statement, - }; - } - - pub fn tag(this: @This()) u8 { - return switch (this) { - .portal => 'P', - .prepared_statement => 'S', - }; - } -}; - -/// Close (F) -/// Byte1('C') -/// - Identifies the message as a Close command. -/// Int32 -/// - Length of message contents in bytes, including self. -/// Byte1 -/// - 'S' to close a prepared statement; or 'P' to close a portal. -/// String -/// - The name of the prepared statement or portal to close (an empty string selects the unnamed prepared statement or portal). -pub const Close = struct { - p: PortalOrPreparedStatement, - - fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const p = this.p; - const count: u32 = @sizeOf((u32)) + 1 + p.slice().len + 1; - const header = [_]u8{ - 'C', - } ++ @byteSwap(count) ++ [_]u8{ - p.tag(), - }; - try writer.write(&header); - try writer.write(p.slice()); - try writer.write(&[_]u8{0}); - } - - pub const write = writeWrap(@This(), writeInternal); -}; - -pub const CloseComplete = [_]u8{'3'} ++ toBytes(Int32(4)); -pub const EmptyQueryResponse = [_]u8{'I'} ++ toBytes(Int32(4)); -pub const Terminate = [_]u8{'X'} ++ toBytes(Int32(4)); - -fn Int32(value: anytype) [4]u8 { - return @bitCast(@byteSwap(@as(int4, @intCast(value)))); -} - -const toBytes = std.mem.toBytes; - -pub const TransactionStatusIndicator = enum(u8) { - /// if idle (not in a transaction block) - I = 'I', - - /// if in a transaction block - T = 'T', - - /// if in a failed transaction block - E = 'E', - - _, -}; - -pub const ReadyForQuery = struct { - status: TransactionStatusIndicator = .I, - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - const status = try reader.int(u8); - this.* = .{ - .status = @enumFromInt(status), - }; - } - - pub const decode = decoderWrap(ReadyForQuery, decodeInternal).decode; -}; - -pub const null_int4 = 4294967295; - -pub const DataRow = struct { - pub fn decode(context: anytype, comptime ContextType: type, reader: NewReader(ContextType), comptime forEach: fn (@TypeOf(context), index: u32, bytes: ?*Data) AnyPostgresError!bool) AnyPostgresError!void { - var remaining_bytes = try reader.length(); - remaining_bytes -|= 4; - - const remaining_fields: usize = @intCast(@max(try reader.short(), 0)); - - for (0..remaining_fields) |index| { - const byte_length = try reader.int4(); - switch (byte_length) { - 0 => { - var empty = Data.Empty; - if (!try forEach(context, @intCast(index), &empty)) break; - }, - null_int4 => { - if (!try forEach(context, @intCast(index), null)) break; - }, - else => { - var bytes = try reader.bytes(@intCast(byte_length)); - if (!try forEach(context, @intCast(index), &bytes)) break; - }, - } - } - } -}; - -pub const BindComplete = [_]u8{'2'} ++ toBytes(Int32(4)); - -pub const ColumnIdentifier = union(enum) { - name: Data, - index: u32, - duplicate: void, - - pub fn init(name: Data) !@This() { - if (switch (name.slice().len) { - 1..."4294967295".len => true, - 0 => return .{ .name = .{ .empty = {} } }, - else => false, - }) might_be_int: { - // use a u64 to avoid overflow - var int: u64 = 0; - for (name.slice()) |byte| { - int = int * 10 + switch (byte) { - '0'...'9' => @as(u64, byte - '0'), - else => break :might_be_int, - }; - } - - // JSC only supports indexed property names up to 2^32 - if (int < std.math.maxInt(u32)) - return .{ .index = @intCast(int) }; - } - - return .{ .name = .{ .owned = try name.toOwned() } }; - } - - pub fn deinit(this: *@This()) void { - switch (this.*) { - .name => |*name| name.deinit(), - else => {}, - } - } -}; -pub const FieldDescription = struct { - /// JavaScriptCore treats numeric property names differently than string property names. - /// so we do the work to figure out if the property name is a number ahead of time. - name_or_index: ColumnIdentifier = .{ - .name = .{ .empty = {} }, - }, - table_oid: int4 = 0, - column_index: short = 0, - type_oid: int4 = 0, - binary: bool = false, - pub fn typeTag(this: @This()) types.Tag { - return @enumFromInt(@as(short, @truncate(this.type_oid))); - } - - pub fn deinit(this: *@This()) void { - this.name_or_index.deinit(); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) AnyPostgresError!void { - var name = try reader.readZ(); - errdefer { - name.deinit(); - } - - // Field name (null-terminated string) - const field_name = try ColumnIdentifier.init(name); - // Table OID (4 bytes) - // If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. - const table_oid = try reader.int4(); - - // Column attribute number (2 bytes) - // If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. - const column_index = try reader.short(); - - // Data type OID (4 bytes) - // The object ID of the field's data type. The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. - const type_oid = try reader.int4(); - - // Data type size (2 bytes) The data type size (see pg_type.typlen). Note that negative values denote variable-width types. - // Type modifier (4 bytes) The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. - try reader.skip(6); - - // Format code (2 bytes) - // The format code being used for the field. Currently will be zero (text) or one (binary). In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero. - const binary = switch (try reader.short()) { - 0 => false, - 1 => true, - else => return error.UnknownFormatCode, - }; - this.* = .{ - .table_oid = table_oid, - .column_index = column_index, - .type_oid = type_oid, - .binary = binary, - .name_or_index = field_name, - }; - } - - pub const decode = decoderWrap(FieldDescription, decodeInternal).decode; -}; - -pub const RowDescription = struct { - fields: []FieldDescription = &[_]FieldDescription{}, - pub fn deinit(this: *@This()) void { - for (this.fields) |*field| { - field.deinit(); - } - - bun.default_allocator.free(this.fields); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var remaining_bytes = try reader.length(); - remaining_bytes -|= 4; - - const field_count: usize = @intCast(@max(try reader.short(), 0)); - var fields = try bun.default_allocator.alloc( - FieldDescription, - field_count, - ); - var remaining = fields; - errdefer { - for (fields[0 .. field_count - remaining.len]) |*field| { - field.deinit(); - } - - bun.default_allocator.free(fields); - } - while (remaining.len > 0) { - try remaining[0].decodeInternal(Container, reader); - remaining = remaining[1..]; - } - this.* = .{ - .fields = fields, - }; - } - - pub const decode = decoderWrap(RowDescription, decodeInternal).decode; -}; - -pub const ParameterDescription = struct { - parameters: []int4 = &[_]int4{}, - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var remaining_bytes = try reader.length(); - remaining_bytes -|= 4; - - const count = try reader.short(); - const parameters = try bun.default_allocator.alloc(int4, @intCast(@max(count, 0))); - - var data = try reader.read(@as(usize, @intCast(@max(count, 0))) * @sizeOf((int4))); - defer data.deinit(); - const input_params: []align(1) const int4 = toInt32Slice(int4, data.slice()); - for (input_params, parameters) |src, *dest| { - dest.* = @byteSwap(src); - } - - this.* = .{ - .parameters = parameters, - }; - } - - pub const decode = decoderWrap(ParameterDescription, decodeInternal).decode; -}; - -// workaround for zig compiler TODO -fn toInt32Slice(comptime Int: type, slice: []const u8) []align(1) const Int { - return @as([*]align(1) const Int, @ptrCast(slice.ptr))[0 .. slice.len / @sizeOf((Int))]; -} - -pub const NotificationResponse = struct { - pid: int4 = 0, - channel: bun.ByteList = .{}, - payload: bun.ByteList = .{}, - - pub fn deinit(this: *@This()) void { - this.channel.deinitWithAllocator(bun.default_allocator); - this.payload.deinitWithAllocator(bun.default_allocator); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - this.* = .{ - .pid = try reader.int4(), - .channel = (try reader.readZ()).toOwned(), - .payload = (try reader.readZ()).toOwned(), - }; - } - - pub const decode = decoderWrap(NotificationResponse, decodeInternal).decode; -}; - -pub const CommandComplete = struct { - command_tag: Data = .{ .empty = {} }, - - pub fn deinit(this: *@This()) void { - this.command_tag.deinit(); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - const tag = try reader.readZ(); - this.* = .{ - .command_tag = tag, - }; - } - - pub const decode = decoderWrap(CommandComplete, decodeInternal).decode; -}; - -pub const Parse = struct { - name: []const u8 = "", - query: []const u8 = "", - params: []const int4 = &.{}, - - pub fn deinit(this: *Parse) void { - _ = this; - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const parameters = this.params; - const count: usize = @sizeOf((u32)) + @sizeOf(u16) + (parameters.len * @sizeOf(u32)) + @max(zCount(this.name), 1) + @max(zCount(this.query), 1); - const header = [_]u8{ - 'P', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(this.name); - try writer.string(this.query); - try writer.short(parameters.len); - for (parameters) |parameter| { - try writer.int4(parameter); - } - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -pub const ParseComplete = [_]u8{'1'} ++ toBytes(Int32(4)); - -pub const PasswordMessage = struct { - password: Data = .{ .empty = {} }, - - pub fn deinit(this: *PasswordMessage) void { - this.password.deinit(); - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const password = this.password.slice(); - const count: usize = @sizeOf((u32)) + password.len + 1; - const header = [_]u8{ - 'p', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(password); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -pub const CopyData = struct { - data: Data = .{ .empty = {} }, - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - - const data = try reader.read(@intCast(length -| 5)); - this.* = .{ - .data = data, - }; - } - - pub const decode = decoderWrap(CopyData, decodeInternal).decode; - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const data = this.data.slice(); - const count: u32 = @sizeOf((u32)) + data.len + 1; - const header = [_]u8{ - 'd', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(data); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -pub const CopyDone = [_]u8{'c'} ++ toBytes(Int32(4)); -pub const Sync = [_]u8{'S'} ++ toBytes(Int32(4)); -pub const Flush = [_]u8{'H'} ++ toBytes(Int32(4)); -pub const SSLRequest = toBytes(Int32(8)) ++ toBytes(Int32(80877103)); -pub const NoData = [_]u8{'n'} ++ toBytes(Int32(4)); - -pub fn writeQuery(query: []const u8, comptime Context: type, writer: NewWriter(Context)) !void { - const count: u32 = @sizeOf((u32)) + @as(u32, @intCast(query.len)) + 1; - const header = [_]u8{ - 'Q', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(query); -} -pub const SASLInitialResponse = struct { - mechanism: Data = .{ .empty = {} }, - data: Data = .{ .empty = {} }, - - pub fn deinit(this: *SASLInitialResponse) void { - this.mechanism.deinit(); - this.data.deinit(); - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const mechanism = this.mechanism.slice(); - const data = this.data.slice(); - const count: usize = @sizeOf(u32) + mechanism.len + 1 + data.len + @sizeOf(u32); - const header = [_]u8{ - 'p', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(mechanism); - try writer.int4(@truncate(data.len)); - try writer.write(data); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -pub const SASLResponse = struct { - data: Data = .{ .empty = {} }, - - pub fn deinit(this: *SASLResponse) void { - this.data.deinit(); - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const data = this.data.slice(); - const count: usize = @sizeOf(u32) + data.len; - const header = [_]u8{ - 'p', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.write(data); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -pub const StartupMessage = struct { - user: Data, - database: Data, - options: Data = Data{ .empty = {} }, - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const user = this.user.slice(); - const database = this.database.slice(); - const options = this.options.slice(); - const count: usize = @sizeOf((int4)) + @sizeOf((int4)) + zFieldCount("user", user) + zFieldCount("database", database) + zFieldCount("client_encoding", "UTF8") + options.len + 1; - - const header = toBytes(Int32(@as(u32, @truncate(count)))); - try writer.write(&header); - try writer.int4(196608); - - try writer.string("user"); - if (user.len > 0) - try writer.string(user); - - try writer.string("database"); - - if (database.len == 0) { - // The database to connect to. Defaults to the user name. - try writer.string(user); - } else { - try writer.string(database); - } - try writer.string("client_encoding"); - try writer.string("UTF8"); - if (options.len > 0) { - try writer.write(options); - } - try writer.write(&[_]u8{0}); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -fn zCount(slice: []const u8) usize { - return if (slice.len > 0) slice.len + 1 else 0; -} - -fn zFieldCount(prefix: []const u8, slice: []const u8) usize { - if (slice.len > 0) { - return zCount(prefix) + zCount(slice); - } - - return zCount(prefix); -} - -pub const Execute = struct { - max_rows: int4 = 0, - p: PortalOrPreparedStatement, - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - try writer.write("E"); - const length = try writer.length(); - if (this.p == .portal) - try writer.string(this.p.portal) - else - try writer.write(&[_]u8{0}); - try writer.int4(this.max_rows); - try length.write(); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -pub const Describe = struct { - p: PortalOrPreparedStatement, - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const message = this.p.slice(); - try writer.write(&[_]u8{ - 'D', - }); - const length = try writer.length(); - try writer.write(&[_]u8{ - this.p.tag(), - }); - try writer.string(message); - try length.write(); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -pub const NegotiateProtocolVersion = struct { - version: int4 = 0, - unrecognized_options: std.ArrayListUnmanaged(String) = .{}, - - pub fn decodeInternal( - this: *@This(), - comptime Container: type, - reader: NewReader(Container), - ) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - const version = try reader.int4(); - this.* = .{ - .version = version, - }; - - const unrecognized_options_count: u32 = @intCast(@max(try reader.int4(), 0)); - try this.unrecognized_options.ensureTotalCapacity(bun.default_allocator, unrecognized_options_count); - errdefer { - for (this.unrecognized_options.items) |*option| { - option.deinit(); - } - this.unrecognized_options.deinit(bun.default_allocator); - } - for (0..unrecognized_options_count) |_| { - var option = try reader.readZ(); - if (option.slice().len == 0) break; - defer option.deinit(); - this.unrecognized_options.appendAssumeCapacity( - String.fromUTF8(option), - ); - } - } -}; - -pub const NoticeResponse = struct { - messages: std.ArrayListUnmanaged(FieldMessage) = .{}, - pub fn deinit(this: *NoticeResponse) void { - for (this.messages.items) |*message| { - message.deinit(); - } - this.messages.deinit(bun.default_allocator); - } - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var remaining_bytes = try reader.length(); - remaining_bytes -|= 4; - - if (remaining_bytes > 0) { - this.* = .{ - .messages = try FieldMessage.decodeList(Container, reader), - }; - } - } - pub const decode = decoderWrap(NoticeResponse, decodeInternal).decode; - - pub fn toJS(this: NoticeResponse, globalObject: *JSC.JSGlobalObject) JSValue { - var b = bun.StringBuilder{}; - defer b.deinit(bun.default_allocator); - - for (this.messages.items) |msg| { - b.cap += switch (msg) { - inline else => |m| m.utf8ByteLength(), - } + 1; - } - b.allocate(bun.default_allocator) catch {}; - - for (this.messages.items) |msg| { - var str = switch (msg) { - inline else => |m| m.toUTF8(bun.default_allocator), - }; - defer str.deinit(); - _ = b.append(str.slice()); - _ = b.append("\n"); - } - - return JSC.ZigString.init(b.allocatedSlice()[0..b.len]).toJS(globalObject); - } -}; - -pub const CopyFail = struct { - message: Data = .{ .empty = {} }, - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - _ = try reader.int4(); - - const message = try reader.readZ(); - this.* = .{ - .message = message, - }; - } - - pub const decode = decoderWrap(CopyFail, decodeInternal).decode; - - pub fn writeInternal( - this: *@This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const message = this.message.slice(); - const count: u32 = @sizeOf((u32)) + message.len + 1; - const header = [_]u8{ - 'f', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(message); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - -pub const CopyInResponse = struct { - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - _ = reader; - _ = this; - TODO(@This()); - } - - pub const decode = decoderWrap(CopyInResponse, decodeInternal).decode; -}; - -pub const CopyOutResponse = struct { - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - _ = reader; - _ = this; - TODO(@This()); - } - - pub const decode = decoderWrap(CopyInResponse, decodeInternal).decode; -}; - -fn TODO(comptime Type: type) !void { - bun.Output.panic("TODO: not implemented {s}", .{bun.meta.typeBaseName(@typeName(Type))}); -} diff --git a/src/sql/postgres/protocol/ArrayList.zig b/src/sql/postgres/protocol/ArrayList.zig new file mode 100644 index 0000000000..0fff3a0c0f --- /dev/null +++ b/src/sql/postgres/protocol/ArrayList.zig @@ -0,0 +1,22 @@ +array: *std.ArrayList(u8), + +pub fn offset(this: @This()) usize { + return this.array.items.len; +} + +pub fn write(this: @This(), bytes: []const u8) AnyPostgresError!void { + try this.array.appendSlice(bytes); +} + +pub fn pwrite(this: @This(), bytes: []const u8, i: usize) AnyPostgresError!void { + @memcpy(this.array.items[i..][0..bytes.len], bytes); +} + +pub const Writer = NewWriter(@This()); + +// @sortImports + +const ArrayList = @This(); +const std = @import("std"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; +const NewWriter = @import("./NewWriter.zig").NewWriter; diff --git a/src/sql/postgres/protocol/Authentication.zig b/src/sql/postgres/protocol/Authentication.zig new file mode 100644 index 0000000000..91b838d332 --- /dev/null +++ b/src/sql/postgres/protocol/Authentication.zig @@ -0,0 +1,182 @@ +pub const Authentication = union(enum) { + Ok: void, + ClearTextPassword: struct {}, + MD5Password: struct { + salt: [4]u8, + }, + KerberosV5: struct {}, + SCMCredential: struct {}, + GSS: struct {}, + GSSContinue: struct { + data: Data, + }, + SSPI: struct {}, + SASL: struct {}, + SASLContinue: struct { + data: Data, + r: []const u8, + s: []const u8, + i: []const u8, + + pub fn iterationCount(this: *const @This()) !u32 { + return try std.fmt.parseInt(u32, this.i, 0); + } + }, + SASLFinal: struct { + data: Data, + }, + Unknown: void, + + pub fn deinit(this: *@This()) void { + switch (this.*) { + .MD5Password => {}, + .SASL => {}, + .SASLContinue => { + this.SASLContinue.data.zdeinit(); + }, + .SASLFinal => { + this.SASLFinal.data.zdeinit(); + }, + else => {}, + } + } + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const message_length = try reader.length(); + + switch (try reader.int4()) { + 0 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ .Ok = {} }; + }, + 2 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ + .KerberosV5 = .{}, + }; + }, + 3 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ + .ClearTextPassword = .{}, + }; + }, + 5 => { + if (message_length != 12) return error.InvalidMessageLength; + var salt_data = try reader.bytes(4); + defer salt_data.deinit(); + this.* = .{ + .MD5Password = .{ + .salt = salt_data.slice()[0..4].*, + }, + }; + }, + 7 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ + .GSS = .{}, + }; + }, + + 8 => { + if (message_length < 9) return error.InvalidMessageLength; + const bytes = try reader.read(message_length - 8); + this.* = .{ + .GSSContinue = .{ + .data = bytes, + }, + }; + }, + 9 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ + .SSPI = .{}, + }; + }, + + 10 => { + if (message_length < 9) return error.InvalidMessageLength; + try reader.skip(message_length - 8); + this.* = .{ + .SASL = .{}, + }; + }, + + 11 => { + if (message_length < 9) return error.InvalidMessageLength; + var bytes = try reader.bytes(message_length - 8); + errdefer { + bytes.deinit(); + } + + var iter = bun.strings.split(bytes.slice(), ","); + var r: ?[]const u8 = null; + var i: ?[]const u8 = null; + var s: ?[]const u8 = null; + + while (iter.next()) |item| { + if (item.len > 2) { + const key = item[0]; + const after_equals = item[2..]; + if (key == 'r') { + r = after_equals; + } else if (key == 's') { + s = after_equals; + } else if (key == 'i') { + i = after_equals; + } + } + } + + if (r == null) { + debug("Missing r", .{}); + } + + if (s == null) { + debug("Missing s", .{}); + } + + if (i == null) { + debug("Missing i", .{}); + } + + this.* = .{ + .SASLContinue = .{ + .data = bytes, + .r = r orelse return error.InvalidMessage, + .s = s orelse return error.InvalidMessage, + .i = i orelse return error.InvalidMessage, + }, + }; + }, + + 12 => { + if (message_length < 9) return error.InvalidMessageLength; + const remaining: usize = message_length - 8; + + const bytes = try reader.read(remaining); + this.* = .{ + .SASLFinal = .{ + .data = bytes, + }, + }; + }, + + else => { + this.* = .{ .Unknown = {} }; + }, + } + } + + pub const decode = DecoderWrap(Authentication, decodeInternal).decode; +}; + +const debug = bun.Output.scoped(.Postgres, true); + +// @sortImports + +const bun = @import("bun"); +const std = @import("std"); +const Data = @import("../Data.zig").Data; +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/BackendKeyData.zig b/src/sql/postgres/protocol/BackendKeyData.zig new file mode 100644 index 0000000000..7df3e20971 --- /dev/null +++ b/src/sql/postgres/protocol/BackendKeyData.zig @@ -0,0 +1,20 @@ +process_id: u32 = 0, +secret_key: u32 = 0, +pub const decode = DecoderWrap(BackendKeyData, decodeInternal).decode; + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + if (!try reader.expectInt(u32, 12)) { + return error.InvalidBackendKeyData; + } + + this.* = .{ + .process_id = @bitCast(try reader.int4()), + .secret_key = @bitCast(try reader.int4()), + }; +} + +// @sortImports + +const BackendKeyData = @This(); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/Close.zig b/src/sql/postgres/protocol/Close.zig new file mode 100644 index 0000000000..baac29e9e8 --- /dev/null +++ b/src/sql/postgres/protocol/Close.zig @@ -0,0 +1,39 @@ +/// Close (F) +/// Byte1('C') +/// - Identifies the message as a Close command. +/// Int32 +/// - Length of message contents in bytes, including self. +/// Byte1 +/// - 'S' to close a prepared statement; or 'P' to close a portal. +/// String +/// - The name of the prepared statement or portal to close (an empty string selects the unnamed prepared statement or portal). +pub const Close = struct { + p: PortalOrPreparedStatement, + + fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const p = this.p; + const count: u32 = @sizeOf((u32)) + 1 + p.slice().len + 1; + const header = [_]u8{ + 'C', + } ++ @byteSwap(count) ++ [_]u8{ + p.tag(), + }; + try writer.write(&header); + try writer.write(p.slice()); + try writer.write(&[_]u8{0}); + } + + pub const write = WriteWrap(@This(), writeInternal); +}; + +// @sortImports + +const NewWriter = @import("./NewWriter.zig").NewWriter; + +const PortalOrPreparedStatement = @import("./PortalOrPreparedStatement.zig").PortalOrPreparedStatement; + +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; diff --git a/src/sql/postgres/protocol/ColumnIdentifier.zig b/src/sql/postgres/protocol/ColumnIdentifier.zig new file mode 100644 index 0000000000..026a6a843e --- /dev/null +++ b/src/sql/postgres/protocol/ColumnIdentifier.zig @@ -0,0 +1,40 @@ +pub const ColumnIdentifier = union(enum) { + name: Data, + index: u32, + duplicate: void, + + pub fn init(name: Data) !@This() { + if (switch (name.slice().len) { + 1..."4294967295".len => true, + 0 => return .{ .name = .{ .empty = {} } }, + else => false, + }) might_be_int: { + // use a u64 to avoid overflow + var int: u64 = 0; + for (name.slice()) |byte| { + int = int * 10 + switch (byte) { + '0'...'9' => @as(u64, byte - '0'), + else => break :might_be_int, + }; + } + + // JSC only supports indexed property names up to 2^32 + if (int < std.math.maxInt(u32)) + return .{ .index = @intCast(int) }; + } + + return .{ .name = .{ .owned = try name.toOwned() } }; + } + + pub fn deinit(this: *@This()) void { + switch (this.*) { + .name => |*name| name.deinit(), + else => {}, + } + } +}; + +// @sortImports + +const std = @import("std"); +const Data = @import("../Data.zig").Data; diff --git a/src/sql/postgres/protocol/CommandComplete.zig b/src/sql/postgres/protocol/CommandComplete.zig new file mode 100644 index 0000000000..a9299cd1a6 --- /dev/null +++ b/src/sql/postgres/protocol/CommandComplete.zig @@ -0,0 +1,25 @@ +command_tag: Data = .{ .empty = {} }, + +pub fn deinit(this: *@This()) void { + this.command_tag.deinit(); +} + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + const tag = try reader.readZ(); + this.* = .{ + .command_tag = tag, + }; +} + +pub const decode = DecoderWrap(CommandComplete, decodeInternal).decode; + +// @sortImports + +const CommandComplete = @This(); +const bun = @import("bun"); +const Data = @import("../Data.zig").Data; +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/CopyData.zig b/src/sql/postgres/protocol/CopyData.zig new file mode 100644 index 0000000000..885bb2960e --- /dev/null +++ b/src/sql/postgres/protocol/CopyData.zig @@ -0,0 +1,40 @@ +data: Data = .{ .empty = {} }, + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + + const data = try reader.read(@intCast(length -| 5)); + this.* = .{ + .data = data, + }; +} + +pub const decode = DecoderWrap(CopyData, decodeInternal).decode; + +pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + const data = this.data.slice(); + const count: u32 = @sizeOf((u32)) + data.len + 1; + const header = [_]u8{ + 'd', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(data); +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const CopyData = @This(); +const std = @import("std"); +const Data = @import("../Data.zig").Data; +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const Int32 = @import("../types/int_types.zig").Int32; +const NewReader = @import("./NewReader.zig").NewReader; +const NewWriter = @import("./NewWriter.zig").NewWriter; +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; +const toBytes = std.mem.toBytes; diff --git a/src/sql/postgres/protocol/CopyFail.zig b/src/sql/postgres/protocol/CopyFail.zig new file mode 100644 index 0000000000..f006cafb76 --- /dev/null +++ b/src/sql/postgres/protocol/CopyFail.zig @@ -0,0 +1,42 @@ +message: Data = .{ .empty = {} }, + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + _ = try reader.int4(); + + const message = try reader.readZ(); + this.* = .{ + .message = message, + }; +} + +pub const decode = DecoderWrap(CopyFail, decodeInternal).decode; + +pub fn writeInternal( + this: *@This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + const message = this.message.slice(); + const count: u32 = @sizeOf((u32)) + message.len + 1; + const header = [_]u8{ + 'f', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(message); +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const CopyFail = @This(); +const std = @import("std"); +const Data = @import("../Data.zig").Data; +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; +const NewWriter = @import("./NewWriter.zig").NewWriter; +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; +const toBytes = std.mem.toBytes; + +const int_types = @import("../types/int_types.zig"); +const Int32 = int_types.Int32; diff --git a/src/sql/postgres/protocol/CopyInResponse.zig b/src/sql/postgres/protocol/CopyInResponse.zig new file mode 100644 index 0000000000..47dbdd850f --- /dev/null +++ b/src/sql/postgres/protocol/CopyInResponse.zig @@ -0,0 +1,14 @@ +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + _ = reader; + _ = this; + bun.Output.panic("TODO: not implemented {s}", .{bun.meta.typeBaseName(@typeName(@This()))}); +} + +pub const decode = DecoderWrap(CopyInResponse, decodeInternal).decode; + +// @sortImports + +const CopyInResponse = @This(); +const bun = @import("bun"); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/CopyOutResponse.zig b/src/sql/postgres/protocol/CopyOutResponse.zig new file mode 100644 index 0000000000..45650a3f41 --- /dev/null +++ b/src/sql/postgres/protocol/CopyOutResponse.zig @@ -0,0 +1,14 @@ +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + _ = reader; + _ = this; + bun.Output.panic("TODO: not implemented {s}", .{bun.meta.typeBaseName(@typeName(@This()))}); +} + +pub const decode = DecoderWrap(CopyOutResponse, decodeInternal).decode; + +// @sortImports + +const CopyOutResponse = @This(); +const bun = @import("bun"); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/DataRow.zig b/src/sql/postgres/protocol/DataRow.zig new file mode 100644 index 0000000000..125a25f1f2 --- /dev/null +++ b/src/sql/postgres/protocol/DataRow.zig @@ -0,0 +1,33 @@ +pub fn decode(context: anytype, comptime ContextType: type, reader: NewReader(ContextType), comptime forEach: fn (@TypeOf(context), index: u32, bytes: ?*Data) AnyPostgresError!bool) AnyPostgresError!void { + var remaining_bytes = try reader.length(); + remaining_bytes -|= 4; + + const remaining_fields: usize = @intCast(@max(try reader.short(), 0)); + + for (0..remaining_fields) |index| { + const byte_length = try reader.int4(); + switch (byte_length) { + 0 => { + var empty = Data.Empty; + if (!try forEach(context, @intCast(index), &empty)) break; + }, + null_int4 => { + if (!try forEach(context, @intCast(index), null)) break; + }, + else => { + var bytes = try reader.bytes(@intCast(byte_length)); + if (!try forEach(context, @intCast(index), &bytes)) break; + }, + } + } +} + +pub const null_int4 = 4294967295; + +// @sortImports + +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; + +const Data = @import("../Data.zig").Data; + +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/DecoderWrap.zig b/src/sql/postgres/protocol/DecoderWrap.zig new file mode 100644 index 0000000000..fe2b78902f --- /dev/null +++ b/src/sql/postgres/protocol/DecoderWrap.zig @@ -0,0 +1,14 @@ +pub fn DecoderWrap(comptime Container: type, comptime decodeFn: anytype) type { + return struct { + pub fn decode(this: *Container, context: anytype) AnyPostgresError!void { + const Context = @TypeOf(context); + try decodeFn(this, Context, NewReader(Context){ .wrapped = context }); + } + }; +} + +// @sortImports + +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; + +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/Describe.zig b/src/sql/postgres/protocol/Describe.zig new file mode 100644 index 0000000000..4dc9fd2728 --- /dev/null +++ b/src/sql/postgres/protocol/Describe.zig @@ -0,0 +1,28 @@ +p: PortalOrPreparedStatement, + +pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + const message = this.p.slice(); + try writer.write(&[_]u8{ + 'D', + }); + const length = try writer.length(); + try writer.write(&[_]u8{ + this.p.tag(), + }); + try writer.string(message); + try length.write(); +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const NewWriter = @import("./NewWriter.zig").NewWriter; + +const PortalOrPreparedStatement = @import("./PortalOrPreparedStatement.zig").PortalOrPreparedStatement; + +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; diff --git a/src/sql/postgres/protocol/ErrorResponse.zig b/src/sql/postgres/protocol/ErrorResponse.zig new file mode 100644 index 0000000000..e70d2215a1 --- /dev/null +++ b/src/sql/postgres/protocol/ErrorResponse.zig @@ -0,0 +1,159 @@ +messages: std.ArrayListUnmanaged(FieldMessage) = .{}, + +pub fn format(formatter: ErrorResponse, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + for (formatter.messages.items) |message| { + try std.fmt.format(writer, "{}\n", .{message}); + } +} + +pub fn deinit(this: *ErrorResponse) void { + for (this.messages.items) |*message| { + message.deinit(); + } + this.messages.deinit(bun.default_allocator); +} + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var remaining_bytes = try reader.length(); + if (remaining_bytes < 4) return error.InvalidMessageLength; + remaining_bytes -|= 4; + + if (remaining_bytes > 0) { + this.* = .{ + .messages = try FieldMessage.decodeList(Container, reader), + }; + } +} + +pub const decode = DecoderWrap(ErrorResponse, decodeInternal).decode; + +pub fn toJS(this: ErrorResponse, globalObject: *JSC.JSGlobalObject) JSValue { + var b = bun.StringBuilder{}; + defer b.deinit(bun.default_allocator); + + // Pre-calculate capacity to avoid reallocations + for (this.messages.items) |*msg| { + b.cap += switch (msg.*) { + inline else => |m| m.utf8ByteLength(), + } + 1; + } + b.allocate(bun.default_allocator) catch {}; + + // Build a more structured error message + var severity: String = String.dead; + var code: String = String.dead; + var message: String = String.dead; + var detail: String = String.dead; + var hint: String = String.dead; + var position: String = String.dead; + var where: String = String.dead; + var schema: String = String.dead; + var table: String = String.dead; + var column: String = String.dead; + var datatype: String = String.dead; + var constraint: String = String.dead; + var file: String = String.dead; + var line: String = String.dead; + var routine: String = String.dead; + + for (this.messages.items) |*msg| { + switch (msg.*) { + .severity => |str| severity = str, + .code => |str| code = str, + .message => |str| message = str, + .detail => |str| detail = str, + .hint => |str| hint = str, + .position => |str| position = str, + .where => |str| where = str, + .schema => |str| schema = str, + .table => |str| table = str, + .column => |str| column = str, + .datatype => |str| datatype = str, + .constraint => |str| constraint = str, + .file => |str| file = str, + .line => |str| line = str, + .routine => |str| routine = str, + else => {}, + } + } + + var needs_newline = false; + construct_message: { + if (!message.isEmpty()) { + _ = b.appendStr(message); + needs_newline = true; + break :construct_message; + } + if (!detail.isEmpty()) { + if (needs_newline) { + _ = b.append("\n"); + } else { + _ = b.append(" "); + } + needs_newline = true; + _ = b.appendStr(detail); + } + if (!hint.isEmpty()) { + if (needs_newline) { + _ = b.append("\n"); + } else { + _ = b.append(" "); + } + needs_newline = true; + _ = b.appendStr(hint); + } + } + + const possible_fields = .{ + .{ "detail", detail, void }, + .{ "hint", hint, void }, + .{ "column", column, void }, + .{ "constraint", constraint, void }, + .{ "datatype", datatype, void }, + // in the past this was set to i32 but postgres returns a strings lets keep it compatible + .{ "errno", code, void }, + .{ "position", position, i32 }, + .{ "schema", schema, void }, + .{ "table", table, void }, + .{ "where", where, void }, + }; + const error_code: JSC.Error = + // https://www.postgresql.org/docs/8.1/errcodes-appendix.html + if (code.eqlComptime("42601")) + .POSTGRES_SYNTAX_ERROR + else + .POSTGRES_SERVER_ERROR; + const err = error_code.fmt(globalObject, "{s}", .{b.allocatedSlice()[0..b.len]}); + + inline for (possible_fields) |field| { + if (!field.@"1".isEmpty()) { + const value = brk: { + if (field.@"2" == i32) { + if (field.@"1".toInt32()) |val| { + break :brk JSC.JSValue.jsNumberFromInt32(val); + } + } + + break :brk field.@"1".toJS(globalObject); + }; + + err.put(globalObject, JSC.ZigString.static(field.@"0"), value); + } + } + + return err; +} + +// @sortImports + +const ErrorResponse = @This(); +const std = @import("std"); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const FieldMessage = @import("./FieldMessage.zig").FieldMessage; +const NewReader = @import("./NewReader.zig").NewReader; + +const bun = @import("bun"); +const String = bun.String; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/protocol/Execute.zig b/src/sql/postgres/protocol/Execute.zig new file mode 100644 index 0000000000..648d39da4f --- /dev/null +++ b/src/sql/postgres/protocol/Execute.zig @@ -0,0 +1,28 @@ +max_rows: int4 = 0, +p: PortalOrPreparedStatement, + +pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + try writer.write("E"); + const length = try writer.length(); + if (this.p == .portal) + try writer.string(this.p.portal) + else + try writer.write(&[_]u8{0}); + try writer.int4(this.max_rows); + try length.write(); +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const NewWriter = @import("./NewWriter.zig").NewWriter; +const PortalOrPreparedStatement = @import("./PortalOrPreparedStatement.zig").PortalOrPreparedStatement; +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; + +const int_types = @import("../types/int_types.zig"); +const int4 = int_types.int4; diff --git a/src/sql/postgres/protocol/FieldDescription.zig b/src/sql/postgres/protocol/FieldDescription.zig new file mode 100644 index 0000000000..860176c5b3 --- /dev/null +++ b/src/sql/postgres/protocol/FieldDescription.zig @@ -0,0 +1,70 @@ +/// JavaScriptCore treats numeric property names differently than string property names. +/// so we do the work to figure out if the property name is a number ahead of time. +name_or_index: ColumnIdentifier = .{ + .name = .{ .empty = {} }, +}, +table_oid: int4 = 0, +column_index: short = 0, +type_oid: int4 = 0, +binary: bool = false, +pub fn typeTag(this: @This()) types.Tag { + return @enumFromInt(@as(short, @truncate(this.type_oid))); +} + +pub fn deinit(this: *@This()) void { + this.name_or_index.deinit(); +} + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) AnyPostgresError!void { + var name = try reader.readZ(); + errdefer { + name.deinit(); + } + + // Field name (null-terminated string) + const field_name = try ColumnIdentifier.init(name); + // Table OID (4 bytes) + // If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. + const table_oid = try reader.int4(); + + // Column attribute number (2 bytes) + // If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. + const column_index = try reader.short(); + + // Data type OID (4 bytes) + // The object ID of the field's data type. The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. + const type_oid = try reader.int4(); + + // Data type size (2 bytes) The data type size (see pg_type.typlen). Note that negative values denote variable-width types. + // Type modifier (4 bytes) The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. + try reader.skip(6); + + // Format code (2 bytes) + // The format code being used for the field. Currently will be zero (text) or one (binary). In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero. + const binary = switch (try reader.short()) { + 0 => false, + 1 => true, + else => return error.UnknownFormatCode, + }; + this.* = .{ + .table_oid = table_oid, + .column_index = column_index, + .type_oid = type_oid, + .binary = binary, + .name_or_index = field_name, + }; +} + +pub const decode = DecoderWrap(FieldDescription, decodeInternal).decode; + +// @sortImports + +const FieldDescription = @This(); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; +const ColumnIdentifier = @import("./ColumnIdentifier.zig").ColumnIdentifier; +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; + +const types = @import("../PostgresTypes.zig"); +const int4 = types.int4; +const short = types.short; diff --git a/src/sql/postgres/protocol/FieldMessage.zig b/src/sql/postgres/protocol/FieldMessage.zig new file mode 100644 index 0000000000..d3d2c1fdbf --- /dev/null +++ b/src/sql/postgres/protocol/FieldMessage.zig @@ -0,0 +1,87 @@ +pub const FieldMessage = union(FieldType) { + severity: String, + localized_severity: String, + code: String, + message: String, + detail: String, + hint: String, + position: String, + internal_position: String, + internal: String, + where: String, + schema: String, + table: String, + column: String, + datatype: String, + constraint: String, + file: String, + line: String, + routine: String, + + pub fn format(this: FieldMessage, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + switch (this) { + inline else => |str| { + try std.fmt.format(writer, "{}", .{str}); + }, + } + } + + pub fn deinit(this: *FieldMessage) void { + switch (this.*) { + inline else => |*message| { + message.deref(); + }, + } + } + + pub fn decodeList(comptime Context: type, reader: NewReader(Context)) !std.ArrayListUnmanaged(FieldMessage) { + var messages = std.ArrayListUnmanaged(FieldMessage){}; + while (true) { + const field_int = try reader.int(u8); + if (field_int == 0) break; + const field: FieldType = @enumFromInt(field_int); + + var message = try reader.readZ(); + defer message.deinit(); + if (message.slice().len == 0) break; + + try messages.append(bun.default_allocator, FieldMessage.init(field, message.slice()) catch continue); + } + + return messages; + } + + pub fn init(tag: FieldType, message: []const u8) !FieldMessage { + return switch (tag) { + .severity => FieldMessage{ .severity = String.createUTF8(message) }, + // Ignore this one for now. + // .localized_severity => FieldMessage{ .localized_severity = String.createUTF8(message) }, + .code => FieldMessage{ .code = String.createUTF8(message) }, + .message => FieldMessage{ .message = String.createUTF8(message) }, + .detail => FieldMessage{ .detail = String.createUTF8(message) }, + .hint => FieldMessage{ .hint = String.createUTF8(message) }, + .position => FieldMessage{ .position = String.createUTF8(message) }, + .internal_position => FieldMessage{ .internal_position = String.createUTF8(message) }, + .internal => FieldMessage{ .internal = String.createUTF8(message) }, + .where => FieldMessage{ .where = String.createUTF8(message) }, + .schema => FieldMessage{ .schema = String.createUTF8(message) }, + .table => FieldMessage{ .table = String.createUTF8(message) }, + .column => FieldMessage{ .column = String.createUTF8(message) }, + .datatype => FieldMessage{ .datatype = String.createUTF8(message) }, + .constraint => FieldMessage{ .constraint = String.createUTF8(message) }, + .file => FieldMessage{ .file = String.createUTF8(message) }, + .line => FieldMessage{ .line = String.createUTF8(message) }, + .routine => FieldMessage{ .routine = String.createUTF8(message) }, + else => error.UnknownFieldType, + }; + } +}; + +// @sortImports + +const std = @import("std"); +const FieldType = @import("./FieldType.zig").FieldType; +const NewReader = @import("./NewReader.zig").NewReader; + +const bun = @import("bun"); +const String = bun.String; diff --git a/src/sql/postgres/protocol/FieldType.zig b/src/sql/postgres/protocol/FieldType.zig new file mode 100644 index 0000000000..b5e6c860fe --- /dev/null +++ b/src/sql/postgres/protocol/FieldType.zig @@ -0,0 +1,57 @@ +pub const FieldType = enum(u8) { + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a localized translation of one of these. Always present. + severity = 'S', + + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message). This is identical to the S field except that the contents are never localized. This is present only in messages generated by PostgreSQL versions 9.6 and later. + localized_severity = 'V', + + /// Code: the SQLSTATE code for the error (see Appendix A). Not localizable. Always present. + code = 'C', + + /// Message: the primary human-readable error message. This should be accurate but terse (typically one line). Always present. + message = 'M', + + /// Detail: an optional secondary error message carrying more detail about the problem. Might run to multiple lines. + detail = 'D', + + /// Hint: an optional suggestion what to do about the problem. This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. Might run to multiple lines. + hint = 'H', + + /// Position: the field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. The first character has index 1, and positions are measured in characters not bytes. + position = 'P', + + /// Internal position: this is defined the same as the P field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. The q field will always appear when this field appears. + internal_position = 'p', + + /// Internal query: the text of a failed internally-generated command. This could be, for example, an SQL query issued by a PL/pgSQL function. + internal = 'q', + + /// Where: an indication of the context in which the error occurred. Presently this includes a call stack traceback of active procedural language functions and internally-generated queries. The trace is one entry per line, most recent first. + where = 'W', + + /// Schema name: if the error was associated with a specific database object, the name of the schema containing that object, if any. + schema = 's', + + /// Table name: if the error was associated with a specific table, the name of the table. (Refer to the schema name field for the name of the table's schema.) + table = 't', + + /// Column name: if the error was associated with a specific table column, the name of the column. (Refer to the schema and table name fields to identify the table.) + column = 'c', + + /// Data type name: if the error was associated with a specific data type, the name of the data type. (Refer to the schema name field for the name of the data type's schema.) + datatype = 'd', + + /// Constraint name: if the error was associated with a specific constraint, the name of the constraint. Refer to fields listed above for the associated table or domain. (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) + constraint = 'n', + + /// File: the file name of the source-code location where the error was reported. + file = 'F', + + /// Line: the line number of the source-code location where the error was reported. + line = 'L', + + /// Routine: the name of the source-code routine reporting the error. + routine = 'R', + + _, +}; diff --git a/src/sql/postgres/protocol/NegotiateProtocolVersion.zig b/src/sql/postgres/protocol/NegotiateProtocolVersion.zig new file mode 100644 index 0000000000..9b80f0fdd2 --- /dev/null +++ b/src/sql/postgres/protocol/NegotiateProtocolVersion.zig @@ -0,0 +1,44 @@ +version: int4 = 0, +unrecognized_options: std.ArrayListUnmanaged(String) = .{}, + +pub fn decodeInternal( + this: *@This(), + comptime Container: type, + reader: NewReader(Container), +) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + const version = try reader.int4(); + this.* = .{ + .version = version, + }; + + const unrecognized_options_count: u32 = @intCast(@max(try reader.int4(), 0)); + try this.unrecognized_options.ensureTotalCapacity(bun.default_allocator, unrecognized_options_count); + errdefer { + for (this.unrecognized_options.items) |*option| { + option.deinit(); + } + this.unrecognized_options.deinit(bun.default_allocator); + } + for (0..unrecognized_options_count) |_| { + var option = try reader.readZ(); + if (option.slice().len == 0) break; + defer option.deinit(); + this.unrecognized_options.appendAssumeCapacity( + String.fromUTF8(option), + ); + } +} + +// @sortImports + +const std = @import("std"); +const NewReader = @import("./NewReader.zig").NewReader; + +const int_types = @import("../types/int_types.zig"); +const int4 = int_types.int4; + +const bun = @import("bun"); +const String = bun.String; diff --git a/src/sql/postgres/protocol/NewReader.zig b/src/sql/postgres/protocol/NewReader.zig new file mode 100644 index 0000000000..932d4d334d --- /dev/null +++ b/src/sql/postgres/protocol/NewReader.zig @@ -0,0 +1,118 @@ +pub fn NewReaderWrap( + comptime Context: type, + comptime markMessageStartFn_: (fn (ctx: Context) void), + comptime peekFn_: (fn (ctx: Context) []const u8), + comptime skipFn_: (fn (ctx: Context, count: usize) void), + comptime ensureCapacityFn_: (fn (ctx: Context, count: usize) bool), + comptime readFunction_: (fn (ctx: Context, count: usize) AnyPostgresError!Data), + comptime readZ_: (fn (ctx: Context) AnyPostgresError!Data), +) type { + return struct { + wrapped: Context, + const readFn = readFunction_; + const readZFn = readZ_; + const ensureCapacityFn = ensureCapacityFn_; + const skipFn = skipFn_; + const peekFn = peekFn_; + const markMessageStartFn = markMessageStartFn_; + + pub const Ctx = Context; + + pub inline fn markMessageStart(this: @This()) void { + markMessageStartFn(this.wrapped); + } + + pub inline fn read(this: @This(), count: usize) AnyPostgresError!Data { + return try readFn(this.wrapped, count); + } + + pub inline fn eatMessage(this: @This(), comptime msg_: anytype) AnyPostgresError!void { + const msg = msg_[1..]; + try this.ensureCapacity(msg.len); + + var input = try readFn(this.wrapped, msg.len); + defer input.deinit(); + if (bun.strings.eqlComptime(input.slice(), msg)) return; + return error.InvalidMessage; + } + + pub fn skip(this: @This(), count: usize) AnyPostgresError!void { + skipFn(this.wrapped, count); + } + + pub fn peek(this: @This()) []const u8 { + return peekFn(this.wrapped); + } + + pub inline fn readZ(this: @This()) AnyPostgresError!Data { + return try readZFn(this.wrapped); + } + + pub inline fn ensureCapacity(this: @This(), count: usize) AnyPostgresError!void { + if (!ensureCapacityFn(this.wrapped, count)) { + return error.ShortRead; + } + } + + pub fn int(this: @This(), comptime Int: type) !Int { + var data = try this.read(@sizeOf((Int))); + defer data.deinit(); + if (comptime Int == u8) { + return @as(Int, data.slice()[0]); + } + return @byteSwap(@as(Int, @bitCast(data.slice()[0..@sizeOf(Int)].*))); + } + + pub fn peekInt(this: @This(), comptime Int: type) ?Int { + const remain = this.peek(); + if (remain.len < @sizeOf(Int)) { + return null; + } + return @byteSwap(@as(Int, @bitCast(remain[0..@sizeOf(Int)].*))); + } + + pub fn expectInt(this: @This(), comptime Int: type, comptime value: comptime_int) !bool { + const actual = try this.int(Int); + return actual == value; + } + + pub fn int4(this: @This()) !PostgresInt32 { + return this.int(PostgresInt32); + } + + pub fn short(this: @This()) !PostgresShort { + return this.int(PostgresShort); + } + + pub fn length(this: @This()) !PostgresInt32 { + const expected = try this.int(PostgresInt32); + if (expected > -1) { + try this.ensureCapacity(@intCast(expected -| 4)); + } + + return expected; + } + + pub const bytes = read; + + pub fn String(this: @This()) !bun.String { + var result = try this.readZ(); + defer result.deinit(); + return bun.String.fromUTF8(result.slice()); + } + }; +} + +pub fn NewReader(comptime Context: type) type { + return NewReaderWrap(Context, Context.markMessageStart, Context.peek, Context.skip, Context.ensureLength, Context.read, Context.readZ); +} + +// @sortImports + +const bun = @import("bun"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; +const Data = @import("../Data.zig").Data; + +const int_types = @import("../types/int_types.zig"); +const PostgresInt32 = int_types.PostgresInt32; +const PostgresShort = int_types.PostgresShort; diff --git a/src/sql/postgres/protocol/NewWriter.zig b/src/sql/postgres/protocol/NewWriter.zig new file mode 100644 index 0000000000..6f6a800328 --- /dev/null +++ b/src/sql/postgres/protocol/NewWriter.zig @@ -0,0 +1,125 @@ +pub fn NewWriterWrap( + comptime Context: type, + comptime offsetFn_: (fn (ctx: Context) usize), + comptime writeFunction_: (fn (ctx: Context, bytes: []const u8) AnyPostgresError!void), + comptime pwriteFunction_: (fn (ctx: Context, bytes: []const u8, offset: usize) AnyPostgresError!void), +) type { + return struct { + wrapped: Context, + + const writeFn = writeFunction_; + const pwriteFn = pwriteFunction_; + const offsetFn = offsetFn_; + pub const Ctx = Context; + + pub const WrappedWriter = @This(); + + pub inline fn write(this: @This(), data: []const u8) AnyPostgresError!void { + try writeFn(this.wrapped, data); + } + + pub const LengthWriter = struct { + index: usize, + context: WrappedWriter, + + pub fn write(this: LengthWriter) AnyPostgresError!void { + try this.context.pwrite(&Int32(this.context.offset() - this.index), this.index); + } + + pub fn writeExcludingSelf(this: LengthWriter) AnyPostgresError!void { + try this.context.pwrite(&Int32(this.context.offset() -| (this.index + 4)), this.index); + } + }; + + pub inline fn length(this: @This()) AnyPostgresError!LengthWriter { + const i = this.offset(); + try this.int4(0); + return LengthWriter{ + .index = i, + .context = this, + }; + } + + pub inline fn offset(this: @This()) usize { + return offsetFn(this.wrapped); + } + + pub inline fn pwrite(this: @This(), data: []const u8, i: usize) AnyPostgresError!void { + try pwriteFn(this.wrapped, data, i); + } + + pub fn int4(this: @This(), value: PostgresInt32) !void { + try this.write(std.mem.asBytes(&@byteSwap(value))); + } + + pub fn int8(this: @This(), value: PostgresInt64) !void { + try this.write(std.mem.asBytes(&@byteSwap(value))); + } + + pub fn sint4(this: @This(), value: i32) !void { + try this.write(std.mem.asBytes(&@byteSwap(value))); + } + + pub fn @"f64"(this: @This(), value: f64) !void { + try this.write(std.mem.asBytes(&@byteSwap(@as(u64, @bitCast(value))))); + } + + pub fn @"f32"(this: @This(), value: f32) !void { + try this.write(std.mem.asBytes(&@byteSwap(@as(u32, @bitCast(value))))); + } + + pub fn short(this: @This(), value: anytype) !void { + try this.write(std.mem.asBytes(&@byteSwap(@as(u16, @intCast(value))))); + } + + pub fn string(this: @This(), value: []const u8) !void { + try this.write(value); + if (value.len == 0 or value[value.len - 1] != 0) + try this.write(&[_]u8{0}); + } + + pub fn bytes(this: @This(), value: []const u8) !void { + try this.write(value); + if (value.len == 0 or value[value.len - 1] != 0) + try this.write(&[_]u8{0}); + } + + pub fn @"bool"(this: @This(), value: bool) !void { + try this.write(if (value) "t" else "f"); + } + + pub fn @"null"(this: @This()) !void { + try this.int4(std.math.maxInt(PostgresInt32)); + } + + pub fn String(this: @This(), value: bun.String) !void { + if (value.isEmpty()) { + try this.write(&[_]u8{0}); + return; + } + + var sliced = value.toUTF8(bun.default_allocator); + defer sliced.deinit(); + const slice = sliced.slice(); + + try this.write(slice); + if (slice.len == 0 or slice[slice.len - 1] != 0) + try this.write(&[_]u8{0}); + } + }; +} + +pub fn NewWriter(comptime Context: type) type { + return NewWriterWrap(Context, Context.offset, Context.write, Context.pwrite); +} + +// @sortImports + +const bun = @import("bun"); +const std = @import("std"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; + +const int_types = @import("../types/int_types.zig"); +const Int32 = int_types.Int32; +const PostgresInt32 = int_types.PostgresInt32; +const PostgresInt64 = int_types.PostgresInt64; diff --git a/src/sql/postgres/protocol/NoticeResponse.zig b/src/sql/postgres/protocol/NoticeResponse.zig new file mode 100644 index 0000000000..1e84eef072 --- /dev/null +++ b/src/sql/postgres/protocol/NoticeResponse.zig @@ -0,0 +1,53 @@ +messages: std.ArrayListUnmanaged(FieldMessage) = .{}, +pub fn deinit(this: *NoticeResponse) void { + for (this.messages.items) |*message| { + message.deinit(); + } + this.messages.deinit(bun.default_allocator); +} +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var remaining_bytes = try reader.length(); + remaining_bytes -|= 4; + + if (remaining_bytes > 0) { + this.* = .{ + .messages = try FieldMessage.decodeList(Container, reader), + }; + } +} +pub const decode = DecoderWrap(NoticeResponse, decodeInternal).decode; + +pub fn toJS(this: NoticeResponse, globalObject: *JSC.JSGlobalObject) JSValue { + var b = bun.StringBuilder{}; + defer b.deinit(bun.default_allocator); + + for (this.messages.items) |msg| { + b.cap += switch (msg) { + inline else => |m| m.utf8ByteLength(), + } + 1; + } + b.allocate(bun.default_allocator) catch {}; + + for (this.messages.items) |msg| { + var str = switch (msg) { + inline else => |m| m.toUTF8(bun.default_allocator), + }; + defer str.deinit(); + _ = b.append(str.slice()); + _ = b.append("\n"); + } + + return JSC.ZigString.init(b.allocatedSlice()[0..b.len]).toJS(globalObject); +} + +// @sortImports + +const NoticeResponse = @This(); +const bun = @import("bun"); +const std = @import("std"); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const FieldMessage = @import("./FieldMessage.zig").FieldMessage; +const NewReader = @import("./NewReader.zig").NewReader; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/protocol/NotificationResponse.zig b/src/sql/postgres/protocol/NotificationResponse.zig new file mode 100644 index 0000000000..936490602d --- /dev/null +++ b/src/sql/postgres/protocol/NotificationResponse.zig @@ -0,0 +1,31 @@ +pid: int4 = 0, +channel: bun.ByteList = .{}, +payload: bun.ByteList = .{}, + +pub fn deinit(this: *@This()) void { + this.channel.deinitWithAllocator(bun.default_allocator); + this.payload.deinitWithAllocator(bun.default_allocator); +} + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + this.* = .{ + .pid = try reader.int4(), + .channel = (try reader.readZ()).toOwned(), + .payload = (try reader.readZ()).toOwned(), + }; +} + +pub const decode = DecoderWrap(NotificationResponse, decodeInternal).decode; + +// @sortImports + +const NotificationResponse = @This(); +const bun = @import("bun"); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; + +const types = @import("../PostgresTypes.zig"); +const int4 = types.int4; diff --git a/src/sql/postgres/protocol/ParameterDescription.zig b/src/sql/postgres/protocol/ParameterDescription.zig new file mode 100644 index 0000000000..8be2737fd6 --- /dev/null +++ b/src/sql/postgres/protocol/ParameterDescription.zig @@ -0,0 +1,37 @@ +parameters: []int4 = &[_]int4{}, + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var remaining_bytes = try reader.length(); + remaining_bytes -|= 4; + + const count = try reader.short(); + const parameters = try bun.default_allocator.alloc(int4, @intCast(@max(count, 0))); + + var data = try reader.read(@as(usize, @intCast(@max(count, 0))) * @sizeOf((int4))); + defer data.deinit(); + const input_params: []align(1) const int4 = toInt32Slice(int4, data.slice()); + for (input_params, parameters) |src, *dest| { + dest.* = @byteSwap(src); + } + + this.* = .{ + .parameters = parameters, + }; +} + +pub const decode = DecoderWrap(ParameterDescription, decodeInternal).decode; + +// workaround for zig compiler TODO +fn toInt32Slice(comptime Int: type, slice: []const u8) []align(1) const Int { + return @as([*]align(1) const Int, @ptrCast(slice.ptr))[0 .. slice.len / @sizeOf((Int))]; +} + +// @sortImports + +const ParameterDescription = @This(); +const bun = @import("bun"); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; + +const types = @import("../PostgresTypes.zig"); +const int4 = types.int4; diff --git a/src/sql/postgres/protocol/ParameterStatus.zig b/src/sql/postgres/protocol/ParameterStatus.zig new file mode 100644 index 0000000000..9575c0302d --- /dev/null +++ b/src/sql/postgres/protocol/ParameterStatus.zig @@ -0,0 +1,27 @@ +name: Data = .{ .empty = {} }, +value: Data = .{ .empty = {} }, + +pub fn deinit(this: *@This()) void { + this.name.deinit(); + this.value.deinit(); +} + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + this.* = .{ + .name = try reader.readZ(), + .value = try reader.readZ(), + }; +} + +pub const decode = DecoderWrap(ParameterStatus, decodeInternal).decode; + +// @sortImports + +const ParameterStatus = @This(); +const bun = @import("bun"); +const Data = @import("../Data.zig").Data; +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/Parse.zig b/src/sql/postgres/protocol/Parse.zig new file mode 100644 index 0000000000..af14f63461 --- /dev/null +++ b/src/sql/postgres/protocol/Parse.zig @@ -0,0 +1,43 @@ +name: []const u8 = "", +query: []const u8 = "", +params: []const int4 = &.{}, + +pub fn deinit(this: *Parse) void { + _ = this; +} + +pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + const parameters = this.params; + const count: usize = @sizeOf((u32)) + @sizeOf(u16) + (parameters.len * @sizeOf(u32)) + @max(zCount(this.name), 1) + @max(zCount(this.query), 1); + const header = [_]u8{ + 'P', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(this.name); + try writer.string(this.query); + try writer.short(parameters.len); + for (parameters) |parameter| { + try writer.int4(parameter); + } +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const Parse = @This(); +const std = @import("std"); +const NewWriter = @import("./NewWriter.zig").NewWriter; +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; +const toBytes = std.mem.toBytes; + +const types = @import("../PostgresTypes.zig"); +const Int32 = types.Int32; +const int4 = types.int4; + +const zHelpers = @import("./zHelpers.zig"); +const zCount = zHelpers.zCount; diff --git a/src/sql/postgres/protocol/PasswordMessage.zig b/src/sql/postgres/protocol/PasswordMessage.zig new file mode 100644 index 0000000000..222e37b7da --- /dev/null +++ b/src/sql/postgres/protocol/PasswordMessage.zig @@ -0,0 +1,31 @@ +password: Data = .{ .empty = {} }, + +pub fn deinit(this: *PasswordMessage) void { + this.password.deinit(); +} + +pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + const password = this.password.slice(); + const count: usize = @sizeOf((u32)) + password.len + 1; + const header = [_]u8{ + 'p', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(password); +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const PasswordMessage = @This(); +const std = @import("std"); +const Data = @import("../Data.zig").Data; +const Int32 = @import("../types/int_types.zig").Int32; +const NewWriter = @import("./NewWriter.zig").NewWriter; +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; +const toBytes = std.mem.toBytes; diff --git a/src/sql/postgres/protocol/PortalOrPreparedStatement.zig b/src/sql/postgres/protocol/PortalOrPreparedStatement.zig new file mode 100644 index 0000000000..575f5a07bd --- /dev/null +++ b/src/sql/postgres/protocol/PortalOrPreparedStatement.zig @@ -0,0 +1,18 @@ +pub const PortalOrPreparedStatement = union(enum) { + portal: []const u8, + prepared_statement: []const u8, + + pub fn slice(this: @This()) []const u8 { + return switch (this) { + .portal => this.portal, + .prepared_statement => this.prepared_statement, + }; + } + + pub fn tag(this: @This()) u8 { + return switch (this) { + .portal => 'P', + .prepared_statement => 'S', + }; + } +}; diff --git a/src/sql/postgres/protocol/ReadyForQuery.zig b/src/sql/postgres/protocol/ReadyForQuery.zig new file mode 100644 index 0000000000..baee6bea3b --- /dev/null +++ b/src/sql/postgres/protocol/ReadyForQuery.zig @@ -0,0 +1,18 @@ +const ReadyForQuery = @This(); +status: TransactionStatusIndicator = .I, +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + const status = try reader.int(u8); + this.* = .{ + .status = @enumFromInt(status), + }; +} + +pub const decode = DecoderWrap(ReadyForQuery, decodeInternal).decode; + +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; +const TransactionStatusIndicator = @import("./TransactionStatusIndicator.zig").TransactionStatusIndicator; +const bun = @import("bun"); diff --git a/src/sql/postgres/protocol/RowDescription.zig b/src/sql/postgres/protocol/RowDescription.zig new file mode 100644 index 0000000000..e3068d4aee --- /dev/null +++ b/src/sql/postgres/protocol/RowDescription.zig @@ -0,0 +1,44 @@ +fields: []FieldDescription = &[_]FieldDescription{}, +pub fn deinit(this: *@This()) void { + for (this.fields) |*field| { + field.deinit(); + } + + bun.default_allocator.free(this.fields); +} + +pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var remaining_bytes = try reader.length(); + remaining_bytes -|= 4; + + const field_count: usize = @intCast(@max(try reader.short(), 0)); + var fields = try bun.default_allocator.alloc( + FieldDescription, + field_count, + ); + var remaining = fields; + errdefer { + for (fields[0 .. field_count - remaining.len]) |*field| { + field.deinit(); + } + + bun.default_allocator.free(fields); + } + while (remaining.len > 0) { + try remaining[0].decodeInternal(Container, reader); + remaining = remaining[1..]; + } + this.* = .{ + .fields = fields, + }; +} + +pub const decode = DecoderWrap(RowDescription, decodeInternal).decode; + +// @sortImports + +const FieldDescription = @import("./FieldDescription.zig"); +const RowDescription = @This(); +const bun = @import("bun"); +const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/SASLInitialResponse.zig b/src/sql/postgres/protocol/SASLInitialResponse.zig new file mode 100644 index 0000000000..8c5ee5cf14 --- /dev/null +++ b/src/sql/postgres/protocol/SASLInitialResponse.zig @@ -0,0 +1,36 @@ +mechanism: Data = .{ .empty = {} }, +data: Data = .{ .empty = {} }, + +pub fn deinit(this: *SASLInitialResponse) void { + this.mechanism.deinit(); + this.data.deinit(); +} + +pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + const mechanism = this.mechanism.slice(); + const data = this.data.slice(); + const count: usize = @sizeOf(u32) + mechanism.len + 1 + data.len + @sizeOf(u32); + const header = [_]u8{ + 'p', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(mechanism); + try writer.int4(@truncate(data.len)); + try writer.write(data); +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const SASLInitialResponse = @This(); +const std = @import("std"); +const Data = @import("../Data.zig").Data; +const Int32 = @import("../types/int_types.zig").Int32; +const NewWriter = @import("./NewWriter.zig").NewWriter; +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; +const toBytes = std.mem.toBytes; diff --git a/src/sql/postgres/protocol/SASLResponse.zig b/src/sql/postgres/protocol/SASLResponse.zig new file mode 100644 index 0000000000..314fabd9e2 --- /dev/null +++ b/src/sql/postgres/protocol/SASLResponse.zig @@ -0,0 +1,31 @@ +data: Data = .{ .empty = {} }, + +pub fn deinit(this: *SASLResponse) void { + this.data.deinit(); +} + +pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + const data = this.data.slice(); + const count: usize = @sizeOf(u32) + data.len; + const header = [_]u8{ + 'p', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.write(data); +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const SASLResponse = @This(); +const std = @import("std"); +const Data = @import("../Data.zig").Data; +const Int32 = @import("../types/int_types.zig").Int32; +const NewWriter = @import("./NewWriter.zig").NewWriter; +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; +const toBytes = std.mem.toBytes; diff --git a/src/sql/postgres/protocol/StackReader.zig b/src/sql/postgres/protocol/StackReader.zig new file mode 100644 index 0000000000..85fb93b5a9 --- /dev/null +++ b/src/sql/postgres/protocol/StackReader.zig @@ -0,0 +1,66 @@ +buffer: []const u8 = "", +offset: *usize, +message_start: *usize, + +pub fn markMessageStart(this: @This()) void { + this.message_start.* = this.offset.*; +} + +pub fn ensureLength(this: @This(), length: usize) bool { + return this.buffer.len >= (this.offset.* + length); +} + +pub fn init(buffer: []const u8, offset: *usize, message_start: *usize) NewReader(StackReader) { + return .{ + .wrapped = .{ + .buffer = buffer, + .offset = offset, + .message_start = message_start, + }, + }; +} + +pub fn peek(this: StackReader) []const u8 { + return this.buffer[this.offset.*..]; +} +pub fn skip(this: StackReader, count: usize) void { + if (this.offset.* + count > this.buffer.len) { + this.offset.* = this.buffer.len; + return; + } + + this.offset.* += count; +} +pub fn ensureCapacity(this: StackReader, count: usize) bool { + return this.buffer.len >= (this.offset.* + count); +} +pub fn read(this: StackReader, count: usize) AnyPostgresError!Data { + const offset = this.offset.*; + if (!this.ensureCapacity(count)) { + return error.ShortRead; + } + + this.skip(count); + return Data{ + .temporary = this.buffer[offset..this.offset.*], + }; +} +pub fn readZ(this: StackReader) AnyPostgresError!Data { + const remaining = this.peek(); + if (bun.strings.indexOfChar(remaining, 0)) |zero| { + this.skip(zero + 1); + return Data{ + .temporary = remaining[0..zero], + }; + } + + return error.ShortRead; +} + +// @sortImports + +const StackReader = @This(); +const bun = @import("bun"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; +const Data = @import("../Data.zig").Data; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/StartupMessage.zig b/src/sql/postgres/protocol/StartupMessage.zig new file mode 100644 index 0000000000..d52f65a878 --- /dev/null +++ b/src/sql/postgres/protocol/StartupMessage.zig @@ -0,0 +1,52 @@ +user: Data, +database: Data, +options: Data = Data{ .empty = {} }, + +pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), +) !void { + const user = this.user.slice(); + const database = this.database.slice(); + const options = this.options.slice(); + const count: usize = @sizeOf((int4)) + @sizeOf((int4)) + zFieldCount("user", user) + zFieldCount("database", database) + zFieldCount("client_encoding", "UTF8") + options.len + 1; + + const header = toBytes(Int32(@as(u32, @truncate(count)))); + try writer.write(&header); + try writer.int4(196608); + + try writer.string("user"); + if (user.len > 0) + try writer.string(user); + + try writer.string("database"); + + if (database.len == 0) { + // The database to connect to. Defaults to the user name. + try writer.string(user); + } else { + try writer.string(database); + } + try writer.string("client_encoding"); + try writer.string("UTF8"); + if (options.len > 0) { + try writer.write(options); + } + try writer.write(&[_]u8{0}); +} + +pub const write = WriteWrap(@This(), writeInternal).write; + +// @sortImports + +const std = @import("std"); +const Data = @import("../Data.zig").Data; +const NewWriter = @import("./NewWriter.zig").NewWriter; +const WriteWrap = @import("./WriteWrap.zig").WriteWrap; +const zFieldCount = @import("./zHelpers.zig").zFieldCount; +const toBytes = std.mem.toBytes; + +const int_types = @import("../types/int_types.zig"); +const Int32 = int_types.Int32; +const int4 = int_types.int4; diff --git a/src/sql/postgres/protocol/TransactionStatusIndicator.zig b/src/sql/postgres/protocol/TransactionStatusIndicator.zig new file mode 100644 index 0000000000..9650d394f1 --- /dev/null +++ b/src/sql/postgres/protocol/TransactionStatusIndicator.zig @@ -0,0 +1,12 @@ +pub const TransactionStatusIndicator = enum(u8) { + /// if idle (not in a transaction block) + I = 'I', + + /// if in a transaction block + T = 'T', + + /// if in a failed transaction block + E = 'E', + + _, +}; diff --git a/src/sql/postgres/protocol/WriteWrap.zig b/src/sql/postgres/protocol/WriteWrap.zig new file mode 100644 index 0000000000..0fc4470b69 --- /dev/null +++ b/src/sql/postgres/protocol/WriteWrap.zig @@ -0,0 +1,14 @@ +pub fn WriteWrap(comptime Container: type, comptime writeFn: anytype) type { + return struct { + pub fn write(this: *Container, context: anytype) AnyPostgresError!void { + const Context = @TypeOf(context); + try writeFn(this, Context, NewWriter(Context){ .wrapped = context }); + } + }; +} + +// @sortImports + +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; + +const NewWriter = @import("./NewWriter.zig").NewWriter; diff --git a/src/sql/postgres/protocol/zHelpers.zig b/src/sql/postgres/protocol/zHelpers.zig new file mode 100644 index 0000000000..9e28f9d0d1 --- /dev/null +++ b/src/sql/postgres/protocol/zHelpers.zig @@ -0,0 +1,11 @@ +pub fn zCount(slice: []const u8) usize { + return if (slice.len > 0) slice.len + 1 else 0; +} + +pub fn zFieldCount(prefix: []const u8, slice: []const u8) usize { + if (slice.len > 0) { + return zCount(prefix) + zCount(slice); + } + + return zCount(prefix); +} diff --git a/src/sql/postgres/types/PostgresString.zig b/src/sql/postgres/types/PostgresString.zig new file mode 100644 index 0000000000..f2e4cb4292 --- /dev/null +++ b/src/sql/postgres/types/PostgresString.zig @@ -0,0 +1,52 @@ +pub const to = 25; +pub const from = [_]short{1002}; + +pub fn toJSWithType( + globalThis: *JSC.JSGlobalObject, + comptime Type: type, + value: Type, +) AnyPostgresError!JSValue { + switch (comptime Type) { + [:0]u8, []u8, []const u8, [:0]const u8 => { + var str = bun.String.fromUTF8(value); + defer str.deinit(); + return str.toJS(globalThis); + }, + + bun.String => { + return value.toJS(globalThis); + }, + + *Data => { + var str = bun.String.fromUTF8(value.slice()); + defer str.deinit(); + defer value.deinit(); + return str.toJS(globalThis); + }, + + else => { + @compileError("unsupported type " ++ @typeName(Type)); + }, + } +} + +pub fn toJS( + globalThis: *JSC.JSGlobalObject, + value: anytype, +) !JSValue { + var str = try toJSWithType(globalThis, @TypeOf(value), value); + defer str.deinit(); + return str.toJS(globalThis); +} + +// @sortImports + +const bun = @import("bun"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; +const Data = @import("../Data.zig").Data; + +const int_types = @import("./int_types.zig"); +const short = int_types.short; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/postgres_types.zig b/src/sql/postgres/types/Tag.zig similarity index 72% rename from src/sql/postgres/postgres_types.zig rename to src/sql/postgres/types/Tag.zig index 4119ad23da..7c77c5a390 100644 --- a/src/sql/postgres/postgres_types.zig +++ b/src/sql/postgres/types/Tag.zig @@ -1,14 +1,3 @@ -const std = @import("std"); -const bun = @import("bun"); -const postgres = bun.api.Postgres; -const Data = postgres.Data; -const String = bun.String; -const JSValue = JSC.JSValue; -const JSC = bun.JSC; -const short = postgres.short; -const int4 = postgres.int4; -const AnyPostgresError = postgres.AnyPostgresError; - // select b.typname, b.oid, b.typarray // from pg_catalog.pg_type a // left join pg_catalog.pg_type b on b.oid = a.typelem @@ -402,153 +391,21 @@ pub const Tag = enum(short) { } }; -pub const string = struct { - pub const to = 25; - pub const from = [_]short{1002}; +const @"bool" = @import("./bool.zig"); - pub fn toJSWithType( - globalThis: *JSC.JSGlobalObject, - comptime Type: type, - value: Type, - ) AnyPostgresError!JSValue { - switch (comptime Type) { - [:0]u8, []u8, []const u8, [:0]const u8 => { - var str = String.fromUTF8(value); - defer str.deinit(); - return str.toJS(globalThis); - }, +// @sortImports - bun.String => { - return value.toJS(globalThis); - }, +const bun = @import("bun"); +const bytea = @import("./bytea.zig"); +const date = @import("./date.zig"); +const json = @import("./json.zig"); +const numeric = @import("./numeric.zig"); +const std = @import("std"); +const string = @import("./PostgresString.zig"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; - *Data => { - var str = String.fromUTF8(value.slice()); - defer str.deinit(); - defer value.deinit(); - return str.toJS(globalThis); - }, +const int_types = @import("./int_types.zig"); +const short = int_types.short; - else => { - @compileError("unsupported type " ++ @typeName(Type)); - }, - } - } - - pub fn toJS( - globalThis: *JSC.JSGlobalObject, - value: anytype, - ) !JSValue { - var str = try toJSWithType(globalThis, @TypeOf(value), value); - defer str.deinit(); - return str.toJS(globalThis); - } -}; - -pub const numeric = struct { - pub const to = 0; - pub const from = [_]short{ 21, 23, 26, 700, 701 }; - - pub fn toJS( - _: *JSC.JSGlobalObject, - value: anytype, - ) AnyPostgresError!JSValue { - return JSValue.jsNumber(value); - } -}; - -pub const json = struct { - pub const to = 114; - pub const from = [_]short{ 114, 3802 }; - - pub fn toJS( - globalObject: *JSC.JSGlobalObject, - value: *Data, - ) AnyPostgresError!JSValue { - defer value.deinit(); - var str = bun.String.fromUTF8(value.slice()); - defer str.deref(); - const parse_result = JSValue.parse(str.toJS(globalObject), globalObject); - if (parse_result.AnyPostgresError()) { - return globalObject.throwValue(parse_result); - } - - return parse_result; - } -}; - -pub const @"bool" = struct { - pub const to = 16; - pub const from = [_]short{16}; - - pub fn toJS( - _: *JSC.JSGlobalObject, - value: bool, - ) AnyPostgresError!JSValue { - return JSValue.jsBoolean(value); - } -}; - -pub const date = struct { - pub const to = 1184; - pub const from = [_]short{ 1082, 1114, 1184 }; - - // Postgres stores timestamp and timestampz as microseconds since 2000-01-01 - // This is a signed 64-bit integer. - const POSTGRES_EPOCH_DATE = 946684800000; - - pub fn fromBinary(bytes: []const u8) f64 { - const microseconds = std.mem.readInt(i64, bytes[0..8], .big); - const double_microseconds: f64 = @floatFromInt(microseconds); - return (double_microseconds / std.time.us_per_ms) + POSTGRES_EPOCH_DATE; - } - - pub fn fromJS(globalObject: *JSC.JSGlobalObject, value: JSValue) i64 { - const double_value = if (value.isDate()) - value.getUnixTimestamp() - else if (value.isNumber()) - value.asNumber() - else if (value.isString()) brk: { - var str = value.toBunString(globalObject) catch @panic("unreachable"); - defer str.deref(); - break :brk str.parseDate(globalObject); - } else return 0; - - const unix_timestamp: i64 = @intFromFloat(double_value); - return (unix_timestamp - POSTGRES_EPOCH_DATE) * std.time.us_per_ms; - } - - pub fn toJS( - globalObject: *JSC.JSGlobalObject, - value: anytype, - ) JSValue { - switch (@TypeOf(value)) { - i64 => { - // Convert from Postgres timestamp (μs since 2000-01-01) to Unix timestamp (ms) - const ms = @divFloor(value, std.time.us_per_ms) + POSTGRES_EPOCH_DATE; - return JSValue.fromDateNumber(globalObject, @floatFromInt(ms)); - }, - *Data => { - defer value.deinit(); - return JSValue.fromDateString(globalObject, value.sliceZ().ptr); - }, - else => @compileError("unsupported type " ++ @typeName(@TypeOf(value))), - } - } -}; - -pub const bytea = struct { - pub const to = 17; - pub const from = [_]short{17}; - - pub fn toJS( - globalObject: *JSC.JSGlobalObject, - value: *Data, - ) AnyPostgresError!JSValue { - defer value.deinit(); - - // var slice = value.slice()[@min(1, value.len)..]; - // _ = slice; - return JSValue.createBuffer(globalObject, value.slice(), null); - } -}; +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/types/bool.zig b/src/sql/postgres/types/bool.zig new file mode 100644 index 0000000000..0a00d07084 --- /dev/null +++ b/src/sql/postgres/types/bool.zig @@ -0,0 +1,20 @@ +pub const to = 16; +pub const from = [_]short{16}; + +pub fn toJS( + _: *JSC.JSGlobalObject, + value: bool, +) AnyPostgresError!JSValue { + return JSValue.jsBoolean(value); +} + +// @sortImports + +const bun = @import("bun"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; + +const int_types = @import("./int_types.zig"); +const short = int_types.short; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/types/bytea.zig b/src/sql/postgres/types/bytea.zig new file mode 100644 index 0000000000..dec468e524 --- /dev/null +++ b/src/sql/postgres/types/bytea.zig @@ -0,0 +1,25 @@ +pub const to = 17; +pub const from = [_]short{17}; + +pub fn toJS( + globalObject: *JSC.JSGlobalObject, + value: *Data, +) AnyPostgresError!JSValue { + defer value.deinit(); + + // var slice = value.slice()[@min(1, value.len)..]; + // _ = slice; + return JSValue.createBuffer(globalObject, value.slice(), null); +} + +// @sortImports + +const bun = @import("bun"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; +const Data = @import("../Data.zig").Data; + +const int_types = @import("./int_types.zig"); +const short = int_types.short; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/types/date.zig b/src/sql/postgres/types/date.zig new file mode 100644 index 0000000000..cdec908240 --- /dev/null +++ b/src/sql/postgres/types/date.zig @@ -0,0 +1,57 @@ +pub const to = 1184; +pub const from = [_]short{ 1082, 1114, 1184 }; + +// Postgres stores timestamp and timestampz as microseconds since 2000-01-01 +// This is a signed 64-bit integer. +const POSTGRES_EPOCH_DATE = 946684800000; + +pub fn fromBinary(bytes: []const u8) f64 { + const microseconds = std.mem.readInt(i64, bytes[0..8], .big); + const double_microseconds: f64 = @floatFromInt(microseconds); + return (double_microseconds / std.time.us_per_ms) + POSTGRES_EPOCH_DATE; +} + +pub fn fromJS(globalObject: *JSC.JSGlobalObject, value: JSValue) i64 { + const double_value = if (value.isDate()) + value.getUnixTimestamp() + else if (value.isNumber()) + value.asNumber() + else if (value.isString()) brk: { + var str = value.toBunString(globalObject) catch @panic("unreachable"); + defer str.deref(); + break :brk str.parseDate(globalObject); + } else return 0; + + const unix_timestamp: i64 = @intFromFloat(double_value); + return (unix_timestamp - POSTGRES_EPOCH_DATE) * std.time.us_per_ms; +} + +pub fn toJS( + globalObject: *JSC.JSGlobalObject, + value: anytype, +) JSValue { + switch (@TypeOf(value)) { + i64 => { + // Convert from Postgres timestamp (μs since 2000-01-01) to Unix timestamp (ms) + const ms = @divFloor(value, std.time.us_per_ms) + POSTGRES_EPOCH_DATE; + return JSValue.fromDateNumber(globalObject, @floatFromInt(ms)); + }, + *Data => { + defer value.deinit(); + return JSValue.fromDateString(globalObject, value.sliceZ().ptr); + }, + else => @compileError("unsupported type " ++ @typeName(@TypeOf(value))), + } +} + +// @sortImports + +const bun = @import("bun"); +const std = @import("std"); +const Data = @import("../Data.zig").Data; + +const int_types = @import("./int_types.zig"); +const short = int_types.short; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/types/int_types.zig b/src/sql/postgres/types/int_types.zig new file mode 100644 index 0000000000..5489a5309d --- /dev/null +++ b/src/sql/postgres/types/int_types.zig @@ -0,0 +1,10 @@ +pub const int4 = u32; +pub const PostgresInt32 = int4; +pub const int8 = i64; +pub const PostgresInt64 = int8; +pub const short = u16; +pub const PostgresShort = u16; + +pub fn Int32(value: anytype) [4]u8 { + return @bitCast(@byteSwap(@as(int4, @intCast(value)))); +} diff --git a/src/sql/postgres/types/json.zig b/src/sql/postgres/types/json.zig new file mode 100644 index 0000000000..0aaa37c173 --- /dev/null +++ b/src/sql/postgres/types/json.zig @@ -0,0 +1,29 @@ +pub const to = 114; +pub const from = [_]short{ 114, 3802 }; + +pub fn toJS( + globalObject: *JSC.JSGlobalObject, + value: *Data, +) AnyPostgresError!JSValue { + defer value.deinit(); + var str = bun.String.fromUTF8(value.slice()); + defer str.deref(); + const parse_result = JSValue.parse(str.toJS(globalObject), globalObject); + if (parse_result.AnyPostgresError()) { + return globalObject.throwValue(parse_result); + } + + return parse_result; +} + +// @sortImports + +const bun = @import("bun"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; +const Data = @import("../Data.zig").Data; + +const int_types = @import("./int_types.zig"); +const short = int_types.short; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue; diff --git a/src/sql/postgres/types/numeric.zig b/src/sql/postgres/types/numeric.zig new file mode 100644 index 0000000000..01897396dc --- /dev/null +++ b/src/sql/postgres/types/numeric.zig @@ -0,0 +1,20 @@ +pub const to = 0; +pub const from = [_]short{ 21, 23, 26, 700, 701 }; + +pub fn toJS( + _: *JSC.JSGlobalObject, + value: anytype, +) AnyPostgresError!JSValue { + return JSValue.jsNumber(value); +} + +// @sortImports + +const bun = @import("bun"); +const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; + +const int_types = @import("./int_types.zig"); +const short = int_types.short; + +const JSC = bun.JSC; +const JSValue = JSC.JSValue;