From 76bfceae81afd87f70163c04daada1849d0b123c Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Mon, 30 Dec 2024 13:25:01 -0800 Subject: [PATCH] Support jsonb, `idle_timeout`, `connection_timeout`, `max_lifetime` timeouts in bun:sql. Add `onopen` and `onclose` callbacks. Fix missing `"code"` property appearing in errors. Add error codes for postgres. (#16045) --- src/bun.js/api/Timer.zig | 11 + src/bun.js/api/postgres.classes.ts | 26 +- src/bun.js/bindings/ErrorCode.cpp | 3 + src/bun.js/bindings/ErrorCode.ts | 32 +- src/bun.js/bindings/JSPropertyIterator.cpp | 30 + src/bun.js/bindings/JSPropertyIterator.zig | 26 +- src/bun.js/bindings/bindings.zig | 4 +- src/bun.js/javascript.zig | 3 +- src/js/bun/sql.ts | 171 ++++- src/sql/postgres.zig | 689 +++++++++++++++------ src/sql/postgres/postgres_protocol.zig | 295 ++++++--- src/sql/postgres/postgres_types.zig | 34 +- src/string.zig | 6 + src/string_builder.zig | 6 + test/js/sql/sql.test.ts | 163 ++++- 15 files changed, 1152 insertions(+), 347 deletions(-) diff --git a/src/bun.js/api/Timer.zig b/src/bun.js/api/Timer.zig index 6f9dcb62e3..ab532acabb 100644 --- a/src/bun.js/api/Timer.zig +++ b/src/bun.js/api/Timer.zig @@ -731,6 +731,8 @@ pub const EventLoopTimer = struct { StatWatcherScheduler, UpgradedDuplex, WindowsNamedPipe, + PostgresSQLConnectionTimeout, + PostgresSQLConnectionMaxLifetime, pub fn Type(comptime T: Tag) type { return switch (T) { @@ -740,6 +742,8 @@ pub const EventLoopTimer = struct { .StatWatcherScheduler => StatWatcherScheduler, .UpgradedDuplex => uws.UpgradedDuplex, .WindowsNamedPipe => uws.WindowsNamedPipe, + .PostgresSQLConnectionTimeout => JSC.Postgres.PostgresSQLConnection, + .PostgresSQLConnectionMaxLifetime => JSC.Postgres.PostgresSQLConnection, }; } } else enum { @@ -748,6 +752,8 @@ pub const EventLoopTimer = struct { TestRunner, StatWatcherScheduler, UpgradedDuplex, + PostgresSQLConnectionTimeout, + PostgresSQLConnectionMaxLifetime, pub fn Type(comptime T: Tag) type { return switch (T) { @@ -756,6 +762,8 @@ pub const EventLoopTimer = struct { .TestRunner => JSC.Jest.TestRunner, .StatWatcherScheduler => StatWatcherScheduler, .UpgradedDuplex => uws.UpgradedDuplex, + .PostgresSQLConnectionTimeout => JSC.Postgres.PostgresSQLConnection, + .PostgresSQLConnectionMaxLifetime => JSC.Postgres.PostgresSQLConnection, }; } }; @@ -808,11 +816,14 @@ pub const EventLoopTimer = struct { pub fn fire(this: *EventLoopTimer, now: *const timespec, vm: *VirtualMachine) Arm { switch (this.tag) { + .PostgresSQLConnectionTimeout => return @as(*JSC.Postgres.PostgresSQLConnection, @alignCast(@fieldParentPtr("timer", this))).onConnectionTimeout(), + .PostgresSQLConnectionMaxLifetime => return @as(*JSC.Postgres.PostgresSQLConnection, @alignCast(@fieldParentPtr("max_lifetime_timer", this))).onMaxLifetimeTimeout(), inline else => |t| { var container: *t.Type() = @alignCast(@fieldParentPtr("event_loop_timer", this)); if (comptime t.Type() == TimerObject) { return container.fire(now, vm); } + if (comptime t.Type() == StatWatcherScheduler) { return container.timerCallback(); } diff --git a/src/bun.js/api/postgres.classes.ts b/src/bun.js/api/postgres.classes.ts index 04097296fc..40664ecb2b 100644 --- a/src/bun.js/api/postgres.classes.ts +++ b/src/bun.js/api/postgres.classes.ts @@ -5,8 +5,8 @@ export default [ name: "PostgresSQLConnection", construct: true, finalize: true, - hasPendingActivity: true, configurable: false, + hasPendingActivity: true, klass: { // escapeString: { // fn: "escapeString", @@ -20,9 +20,6 @@ export default [ close: { fn: "doClose", }, - flush: { - fn: "doFlush", - }, connected: { getter: "getConnected", }, @@ -32,17 +29,30 @@ export default [ unref: { fn: "doUnref", }, - query: { - fn: "createQuery", + + queries: { + getter: "getQueries", + this: true, + }, + onconnect: { + getter: "getOnConnect", + setter: "setOnConnect", + this: true, + }, + onclose: { + getter: "getOnClose", + setter: "setOnClose", + this: true, }, }, + values: ["onconnect", "onclose", "queries"], }), define({ name: "PostgresSQLQuery", construct: true, finalize: true, configurable: false, - hasPendingActivity: true, + JSType: "0b11101110", klass: {}, proto: { @@ -59,7 +69,7 @@ export default [ length: 0, }, }, - values: ["pendingValue", "columns", "binding"], + values: ["pendingValue", "target", "columns", "binding"], estimatedSize: true, }), ]; diff --git a/src/bun.js/bindings/ErrorCode.cpp b/src/bun.js/bindings/ErrorCode.cpp index 1fc96a4da2..8a7d9c2e97 100644 --- a/src/bun.js/bindings/ErrorCode.cpp +++ b/src/bun.js/bindings/ErrorCode.cpp @@ -50,6 +50,9 @@ static JSC::JSObject* createErrorPrototype(JSC::VM& vm, JSC::JSGlobalObject* glo case JSC::ErrorType::URIError: prototype = JSC::constructEmptyObject(globalObject, globalObject->m_URIErrorStructure.prototype(globalObject)); break; + case JSC::ErrorType::SyntaxError: + prototype = JSC::constructEmptyObject(globalObject, globalObject->m_syntaxErrorStructure.prototype(globalObject)); + break; default: { RELEASE_ASSERT_NOT_REACHED_WITH_MESSAGE("TODO: Add support for more error types"); break; diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index bb72e160d2..14e93b2c85 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -98,7 +98,37 @@ export default [ ["ERR_ASYNC_CALLBACK", TypeError], // Postgres - ["ERR_POSTGRES_ERROR_RESPONSE", Error, "PostgresError"], + ["ERR_POSTGRES_AUTHENTICATION_FAILED_PBKDF2", Error, "PostgresError"], + ["ERR_POSTGRES_SERVER_ERROR", Error, "PostgresError"], + ["ERR_POSTGRES_SYNTAX_ERROR", SyntaxError, "PostgresError"], + ["ERR_POSTGRES_CONNECTION_CLOSED", Error, "PostgresError"], + ["ERR_POSTGRES_EXPECTED_REQUEST", Error, "PostgresError"], + ["ERR_POSTGRES_EXPECTED_STATEMENT", Error, "PostgresError"], + ["ERR_POSTGRES_INVALID_BACKEND_KEY_DATA", TypeError, "PostgresError"], + ["ERR_POSTGRES_INVALID_BINARY_DATA", TypeError, "PostgresError"], + ["ERR_POSTGRES_INVALID_BYTE_SEQUENCE_FOR_ENCODING", TypeError, "PostgresError"], + ["ERR_POSTGRES_INVALID_BYTE_SEQUENCE", TypeError, "PostgresError"], + ["ERR_POSTGRES_INVALID_CHARACTER", TypeError, "PostgresError"], + ["ERR_POSTGRES_INVALID_MESSAGE_LENGTH", Error, "PostgresError"], + ["ERR_POSTGRES_INVALID_MESSAGE", Error, "PostgresError"], + ["ERR_POSTGRES_INVALID_QUERY_BINDING", Error, "PostgresError"], + ["ERR_POSTGRES_INVALID_SERVER_KEY", Error, "PostgresError"], + ["ERR_POSTGRES_INVALID_SERVER_SIGNATURE", Error, "PostgresError"], + ["ERR_POSTGRES_MULTIDIMENSIONAL_ARRAY_NOT_SUPPORTED_YET", Error, "PostgresError"], + ["ERR_POSTGRES_NULLS_IN_ARRAY_NOT_SUPPORTED_YET", Error, "PostgresError"], + ["ERR_POSTGRES_OVERFLOW", TypeError, "PostgresError"], + ["ERR_POSTGRES_SASL_SIGNATURE_INVALID_BASE64", Error, "PostgresError"], + ["ERR_POSTGRES_SASL_SIGNATURE_MISMATCH", Error, "PostgresError"], + ["ERR_POSTGRES_TLS_NOT_AVAILABLE", Error, "PostgresError"], + ["ERR_POSTGRES_TLS_UPGRADE_FAILED", Error, "PostgresError"], + ["ERR_POSTGRES_UNEXPECTED_MESSAGE", Error, "PostgresError"], + ["ERR_POSTGRES_UNKNOWN_AUTHENTICATION_METHOD", Error, "PostgresError"], + ["ERR_POSTGRES_UNSUPPORTED_AUTHENTICATION_METHOD", Error, "PostgresError"], + ["ERR_POSTGRES_UNSUPPORTED_BYTEA_FORMAT", TypeError, "PostgresError"], + ["ERR_POSTGRES_UNSUPPORTED_INTEGER_SIZE", TypeError, "PostgresError"], + ["ERR_POSTGRES_IDLE_TIMEOUT", Error, "PostgresError"], + ["ERR_POSTGRES_CONNECTION_TIMEOUT", Error, "PostgresError"], + ["ERR_POSTGRES_LIFETIME_TIMEOUT", Error, "PostgresError"], // AWS ["ERR_AWS_MISSING_CREDENTIALS", Error], diff --git a/src/bun.js/bindings/JSPropertyIterator.cpp b/src/bun.js/bindings/JSPropertyIterator.cpp index 8d4657f469..89e64a8ac8 100644 --- a/src/bun.js/bindings/JSPropertyIterator.cpp +++ b/src/bun.js/bindings/JSPropertyIterator.cpp @@ -90,6 +90,36 @@ extern "C" JSPropertyIterator* Bun__JSPropertyIterator__create(JSC::JSGlobalObje return JSPropertyIterator::create(vm, array.releaseData()); } +// The only non-own property that we sometimes want to get is the code property. +extern "C" EncodedJSValue Bun__JSPropertyIterator__getCodeProperty(JSPropertyIterator* iter, JSC::JSGlobalObject* globalObject, JSC::JSObject* object) +{ + if (UNLIKELY(!iter)) { + return {}; + } + + auto& vm = iter->vm; + auto scope = DECLARE_THROW_SCOPE(vm); + RETURN_IF_EXCEPTION(scope, {}); + if (UNLIKELY(object->type() == JSC::ProxyObjectType)) { + return {}; + } + + auto& builtinNames = WebCore::builtinNames(vm); + + PropertySlot slot(object, PropertySlot::InternalMethodType::VMInquiry, vm.ptr()); + if (!object->getNonIndexPropertySlot(globalObject, builtinNames.codePublicName(), slot)) { + return {}; + } + + if (slot.isAccessor() || slot.isCustom()) { + return {}; + } + + RETURN_IF_EXCEPTION(scope, {}); + + return JSValue::encode(slot.getPureResult()); +} + extern "C" size_t Bun__JSPropertyIterator__getLongestPropertyName(JSPropertyIterator* iter, JSC::JSGlobalObject* globalObject, JSC::JSObject* object) { size_t longest = 0; diff --git a/src/bun.js/bindings/JSPropertyIterator.zig b/src/bun.js/bindings/JSPropertyIterator.zig index ff4a4e77f6..bd55ea078c 100644 --- a/src/bun.js/bindings/JSPropertyIterator.zig +++ b/src/bun.js/bindings/JSPropertyIterator.zig @@ -8,7 +8,7 @@ extern "C" fn Bun__JSPropertyIterator__getNameAndValueNonObservable(iter: ?*anyo extern "C" fn Bun__JSPropertyIterator__getName(iter: ?*anyopaque, propertyName: *bun.String, i: usize) void; extern "C" fn Bun__JSPropertyIterator__deinit(iter: ?*anyopaque) void; extern "C" fn Bun__JSPropertyIterator__getLongestPropertyName(iter: ?*anyopaque, globalObject: *JSC.JSGlobalObject, object: *anyopaque) usize; - +extern "C" fn Bun__JSPropertyIterator__getCodeProperty(iter: ?*anyopaque, globalObject: *JSC.JSGlobalObject, object: *anyopaque) JSC.JSValue; pub const JSPropertyIteratorOptions = struct { skip_empty_name: bool, include_value: bool, @@ -27,6 +27,7 @@ pub fn JSPropertyIterator(comptime options: JSPropertyIteratorOptions) type { globalObject: *JSC.JSGlobalObject, object: *JSC.JSCell = undefined, value: JSC.JSValue = .zero, + tried_code_property: bool = false, pub fn getLongestPropertyName(this: *@This()) usize { if (this.impl == null) return 0; @@ -53,6 +54,7 @@ pub fn JSPropertyIterator(comptime options: JSPropertyIteratorOptions) type { pub fn reset(this: *@This()) void { this.iter_i = 0; this.i = 0; + this.tried_code_property = false; } /// The bun.String returned has not incremented it's reference count. @@ -90,5 +92,27 @@ pub fn JSPropertyIterator(comptime options: JSPropertyIteratorOptions) type { return name; } + + /// "code" is not always an own property, and we want to get it without risking exceptions. + pub fn getCodeProperty(this: *@This()) ?bun.String { + if (comptime !options.include_value) { + @compileError("TODO"); + } + + if (this.tried_code_property) { + return null; + } + + this.tried_code_property = true; + + const current = Bun__JSPropertyIterator__getCodeProperty(this.impl, this.globalObject, this.object); + if (current == .zero) { + return null; + } + current.ensureStillAlive(); + this.value = current; + + return bun.String.static("code"); + } }; } diff --git a/src/bun.js/bindings/bindings.zig b/src/bun.js/bindings/bindings.zig index b143919289..9786b87e46 100644 --- a/src/bun.js/bindings/bindings.zig +++ b/src/bun.js/bindings/bindings.zig @@ -6739,10 +6739,10 @@ pub const CallFrame = opaque { /// arguments(n).mut() -> `var args = argumentsAsArray(n); &args` pub fn arguments_old(self: *const CallFrame, comptime max: usize) Arguments(max) { const slice = self.arguments(); - comptime bun.assert(max <= 10); + comptime bun.assert(max <= 13); return switch (@as(u4, @min(slice.len, max))) { 0 => .{ .ptr = undefined, .len = 0 }, - inline 1...10 => |count| Arguments(max).init(comptime @min(count, max), slice.ptr), + inline 1...13 => |count| Arguments(max).init(comptime @min(count, max), slice.ptr), else => unreachable, }; } diff --git a/src/bun.js/javascript.zig b/src/bun.js/javascript.zig index 6c5274dd0a..3444695992 100644 --- a/src/bun.js/javascript.zig +++ b/src/bun.js/javascript.zig @@ -3956,7 +3956,7 @@ pub const VirtualMachine = struct { defer iterator.deinit(); const longest_name = @min(iterator.getLongestPropertyName(), 10); var is_first_property = true; - while (iterator.next()) |field| { + while (iterator.next() orelse iterator.getCodeProperty()) |field| { const value = iterator.value; if (field.eqlComptime("message") or field.eqlComptime("name") or field.eqlComptime("stack")) { continue; @@ -3966,6 +3966,7 @@ pub const VirtualMachine = struct { if (field.eqlComptime("code")) { if (value.isString()) { const str = value.toBunString(this.global); + defer str.deref(); if (!str.isEmpty()) { if (str.eql(name)) { continue; diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index f4f29f431f..8ce069cb72 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -75,6 +75,15 @@ class Query extends PublicPromise { [_handler]; [_queryStatus] = 0; + [Symbol.for("nodejs.util.inspect.custom")]() { + const status = this[_queryStatus]; + const active = (status & QueryStatus.active) != 0; + const cancelled = (status & QueryStatus.cancelled) != 0; + const executed = (status & QueryStatus.executed) != 0; + const error = (status & QueryStatus.error) != 0; + return `PostgresQuery { ${active ? "active" : ""} ${cancelled ? "cancelled" : ""} ${executed ? "executed" : ""} ${error ? "error" : ""} }`; + } + constructor(handle, handler) { var resolve_, reject_; super((resolve, reject) => { @@ -182,7 +191,7 @@ class Query extends PublicPromise { Object.defineProperty(Query, Symbol.species, { value: PublicPromise }); Object.defineProperty(Query, Symbol.toStringTag, { value: "Query" }); init( - function (query, result, commandTag, count) { + function onResolvePostgresQuery(query, result, commandTag, count, queries) { $assert(result instanceof SQLResultArray, "Invalid result array"); if (typeof commandTag === "string") { if (commandTag.length > 0) { @@ -194,18 +203,48 @@ init( result.count = count || 0; + if (queries) { + const queriesIndex = queries.indexOf(query); + if (queriesIndex !== -1) { + queries.splice(queriesIndex, 1); + } + } + try { query.resolve(result); } catch (e) {} }, - function (query, reject) { + function onRejectPostgresQuery(query, reject, queries) { + if (queries) { + const queriesIndex = queries.indexOf(query); + if (queriesIndex !== -1) { + queries.splice(queriesIndex, 1); + } + } + try { query.reject(reject); } catch (e) {} }, ); -function createConnection({ hostname, port, username, password, tls, query, database, sslMode }, onConnected, onClose) { +function createConnection( + { + hostname, + port, + username, + password, + tls, + query, + database, + sslMode, + idleTimeout = 0, + connectionTimeout = 30 * 1000, + maxLifetime = 0, + }, + onConnected, + onClose, +) { return _createConnection( hostname, Number(port), @@ -221,6 +260,9 @@ function createConnection({ hostname, port, username, password, tls, query, data query || "", onConnected, onClose, + idleTimeout, + connectionTimeout, + maxLifetime, ); } @@ -312,7 +354,20 @@ class SQLArrayParameter { } function loadOptions(o) { - var hostname, port, username, password, database, tls, url, query, adapter; + var hostname, + port, + username, + password, + database, + tls, + url, + query, + adapter, + idleTimeout, + connectionTimeout, + maxLifetime, + onconnect, + onclose; const env = Bun.env; var sslMode: SSLMode = SSLMode.disable; @@ -375,6 +430,48 @@ function loadOptions(o) { tls ||= o.tls || o.ssl; adapter ||= o.adapter || "postgres"; + idleTimeout ??= o.idleTimeout; + idleTimeout ??= o.idle_timeout; + connectionTimeout ??= o.connectionTimeout; + connectionTimeout ??= o.connection_timeout; + maxLifetime ??= o.maxLifetime; + maxLifetime ??= o.max_lifetime; + + onconnect ??= o.onconnect; + onclose ??= o.onclose; + if (onconnect !== undefined) { + if (!$isCallable(onconnect)) { + throw $ERR_INVALID_ARG_TYPE("onconnect", "function", onconnect); + } + } + + if (onclose !== undefined) { + if (!$isCallable(onclose)) { + throw $ERR_INVALID_ARG_TYPE("onclose", "function", onclose); + } + } + + if (idleTimeout != null) { + idleTimeout = Number(idleTimeout); + if (idleTimeout > 2 ** 31 || idleTimeout < 0 || idleTimeout !== idleTimeout) { + throw $ERR_INVALID_ARG_VALUE("idle_timeout must be a non-negative integer less than 2^31"); + } + } + + if (connectionTimeout != null) { + connectionTimeout = Number(connectionTimeout); + if (connectionTimeout > 2 ** 31 || connectionTimeout < 0 || connectionTimeout !== connectionTimeout) { + throw $ERR_INVALID_ARG_VALUE("connection_timeout must be a non-negative integer less than 2^31"); + } + } + + if (maxLifetime != null) { + maxLifetime = Number(maxLifetime); + if (maxLifetime > 2 ** 31 || maxLifetime < 0 || maxLifetime !== maxLifetime) { + throw $ERR_INVALID_ARG_VALUE("max_lifetime must be a non-negative integer less than 2^31"); + } + } + if (sslMode !== SSLMode.disable && !tls?.serverName) { if (hostname) { tls = { @@ -398,7 +495,23 @@ function loadOptions(o) { throw new Error(`Unsupported adapter: ${adapter}. Only \"postgres\" is supported for now`); } - return { hostname, port, username, password, database, tls, query, sslMode }; + const ret: any = { hostname, port, username, password, database, tls, query, sslMode }; + if (idleTimeout != null) { + ret.idleTimeout = idleTimeout; + } + if (connectionTimeout != null) { + ret.connectionTimeout = connectionTimeout; + } + if (maxLifetime != null) { + ret.maxLifetime = maxLifetime; + } + if (onconnect !== undefined) { + ret.onconnect = onconnect; + } + if (onclose !== undefined) { + ret.onclose = onclose; + } + return ret; } function SQL(o) { @@ -407,6 +520,7 @@ function SQL(o) { connecting = false, closed = false, onConnect: any[] = [], + storedErrorForClosedConnection, connectionInfo = loadOptions(o); function connectedHandler(query, handle, err) { @@ -415,7 +529,7 @@ function SQL(o) { } if (!connected) { - return query.reject(new Error("Not connected")); + return query.reject(storedErrorForClosedConnection || new Error("Not connected")); } if (query.cancelled) { @@ -423,6 +537,10 @@ function SQL(o) { } handle.run(connection, query); + + // if the above throws, we don't want it to be in the array. + // This array exists mostly to keep the in-flight queries alive. + connection.queries.push(query); } function pendingConnectionHandler(query, handle) { @@ -434,7 +552,7 @@ function SQL(o) { } function closedConnectionHandler(query, handle) { - query.reject(new Error("Connection closed")); + query.reject(storedErrorForClosedConnection || new Error("Connection closed")); } function onConnected(err, result) { @@ -443,11 +561,31 @@ function SQL(o) { handler(err); } onConnect = []; + + if (connected && connectionInfo?.onconnect) { + connectionInfo.onconnect(err); + } } - function onClose(err) { + function onClose(err, queries) { closed = true; + storedErrorForClosedConnection = err; + if (sql === lazyDefaultSQL) { + resetDefaultSQL(initialDefaultSQL); + } + onConnected(err, undefined); + if (queries) { + const queriesCopy = queries.slice(); + queries.length = 0; + for (const handler of queriesCopy) { + handler.reject(err); + } + } + + if (connectionInfo?.onclose) { + connectionInfo.onclose(err); + } } function doCreateQuery(strings, values) { @@ -568,18 +706,23 @@ function SQL(o) { } var lazyDefaultSQL; -var defaultSQLObject = function sql(strings, ...values) { + +function resetDefaultSQL(sql) { + lazyDefaultSQL = sql; + Object.assign(defaultSQLObject, lazyDefaultSQL); + exportsObject.default = exportsObject.sql = lazyDefaultSQL; +} + +var initialDefaultSQL; +var defaultSQLObject = (initialDefaultSQL = function sql(strings, ...values) { if (new.target) { return SQL(strings); } - if (!lazyDefaultSQL) { - lazyDefaultSQL = SQL(undefined); - Object.assign(defaultSQLObject, lazyDefaultSQL); - exportsObject.default = exportsObject.sql = lazyDefaultSQL; + resetDefaultSQL(SQL(undefined)); } return lazyDefaultSQL(strings, ...values); -}; +}); var exportsObject = { sql: defaultSQLObject, diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index 026706100e..296ce7ee1c 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -13,6 +13,37 @@ pub const PostgresShort = u16; const Crypto = JSC.API.Bun.Crypto; const JSValue = JSC.JSValue; const BoringSSL = @import("../boringssl.zig"); +pub const AnyPostgresError = error{ + ConnectionClosed, + ExpectedRequest, + ExpectedStatement, + InvalidBackendKeyData, + InvalidBinaryData, + InvalidByteSequence, + InvalidByteSequenceForEncoding, + InvalidCharacter, + InvalidMessage, + InvalidMessageLength, + InvalidQueryBinding, + InvalidServerKey, + InvalidServerSignature, + JSError, + MultidimensionalArrayNotSupportedYet, + NullsInArrayNotSupportedYet, + OutOfMemory, + Overflow, + PBKDFD2, + SASL_SIGNATURE_MISMATCH, + SASL_SIGNATURE_INVALID_BASE64, + ShortRead, + TLSNotAvailable, + TLSUpgradeFailed, + UnexpectedMessage, + UNKNOWN_AUTHENTICATION_METHOD, + UNSUPPORTED_AUTHENTICATION_METHOD, + UnsupportedByteaFormat, + UnsupportedIntegerSize, +}; pub const SSLMode = enum(u8) { disable = 0, @@ -176,16 +207,23 @@ pub const PostgresSQLQuery = struct { statement: ?*PostgresSQLStatement = null, query: bun.String = bun.String.empty, cursor_name: bun.String = bun.String.empty, + + // Kept alive by being in the "queries" array from JS. thisValue: JSValue = .undefined, - target: JSC.Strong = JSC.Strong.init(), + status: Status = Status.pending, is_done: bool = false, ref_count: std.atomic.Value(u32) = std.atomic.Value(u32).init(1), binary: bool = false, - pending_value: JSC.Strong = .{}, pub usingnamespace JSC.Codegen.JSPostgresSQLQuery; + pub fn getTarget(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject) JSC.JSValue { + const target = PostgresSQLQuery.targetGetCached(this.thisValue) orelse return .zero; + PostgresSQLQuery.targetSetCached(this.thisValue, globalObject, .zero); + return target; + } + pub const Status = enum(u8) { pending, written, @@ -209,9 +247,6 @@ pub const PostgresSQLQuery = struct { } this.query.deref(); this.cursor_name.deref(); - this.target.deinit(); - this.pending_value.deinit(); - bun.default_allocator.destroy(this); } @@ -233,12 +268,12 @@ pub const PostgresSQLQuery = struct { bun.assert(this.ref_count.fetchAdd(1, .monotonic) > 0); } - pub fn onNoData(this: *@This(), globalObject: *JSC.JSGlobalObject) void { + pub fn onNoData(this: *@This(), globalObject: *JSC.JSGlobalObject, queries_array: JSValue) void { this.status = .success; defer this.deref(); const thisValue = this.thisValue; - const targetValue = this.target.trySwap() orelse JSValue.zero; + const targetValue = this.getTarget(globalObject); if (thisValue == .zero or targetValue == .zero) { return; } @@ -251,13 +286,18 @@ pub const PostgresSQLQuery = struct { this.pending_value.trySwap() orelse .undefined, JSValue.jsNumber(0), JSValue.jsNumber(0), + queries_array, }); } - pub fn onWriteFail(this: *@This(), err: anyerror, globalObject: *JSC.JSGlobalObject) void { + pub fn onWriteFail( + this: *@This(), + err: AnyPostgresError, + globalObject: *JSC.JSGlobalObject, + queries_array: JSValue, + ) void { this.status = .fail; - this.pending_value.deinit(); const thisValue = this.thisValue; - const targetValue = this.target.trySwap() orelse JSValue.zero; + const targetValue = this.getTarget(globalObject); if (thisValue == .zero or targetValue == .zero) { return; } @@ -271,6 +311,7 @@ pub const PostgresSQLQuery = struct { event_loop.runCallback(function, globalObject, thisValue, &.{ targetValue, instance, + queries_array, }); } @@ -279,7 +320,7 @@ pub const PostgresSQLQuery = struct { defer this.deref(); const thisValue = this.thisValue; - const targetValue = this.target.trySwap() orelse JSValue.zero; + const targetValue = this.getTarget(globalObject); if (thisValue == .zero or targetValue == .zero) { return; } @@ -388,14 +429,31 @@ pub const PostgresSQLQuery = struct { } }; - pub fn onSuccess(this: *@This(), command_tag_str: []const u8, globalObject: *JSC.JSGlobalObject) void { + pub fn allowGC(thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) void { + if (thisValue == .zero) { + return; + } + + defer thisValue.ensureStillAlive(); + PostgresSQLQuery.bindingSetCached(thisValue, globalObject, .zero); + PostgresSQLQuery.pendingValueSetCached(thisValue, globalObject, .zero); + PostgresSQLQuery.targetSetCached(thisValue, globalObject, .zero); + } + + fn consumePendingValue(thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) ?JSValue { + const pending_value = PostgresSQLQuery.pendingValueGetCached(thisValue) orelse return null; + PostgresSQLQuery.pendingValueSetCached(thisValue, globalObject, .zero); + return pending_value; + } + + pub fn onSuccess(this: *@This(), command_tag_str: []const u8, globalObject: *JSC.JSGlobalObject, connection: JSC.JSValue) void { this.status = .success; defer this.deref(); const thisValue = this.thisValue; - const targetValue = this.target.trySwap() orelse JSValue.zero; + const targetValue = this.getTarget(globalObject); + defer allowGC(thisValue, globalObject); if (thisValue == .zero or targetValue == .zero) { - this.pending_value.deinit(); return; } @@ -407,9 +465,10 @@ pub const PostgresSQLQuery = struct { event_loop.runCallback(function, globalObject, thisValue, &.{ targetValue, - this.pending_value.trySwap() orelse .undefined, + consumePendingValue(thisValue, globalObject) orelse .undefined, tag.toJSTag(globalObject), tag.toJSNumber(), + PostgresSQLConnection.queriesGetCached(connection) orelse .undefined, }); } @@ -458,7 +517,6 @@ pub const PostgresSQLQuery = struct { if (columns != .undefined) { PostgresSQLQuery.columnsSetCached(this_value, globalThis, columns); } - ptr.pending_value.set(globalThis, pending_value); return this_value; } @@ -477,7 +535,7 @@ pub const PostgresSQLQuery = struct { pub fn doRun(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { var arguments_ = callframe.arguments_old(2); const arguments = arguments_.slice(); - var connection = arguments[0].as(PostgresSQLConnection) orelse { + const connection: *PostgresSQLConnection = arguments[0].as(PostgresSQLConnection) orelse { return globalObject.throw("connection must be a PostgresSQLConnection", .{}); }; var query = arguments[1]; @@ -486,11 +544,11 @@ pub const PostgresSQLQuery = struct { return globalObject.throwInvalidArgumentType("run", "query", "Query"); } - this.target.set(globalObject, query); - const binding_value = PostgresSQLQuery.bindingGetCached(callframe.this()) orelse .zero; + const this_value = callframe.this(); + const binding_value = PostgresSQLQuery.bindingGetCached(this_value) orelse .zero; var query_str = this.query.toUTF8(bun.default_allocator); defer query_str.deinit(); - const columns_value = PostgresSQLQuery.columnsGetCached(callframe.this()) orelse .undefined; + const columns_value = PostgresSQLQuery.columnsGetCached(this_value) orelse .undefined; var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value) catch |err| { if (!globalObject.hasException()) @@ -568,9 +626,12 @@ pub const PostgresSQLQuery = struct { connection.requests.writeItem(this) catch {}; this.ref(); this.status = if (did_write) .binding else .pending; + PostgresSQLQuery.targetSetCached(this_value, globalObject, query); if (connection.is_ready_for_query) - connection.flushData(); + connection.flushDataAndResetTimeout() + else if (did_write) + connection.resetConnectionTimeout(); return .undefined; } @@ -665,8 +726,10 @@ pub const PostgresRequest = struct { try writer.int4(@bitCast(@as(i32, -1))); continue; } + if (comptime bun.Environment.enable_logs) { + debug(" -> {s}", .{tag.name() orelse "(unknown)"}); + } - debug(" -> {s}", .{@tagName(tag)}); switch ( // If they pass a value as a string, let's avoid attempting to // convert it to the binary representation. This minimizes the room @@ -674,7 +737,7 @@ pub const PostgresRequest = struct { // differently than what Postgres does when given a timestamp with // timezone. if (tag.isBinaryFormatSupported() and value.isString()) .text else tag) { - .json => { + .jsonb, .json => { var str = bun.String.empty; defer str.deref(); value.jsonStringify(globalObject, 0, &str); @@ -761,7 +824,7 @@ pub const PostgresRequest = struct { params: []const int4, comptime Context: type, writer: protocol.NewWriter(Context), - ) !void { + ) AnyPostgresError!void { { var q = protocol.Parse{ .name = name, @@ -790,7 +853,7 @@ pub const PostgresRequest = struct { comptime Context: type, writer: protocol.NewWriter(Context), signature: *Signature, - ) !void { + ) AnyPostgresError!void { try writeQuery(query, signature.name, signature.fields, Context, writer); try writeBind(signature.name, bun.String.empty, globalObject, array_value, .zero, &.{}, &.{}, Context, writer); var exec = protocol.Execute{ @@ -863,7 +926,7 @@ pub const PostgresRequest = struct { connection.tls_status = .ssl_not_available; debug("Server does not support SSL", .{}); if (connection.ssl_mode == .require) { - connection.fail("Server does not support SSL", error.SSLNotAvailable); + connection.fail("Server does not support SSL", error.TLSNotAvailable); return; } continue; @@ -912,9 +975,6 @@ pub const PostgresSQLConnection = struct { pending_disconnect: bool = false, - on_connect: JSC.Strong = .{}, - on_close: JSC.Strong = .{}, - database: []const u8 = "", user: []const u8 = "", password: []const u8 = "", @@ -928,6 +988,31 @@ pub const PostgresSQLConnection = struct { tls_status: TLSStatus = .none, ssl_mode: SSLMode = .disable, + idle_timeout_interval_ms: u32 = 0, + connection_timeout_ms: u32 = 0, + + /// Before being connected, this is a connection timeout timer. + /// After being connected, this is an idle timeout timer. + timer: JSC.BunTimer.EventLoopTimer = .{ + .tag = .PostgresSQLConnectionTimeout, + .next = .{ + .sec = 0, + .nsec = 0, + }, + }, + + /// This timer controls the maximum lifetime of a connection. + /// It starts when the connection successfully starts (i.e. after handshake is complete). + /// It stops when the connection is closed. + max_lifetime_interval_ms: u32 = 0, + max_lifetime_timer: JSC.BunTimer.EventLoopTimer = .{ + .tag = .PostgresSQLConnectionMaxLifetime, + .next = .{ + .sec = 0, + .nsec = 0, + }, + }, + pub const TLSStatus = union(enum) { none, pending, @@ -942,100 +1027,107 @@ pub const PostgresSQLConnection = struct { pub const AuthenticationState = union(enum) { pending: void, - SASL: SASL, + none: void, ok: void, + SASL: SASL, + md5: void, pub fn zero(this: *AuthenticationState) void { - const bytes = std.mem.asBytes(this); - @memset(bytes, 0); + switch (this.*) { + .SASL => |*sasl| { + sasl.deinit(); + }, + else => {}, + } + this.* = .{ .none = {} }; + } + }; + + pub const SASL = struct { + const nonce_byte_len = 18; + const nonce_base64_len = bun.base64.encodeLenFromSize(nonce_byte_len); + + const server_signature_byte_len = 32; + const server_signature_base64_len = bun.base64.encodeLenFromSize(server_signature_byte_len); + + const salted_password_byte_len = 32; + + nonce_base64_bytes: [nonce_base64_len]u8 = .{0} ** nonce_base64_len, + nonce_len: u8 = 0, + + server_signature_base64_bytes: [server_signature_base64_len]u8 = .{0} ** server_signature_base64_len, + server_signature_len: u8 = 0, + + salted_password_bytes: [salted_password_byte_len]u8 = .{0} ** salted_password_byte_len, + salted_password_created: bool = false, + + status: SASLStatus = .init, + + pub const SASLStatus = enum { + init, + @"continue", + }; + + fn hmac(password: []const u8, data: []const u8) ?[32]u8 { + var buf = std.mem.zeroes([bun.BoringSSL.EVP_MAX_MD_SIZE]u8); + + // TODO: I don't think this is failable. + const result = bun.hmac.generate(password, data, .sha256, &buf) orelse return null; + + assert(result.len == 32); + return buf[0..32].*; } - pub const SASL = struct { - const nonce_byte_len = 18; - const nonce_base64_len = bun.base64.encodeLenFromSize(nonce_byte_len); - - const server_signature_byte_len = 32; - const server_signature_base64_len = bun.base64.encodeLenFromSize(server_signature_byte_len); - - const salted_password_byte_len = 32; - - nonce_base64_bytes: [nonce_base64_len]u8 = .{0} ** nonce_base64_len, - nonce_len: u8 = 0, - - server_signature_base64_bytes: [server_signature_base64_len]u8 = .{0} ** server_signature_base64_len, - server_signature_len: u8 = 0, - - salted_password_bytes: [salted_password_byte_len]u8 = .{0} ** salted_password_byte_len, - salted_password_created: bool = false, - - status: SASLStatus = .init, - - pub const SASLStatus = enum { - init, - @"continue", - }; - - fn hmac(password: []const u8, data: []const u8) ?[32]u8 { - var buf = std.mem.zeroes([bun.BoringSSL.EVP_MAX_MD_SIZE]u8); - - // TODO: I don't think this is failable. - const result = bun.hmac.generate(password, data, .sha256, &buf) orelse return null; - - assert(result.len == 32); - return buf[0..32].*; + pub fn computeSaltedPassword(this: *SASL, salt_bytes: []const u8, iteration_count: u32, connection: *PostgresSQLConnection) !void { + this.salted_password_created = true; + if (Crypto.EVP.pbkdf2(&this.salted_password_bytes, connection.password, salt_bytes, iteration_count, .sha256) == null) { + return error.PBKDFD2; } + } - pub fn computeSaltedPassword(this: *SASL, salt_bytes: []const u8, iteration_count: u32, connection: *PostgresSQLConnection) !void { - this.salted_password_created = true; - if (Crypto.EVP.pbkdf2(&this.salted_password_bytes, connection.password, salt_bytes, iteration_count, .sha256) == null) { - return error.PBKDF2Failed; - } + pub fn saltedPassword(this: *const SASL) []const u8 { + assert(this.salted_password_created); + return this.salted_password_bytes[0..salted_password_byte_len]; + } + + pub fn serverSignature(this: *const SASL) []const u8 { + assert(this.server_signature_len > 0); + return this.server_signature_base64_bytes[0..this.server_signature_len]; + } + + pub fn computeServerSignature(this: *SASL, auth_string: []const u8) !void { + assert(this.server_signature_len == 0); + + const server_key = hmac(this.saltedPassword(), "Server Key") orelse return error.InvalidServerKey; + const server_signature_bytes = hmac(&server_key, auth_string) orelse return error.InvalidServerSignature; + this.server_signature_len = @intCast(bun.base64.encode(&this.server_signature_base64_bytes, &server_signature_bytes)); + } + + pub fn clientKey(this: *const SASL) [32]u8 { + return hmac(this.saltedPassword(), "Client Key").?; + } + + pub fn clientKeySignature(_: *const SASL, client_key: []const u8, auth_string: []const u8) [32]u8 { + var sha_digest = std.mem.zeroes(bun.sha.SHA256.Digest); + bun.sha.SHA256.hash(client_key, &sha_digest, JSC.VirtualMachine.get().rareData().boringEngine()); + return hmac(&sha_digest, auth_string).?; + } + + pub fn nonce(this: *SASL) []const u8 { + if (this.nonce_len == 0) { + var bytes: [nonce_byte_len]u8 = .{0} ** nonce_byte_len; + bun.rand(&bytes); + this.nonce_len = @intCast(bun.base64.encode(&this.nonce_base64_bytes, &bytes)); } + return this.nonce_base64_bytes[0..this.nonce_len]; + } - pub fn saltedPassword(this: *const SASL) []const u8 { - assert(this.salted_password_created); - return this.salted_password_bytes[0..salted_password_byte_len]; - } - - pub fn serverSignature(this: *const SASL) []const u8 { - assert(this.server_signature_len > 0); - return this.server_signature_base64_bytes[0..this.server_signature_len]; - } - - pub fn computeServerSignature(this: *SASL, auth_string: []const u8) !void { - assert(this.server_signature_len == 0); - - const server_key = hmac(this.saltedPassword(), "Server Key") orelse return error.InvalidServerKey; - const server_signature_bytes = hmac(&server_key, auth_string) orelse return error.InvalidServerSignature; - this.server_signature_len = @intCast(bun.base64.encode(&this.server_signature_base64_bytes, &server_signature_bytes)); - } - - pub fn clientKey(this: *const SASL) [32]u8 { - return hmac(this.saltedPassword(), "Client Key").?; - } - - pub fn clientKeySignature(_: *const SASL, client_key: []const u8, auth_string: []const u8) [32]u8 { - var sha_digest = std.mem.zeroes(bun.sha.SHA256.Digest); - bun.sha.SHA256.hash(client_key, &sha_digest, JSC.VirtualMachine.get().rareData().boringEngine()); - return hmac(&sha_digest, auth_string).?; - } - - pub fn nonce(this: *SASL) []const u8 { - if (this.nonce_len == 0) { - var bytes: [nonce_byte_len]u8 = .{0} ** nonce_byte_len; - bun.rand(&bytes); - this.nonce_len = @intCast(bun.base64.encode(&this.nonce_base64_bytes, &bytes)); - } - return this.nonce_base64_bytes[0..this.nonce_len]; - } - - pub fn deinit(this: *SASL) void { - this.nonce_len = 0; - this.salted_password_created = false; - this.server_signature_len = 0; - this.status = .init; - } - }; + pub fn deinit(this: *SASL) void { + this.nonce_len = 0; + this.salted_password_created = false; + this.server_signature_len = 0; + this.status = .init; + } }; pub const Status = enum { @@ -1050,6 +1142,64 @@ pub const PostgresSQLConnection = struct { pub usingnamespace JSC.Codegen.JSPostgresSQLConnection; + fn getTimeoutInterval(this: *const PostgresSQLConnection) u32 { + return switch (this.status) { + .connected => this.idle_timeout_interval_ms, + .failed => 0, + else => this.connection_timeout_ms, + }; + } + + pub fn resetConnectionTimeout(this: *PostgresSQLConnection) void { + const interval = this.getTimeoutInterval(); + if (this.timer.state == .ACTIVE) { + this.globalObject.bunVM().timer.remove(&this.timer); + } + if (interval == 0) { + return; + } + + this.timer.next = bun.timespec.msFromNow(@intCast(interval)); + this.globalObject.bunVM().timer.insert(&this.timer); + } + + pub fn getQueries(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject) JSC.JSValue { + if (PostgresSQLConnection.queriesGetCached(thisValue)) |value| { + return value; + } + + const array = JSC.JSValue.createEmptyArray(globalObject, 0); + PostgresSQLConnection.queriesSetCached(thisValue, globalObject, array); + + return array; + } + + pub fn getOnConnect(_: *PostgresSQLConnection, thisValue: JSC.JSValue, _: *JSC.JSGlobalObject) JSC.JSValue { + if (PostgresSQLConnection.onconnectGetCached(thisValue)) |value| { + return value; + } + + return .undefined; + } + + pub fn setOnConnect(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject, value: JSC.JSValue) bool { + PostgresSQLConnection.onconnectSetCached(thisValue, globalObject, value); + return true; + } + + pub fn getOnClose(_: *PostgresSQLConnection, thisValue: JSC.JSValue, _: *JSC.JSGlobalObject) JSC.JSValue { + if (PostgresSQLConnection.oncloseGetCached(thisValue)) |value| { + return value; + } + + return .undefined; + } + + pub fn setOnClose(_: *PostgresSQLConnection, thisValue: JSC.JSValue, globalObject: *JSC.JSGlobalObject, value: JSC.JSValue) bool { + PostgresSQLConnection.oncloseSetCached(thisValue, globalObject, value); + return true; + } + pub fn setupTLS(this: *PostgresSQLConnection) void { debug("setupTLS", .{}); const new_socket = uws.us_socket_upgrade_to_tls(this.socket.SocketTCP.socket.connected, this.tls_ctx.?, this.tls_config.server_name) orelse { @@ -1066,8 +1216,47 @@ pub const PostgresSQLConnection = struct { this.start(); } + fn setupMaxLifetimeTimerIfNecessary(this: *PostgresSQLConnection) void { + if (this.max_lifetime_interval_ms == 0) return; + if (this.max_lifetime_timer.state == .ACTIVE) return; + + this.max_lifetime_timer.next = bun.timespec.msFromNow(@intCast(this.max_lifetime_interval_ms)); + this.globalObject.bunVM().timer.insert(&this.max_lifetime_timer); + } + + pub fn onConnectionTimeout(this: *PostgresSQLConnection) JSC.BunTimer.EventLoopTimer.Arm { + debug("onConnectionTimeout", .{}); + this.timer.state = .FIRED; + if (this.getTimeoutInterval() == 0) { + this.resetConnectionTimeout(); + return .disarm; + } + + switch (this.status) { + .connected => { + this.failFmt(.ERR_POSTGRES_IDLE_TIMEOUT, "Idle timeout reached after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.idle_timeout_interval_ms) *| std.time.ns_per_ms)}); + }, + else => { + this.failFmt(.ERR_POSTGRES_CONNECTION_TIMEOUT, "Connection timeout after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.connection_timeout_ms) *| std.time.ns_per_ms)}); + }, + .sent_startup_message => { + this.failFmt(.ERR_POSTGRES_CONNECTION_TIMEOUT, "Connection timed out after {} (sent startup message, but never received response)", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.connection_timeout_ms) *| std.time.ns_per_ms)}); + }, + } + return .disarm; + } + + pub fn onMaxLifetimeTimeout(this: *PostgresSQLConnection) JSC.BunTimer.EventLoopTimer.Arm { + debug("onMaxLifetimeTimeout", .{}); + this.max_lifetime_timer.state = .FIRED; + if (this.status == .failed) return .disarm; + this.failFmt(.ERR_POSTGRES_LIFETIME_TIMEOUT, "Max lifetime timeout reached after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.max_lifetime_interval_ms) *| std.time.ns_per_ms)}); + return .disarm; + } fn start(this: *PostgresSQLConnection) void { + this.setupMaxLifetimeTimerIfNecessary(); + this.resetConnectionTimeout(); this.sendStartupMessage(); const event_loop = this.globalObject.bunVM().eventLoop(); @@ -1094,10 +1283,11 @@ pub const PostgresSQLConnection = struct { if (this.status == status) return; this.status = status; + this.resetConnectionTimeout(); + switch (status) { .connected => { - const on_connect = this.on_connect.swap(); - if (on_connect == .zero) return; + const on_connect = this.consumeOnConnectCallback(this.globalObject) orelse return; const js_value = this.js_value; js_value.ensureStillAlive(); this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ JSValue.jsNull(), js_value }); @@ -1110,10 +1300,16 @@ pub const PostgresSQLConnection = struct { pub fn finalize(this: *PostgresSQLConnection) void { debug("PostgresSQLConnection finalize", .{}); + this.stopTimers(); this.js_value = .zero; this.deref(); } + pub fn flushDataAndResetTimeout(this: *PostgresSQLConnection) void { + this.resetConnectionTimeout(); + this.flushData(); + } + pub fn flushData(this: *PostgresSQLConnection) void { const chunk = this.write_buffer.remaining(); if (chunk.len == 0) return; @@ -1126,27 +1322,78 @@ pub const PostgresSQLConnection = struct { pub fn failWithJSValue(this: *PostgresSQLConnection, value: JSValue) void { defer this.updateHasPendingActivity(); + this.stopTimers(); if (this.status == .failed) return; this.status = .failed; - if (!this.socket.isClosed()) this.socket.close(); - const on_close = this.on_close.swap(); - if (on_close == .zero) return; + this.ref(); + defer this.deref(); + if (!this.socket.isClosed()) this.socket.close(); + const on_close = this.consumeOnCloseCallback(this.globalObject) orelse return; + + const loop = this.globalObject.bunVM().eventLoop(); + loop.enter(); + defer loop.exit(); _ = on_close.call( this.globalObject, this.js_value, &[_]JSValue{ value, + this.getQueriesArray(), }, ) catch |e| this.globalObject.reportActiveExceptionAsUnhandled(e); } - pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: anyerror) void { + pub fn failFmt(this: *PostgresSQLConnection, comptime error_code: JSC.Error, comptime fmt: [:0]const u8, args: anytype) void { + this.failWithJSValue(error_code.fmt(this.globalObject, fmt, args)); + } + + pub fn fail(this: *PostgresSQLConnection, message: []const u8, err: AnyPostgresError) void { debug("failed: {s}: {s}", .{ message, @errorName(err) }); - const instance = this.globalObject.createErrorInstance("{s}", .{message}); - instance.put(this.globalObject, JSC.ZigString.static("code"), String.init(@errorName(err)).toJS(this.globalObject)); - this.failWithJSValue(instance); + + const globalObject = this.globalObject; + const error_code: JSC.Error = switch (err) { + error.ConnectionClosed => JSC.Error.ERR_POSTGRES_CONNECTION_CLOSED, + error.ExpectedRequest => JSC.Error.ERR_POSTGRES_EXPECTED_REQUEST, + error.ExpectedStatement => JSC.Error.ERR_POSTGRES_EXPECTED_STATEMENT, + error.InvalidBackendKeyData => JSC.Error.ERR_POSTGRES_INVALID_BACKEND_KEY_DATA, + error.InvalidBinaryData => JSC.Error.ERR_POSTGRES_INVALID_BINARY_DATA, + error.InvalidByteSequence => JSC.Error.ERR_POSTGRES_INVALID_BYTE_SEQUENCE, + error.InvalidByteSequenceForEncoding => JSC.Error.ERR_POSTGRES_INVALID_BYTE_SEQUENCE_FOR_ENCODING, + error.InvalidCharacter => JSC.Error.ERR_POSTGRES_INVALID_CHARACTER, + error.InvalidMessage => JSC.Error.ERR_POSTGRES_INVALID_MESSAGE, + error.InvalidMessageLength => JSC.Error.ERR_POSTGRES_INVALID_MESSAGE_LENGTH, + error.InvalidQueryBinding => JSC.Error.ERR_POSTGRES_INVALID_QUERY_BINDING, + error.InvalidServerKey => JSC.Error.ERR_POSTGRES_INVALID_SERVER_KEY, + error.InvalidServerSignature => JSC.Error.ERR_POSTGRES_INVALID_SERVER_SIGNATURE, + error.MultidimensionalArrayNotSupportedYet => JSC.Error.ERR_POSTGRES_MULTIDIMENSIONAL_ARRAY_NOT_SUPPORTED_YET, + error.NullsInArrayNotSupportedYet => JSC.Error.ERR_POSTGRES_NULLS_IN_ARRAY_NOT_SUPPORTED_YET, + error.Overflow => JSC.Error.ERR_POSTGRES_OVERFLOW, + error.PBKDFD2 => JSC.Error.ERR_POSTGRES_AUTHENTICATION_FAILED_PBKDF2, + error.SASL_SIGNATURE_MISMATCH => JSC.Error.ERR_POSTGRES_SASL_SIGNATURE_MISMATCH, + error.SASL_SIGNATURE_INVALID_BASE64 => JSC.Error.ERR_POSTGRES_SASL_SIGNATURE_INVALID_BASE64, + error.TLSNotAvailable => JSC.Error.ERR_POSTGRES_TLS_NOT_AVAILABLE, + error.TLSUpgradeFailed => JSC.Error.ERR_POSTGRES_TLS_UPGRADE_FAILED, + error.UnexpectedMessage => JSC.Error.ERR_POSTGRES_UNEXPECTED_MESSAGE, + error.UNKNOWN_AUTHENTICATION_METHOD => JSC.Error.ERR_POSTGRES_UNKNOWN_AUTHENTICATION_METHOD, + error.UNSUPPORTED_AUTHENTICATION_METHOD => JSC.Error.ERR_POSTGRES_UNSUPPORTED_AUTHENTICATION_METHOD, + error.UnsupportedByteaFormat => JSC.Error.ERR_POSTGRES_UNSUPPORTED_BYTEA_FORMAT, + error.UnsupportedIntegerSize => JSC.Error.ERR_POSTGRES_UNSUPPORTED_INTEGER_SIZE, + error.JSError => { + this.failWithJSValue(globalObject.takeException(error.JSError)); + return; + }, + error.OutOfMemory => { + // TODO: add binding for creating an out of memory error? + this.failWithJSValue(globalObject.takeException(globalObject.throwOutOfMemory())); + return; + }, + error.ShortRead => { + bun.unreachablePanic("Assertion failed: ShortRead should be handled by the caller in postgres", .{}); + }, + }; + this.failWithJSValue(error_code.fmt(globalObject, "{s}", .{message})); } pub fn onClose(this: *PostgresSQLConnection) void { @@ -1210,22 +1457,35 @@ pub const PostgresSQLConnection = struct { pub fn onHandshake(this: *PostgresSQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void { debug("onHandshake: {d} {d}", .{ success, ssl_error.error_no }); + if (this.tls_config.reject_unauthorized == 0) { + return; + } + + const do_tls_verification = switch (this.ssl_mode) { + // https://github.com/porsager/postgres/blob/6ec85a432b17661ccacbdf7f765c651e88969d36/src/connection.js#L272-L279 + .verify_ca, .verify_full => true, + else => false, + }; + + if (!do_tls_verification) { + return; + } + if (success != 1) { this.failWithJSValue(ssl_error.toJS(this.globalObject)); return; } - if (this.tls_config.reject_unauthorized == 1) { - if (ssl_error.error_no != 0) { + if (ssl_error.error_no != 0) { + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + return; + } + + const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); + if (BoringSSL.SSL_get_servername(ssl_ptr, 0)) |servername| { + const hostname = servername[0..bun.len(servername)]; + if (!BoringSSL.checkServerIdentity(ssl_ptr, hostname)) { this.failWithJSValue(ssl_error.toJS(this.globalObject)); - return; - } - const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); - if (BoringSSL.SSL_get_servername(ssl_ptr, 0)) |servername| { - const hostname = servername[0..bun.len(servername)]; - if (!BoringSSL.checkServerIdentity(ssl_ptr, hostname)) { - this.failWithJSValue(ssl_error.toJS(this.globalObject)); - } } } } @@ -1264,6 +1524,7 @@ pub const PostgresSQLConnection = struct { this.poll_ref.ref(vm); } + this.resetConnectionTimeout(); this.deref(); } @@ -1299,11 +1560,8 @@ pub const PostgresSQLConnection = struct { this.read_buffer.byte_list.len = 0; this.read_buffer.write(bun.default_allocator, data[offset..]) catch @panic("failed to write to read buffer"); } else { - if (comptime bun.Environment.allow_assert) { - if (@errorReturnTrace()) |trace| { - debug("Error: {s}\n{}", .{ @errorName(err), trace }); - } - } + bun.handleErrorReturnTrace(err, @errorReturnTrace()); + this.fail("Failed to read data", err); } }; @@ -1315,11 +1573,7 @@ pub const PostgresSQLConnection = struct { this.read_buffer.write(bun.default_allocator, data) catch @panic("failed to write to read buffer"); PostgresRequest.onData(this, Reader, this.bufferedReader()) catch |err| { if (err != error.ShortRead) { - if (comptime bun.Environment.allow_assert) { - if (@errorReturnTrace()) |trace| { - debug("Error: {s}\n{}", .{ @errorName(err), trace }); - } - } + bun.handleErrorReturnTrace(err, @errorReturnTrace()); this.fail("Failed to read data", err); return; } @@ -1363,7 +1617,7 @@ pub const PostgresSQLConnection = struct { pub fn call(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { var vm = globalObject.bunVM(); - const arguments = callframe.arguments_old(10).slice(); + const arguments = callframe.arguments_old(13).slice(); const hostname_str = arguments[0].toBunString(globalObject); defer hostname_str.deref(); const port = arguments[1].coerce(i32, globalObject); @@ -1460,13 +1714,15 @@ pub const PostgresSQLConnection = struct { const on_connect = arguments[8]; const on_close = arguments[9]; + const idle_timeout = arguments[10].toInt32(); + const connection_timeout = arguments[11].toInt32(); + const max_lifetime = arguments[12].toInt32(); - var ptr = try bun.default_allocator.create(PostgresSQLConnection); + const ptr: *PostgresSQLConnection = try bun.default_allocator.create(PostgresSQLConnection); ptr.* = PostgresSQLConnection{ .globalObject = globalObject, - .on_connect = JSC.Strong.create(on_connect, globalObject), - .on_close = JSC.Strong.create(on_close, globalObject), + .database = database, .user = username, .password = password, @@ -1479,6 +1735,9 @@ pub const PostgresSQLConnection = struct { .tls_ctx = tls_ctx, .ssl_mode = ssl_mode, .tls_status = if (ssl_mode != .disable) .pending else .none, + .idle_timeout_interval_ms = @intCast(idle_timeout), + .connection_timeout_ms = @intCast(connection_timeout), + .max_lifetime_interval_ms = @intCast(max_lifetime), }; ptr.updateHasPendingActivity(); @@ -1487,6 +1746,9 @@ pub const PostgresSQLConnection = struct { js_value.ensureStillAlive(); ptr.js_value = js_value; + PostgresSQLConnection.onconnectSetCached(js_value, globalObject, on_connect); + PostgresSQLConnection.oncloseSetCached(js_value, globalObject, on_close); + { const hostname = hostname_str.toUTF8(bun.default_allocator); defer hostname.deinit(); @@ -1498,6 +1760,7 @@ pub const PostgresSQLConnection = struct { vm.rareData().postgresql_context.tcp = ctx_; break :brk ctx_; }; + ptr.socket = .{ .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| { tls_config.deinit(); @@ -1508,6 +1771,8 @@ pub const PostgresSQLConnection = struct { return globalObject.throwError(err, "failed to connect to postgresql"); }, }; + + ptr.resetConnectionTimeout(); } return js_value; @@ -1600,7 +1865,17 @@ pub const PostgresSQLConnection = struct { return .undefined; } + pub fn stopTimers(this: *PostgresSQLConnection) void { + if (this.timer.state == .ACTIVE) { + this.globalObject.bunVM().timer.remove(&this.timer); + } + if (this.max_lifetime_timer.state == .ACTIVE) { + this.globalObject.bunVM().timer.remove(&this.max_lifetime_timer); + } + } + pub fn deinit(this: *@This()) void { + this.stopTimers(); var iter = this.statements.valueIterator(); while (iter.next()) |stmt_ptr| { var stmt = stmt_ptr.*; @@ -1609,8 +1884,6 @@ pub const PostgresSQLConnection = struct { this.statements.deinit(bun.default_allocator); this.write_buffer.deinit(bun.default_allocator); this.read_buffer.deinit(bun.default_allocator); - this.on_close.deinit(); - this.on_connect.deinit(); this.backend_parameters.deinit(); bun.default_allocator.free(this.options_buf); this.tls_config.deinit(); @@ -1618,6 +1891,8 @@ pub const PostgresSQLConnection = struct { } pub fn disconnect(this: *@This()) void { + this.stopTimers(); + if (this.status == .connected) { this.status = .disconnected; this.poll_ref.disable(); @@ -1636,12 +1911,12 @@ pub const PostgresSQLConnection = struct { pub const Writer = struct { connection: *PostgresSQLConnection, - pub fn write(this: Writer, data: []const u8) anyerror!void { + pub fn write(this: Writer, data: []const u8) AnyPostgresError!void { var buffer = &this.connection.write_buffer; try buffer.write(bun.default_allocator, data); } - pub fn pwrite(this: Writer, data: []const u8, index: usize) anyerror!void { + pub fn pwrite(this: Writer, data: []const u8, index: usize) AnyPostgresError!void { @memcpy(this.connection.write_buffer.byte_list.slice()[index..][0..data.len], data); } @@ -1676,7 +1951,7 @@ pub const PostgresSQLConnection = struct { pub fn ensureCapacity(this: Reader, count: usize) bool { return @as(usize, this.connection.read_buffer.head) + count <= @as(usize, this.connection.read_buffer.byte_list.len); } - pub fn read(this: Reader, count: usize) anyerror!Data { + pub fn read(this: Reader, count: usize) AnyPostgresError!Data { var remaining = this.connection.read_buffer.remaining(); if (@as(usize, remaining.len) < count) { return error.ShortRead; @@ -1687,7 +1962,7 @@ pub const PostgresSQLConnection = struct { .temporary = remaining[0..count], }; } - pub fn readZ(this: Reader) anyerror!Data { + pub fn readZ(this: Reader) AnyPostgresError!Data { const remain = this.connection.read_buffer.remaining(); if (bun.strings.indexOfChar(remain, 0)) |zero| { @@ -1799,7 +2074,7 @@ pub const PostgresSQLConnection = struct { } } - pub fn fromBytes(binary: bool, oid: int4, bytes: []const u8, globalObject: *JSC.JSGlobalObject) anyerror!DataCell { + pub fn fromBytes(binary: bool, oid: int4, bytes: []const u8, globalObject: *JSC.JSGlobalObject) !DataCell { switch (@as(types.Tag, @enumFromInt(@as(short, @intCast(oid))))) { // TODO: .int2_array, .float8_array inline .int4_array, .float4_array => |tag| { @@ -1876,7 +2151,7 @@ pub const PostgresSQLConnection = struct { return DataCell{ .tag = .float8, .value = .{ .float8 = float4 } }; } }, - .json => { + .jsonb, .json => { return DataCell{ .tag = .json, .value = .{ .json = String.createUTF8(bytes).value.WTFStringImpl }, .free_value = 1 }; }, .bool => { @@ -1956,7 +2231,7 @@ pub const PostgresSQLConnection = struct { return pg_ntoT(32, x); } - pub fn parseBinary(comptime tag: types.Tag, comptime ReturnType: type, bytes: []const u8) anyerror!ReturnType { + pub fn parseBinary(comptime tag: types.Tag, comptime ReturnType: type, bytes: []const u8) AnyPostgresError!ReturnType { switch (comptime tag) { .float8 => { return @as(f64, @bitCast(try parseBinary(.int8, i64, bytes))); @@ -2017,7 +2292,7 @@ pub const PostgresSQLConnection = struct { return JSC__constructObjectFromDataCell(globalObject, array, structure, this.list.ptr, @truncate(this.fields.len)); } - pub fn put(this: *Putter, index: u32, optional_bytes: ?*Data) anyerror!bool { + pub fn put(this: *Putter, index: u32, optional_bytes: ?*Data) !bool { const oid = this.fields[index].type_oid; debug("index: {d}, oid: {d}", .{ index, oid }); @@ -2072,7 +2347,7 @@ pub const PostgresSQLConnection = struct { const binding_value = PostgresSQLQuery.bindingGetCached(req.thisValue) orelse .zero; const columns_value = PostgresSQLQuery.columnsGetCached(req.thisValue) orelse .zero; PostgresRequest.bindAndExecute(this.globalObject, stmt, binding_value, columns_value, PostgresSQLConnection.Writer, this.writer()) catch |err| { - req.onWriteFail(err, this.globalObject); + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); req.deref(); this.requests.discard(1); continue; @@ -2091,7 +2366,11 @@ pub const PostgresSQLConnection = struct { return any; } - pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.EnumLiteral), comptime Context: type, reader: protocol.NewReader(Context)) !void { + pub fn getQueriesArray(this: *const PostgresSQLConnection) JSValue { + return PostgresSQLConnection.queriesGetCached(this.js_value) orelse .zero; + } + + pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.EnumLiteral), comptime Context: type, reader: protocol.NewReader(Context)) AnyPostgresError!void { debug("on({s})", .{@tagName(MessageType)}); if (comptime MessageType != .ReadyForQuery) { this.is_ready_for_query = false; @@ -2181,7 +2460,7 @@ pub const PostgresSQLConnection = struct { debug("-> {s}", .{cmd.command_tag.slice()}); _ = this.requests.discard(1); defer this.updateRef(); - request.onSuccess(cmd.command_tag.slice(), this.globalObject); + request.onSuccess(cmd.command_tag.slice(), this.globalObject, this.js_value); }, .BindComplete => { try reader.eatMessage(protocol.BindComplete); @@ -2255,7 +2534,12 @@ pub const PostgresSQLConnection = struct { const iteration_count = try cont.iterationCount(); - const server_salt_decoded_base64 = try bun.base64.decodeAlloc(bun.z_allocator, cont.s); + const server_salt_decoded_base64 = bun.base64.decodeAlloc(bun.z_allocator, cont.s) catch |err| { + return switch (err) { + error.DecodingFailed => error.SASL_SIGNATURE_INVALID_BASE64, + else => |e| e, + }; + }; defer bun.z_allocator.free(server_salt_decoded_base64); try sasl.computeSaltedPassword(server_salt_decoded_base64, iteration_count, this); @@ -2352,8 +2636,46 @@ pub const PostgresSQLConnection = struct { this.flushData(); }, + .MD5Password => |md5| { + debug("MD5Password", .{}); + // Format is: md5 + md5(md5(password + username) + salt) + var first_hash_buf: bun.sha.MD5.Digest = undefined; + var first_hash_str: [32]u8 = undefined; + var final_hash_buf: bun.sha.MD5.Digest = undefined; + var final_hash_str: [32]u8 = undefined; + var final_password_buf: [36]u8 = undefined; + + // First hash: md5(password + username) + var first_hasher = bun.sha.MD5.init(); + first_hasher.update(this.password); + first_hasher.update(this.user); + first_hasher.final(&first_hash_buf); + const first_hash_str_output = std.fmt.bufPrint(&first_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&first_hash_buf)}) catch unreachable; + + // Second hash: md5(first_hash + salt) + var final_hasher = bun.sha.MD5.init(); + final_hasher.update(first_hash_str_output); + final_hasher.update(&md5.salt); + final_hasher.final(&final_hash_buf); + const final_hash_str_output = std.fmt.bufPrint(&final_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&final_hash_buf)}) catch unreachable; + + // Format final password as "md5" + final_hash + const final_password = std.fmt.bufPrintZ(&final_password_buf, "md5{s}", .{final_hash_str_output}) catch unreachable; + + var response = protocol.PasswordMessage{ + .password = .{ + .temporary = final_password, + }, + }; + + this.authentication_state = .{ .md5 = {} }; + try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); + this.flushData(); + }, + else => { debug("TODO auth: {s}", .{@tagName(std.meta.activeTag(auth))}); + this.fail("TODO: support authentication method: {s}", error.UNSUPPORTED_AUTHENTICATION_METHOD); }, } }, @@ -2371,19 +2693,12 @@ pub const PostgresSQLConnection = struct { var err: protocol.ErrorResponse = undefined; try err.decodeInternal(Context, reader); - if (this.status == .connecting) { - this.status = .failed; + if (this.status == .connecting or this.status == .sent_startup_message) { defer { err.deinit(); - this.poll_ref.unref(this.globalObject.bunVM()); - this.updateHasPendingActivity(); } - const on_connect = this.on_connect.swap(); - if (on_connect == .zero) return; - const js_value = this.js_value; - js_value.ensureStillAlive(); - this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ err.toJS(this.globalObject), js_value }); + this.failWithJSValue(err.toJS(this.globalObject)); // it shouldn't enqueue any requests while connecting bun.assert(this.requests.count == 0); @@ -2426,7 +2741,7 @@ pub const PostgresSQLConnection = struct { try reader.eatMessage(protocol.CloseComplete); var request = this.current() orelse return error.ExpectedRequest; _ = this.requests.discard(1); - request.onSuccess("CLOSECOMPLETE", this.globalObject); + request.onSuccess("CLOSECOMPLETE", this.globalObject, this.getQueriesArray()); }, .CopyInResponse => { debug("TODO CopyInResponse", .{}); @@ -2443,7 +2758,7 @@ pub const PostgresSQLConnection = struct { var request = this.current() orelse return error.ExpectedRequest; _ = this.requests.discard(1); this.updateRef(); - request.onSuccess("", this.globalObject); + request.onSuccess("", this.globalObject, this.getQueriesArray()); }, .CopyOutResponse => { debug("TODO CopyOutResponse", .{}); @@ -2467,25 +2782,21 @@ pub const PostgresSQLConnection = struct { } } - pub fn doFlush(this: *PostgresSQLConnection, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { - _ = callframe; - _ = globalObject; - _ = this; - - return .undefined; - } - - pub fn createQuery(this: *PostgresSQLConnection, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { - _ = callframe; - _ = globalObject; - _ = this; - - return .undefined; - } - pub fn getConnected(this: *PostgresSQLConnection, _: *JSC.JSGlobalObject) JSValue { return JSValue.jsBoolean(this.status == Status.connected); } + + pub fn consumeOnConnectCallback(this: *const PostgresSQLConnection, globalObject: *JSC.JSGlobalObject) ?JSC.JSValue { + const on_connect = PostgresSQLConnection.onconnectGetCached(this.js_value) orelse return null; + PostgresSQLConnection.onconnectSetCached(this.js_value, globalObject, .zero); + return on_connect; + } + + pub fn consumeOnCloseCallback(this: *const PostgresSQLConnection, globalObject: *JSC.JSGlobalObject) ?JSC.JSValue { + const on_close = PostgresSQLConnection.oncloseGetCached(this.js_value) orelse return null; + PostgresSQLConnection.oncloseSetCached(this.js_value, globalObject, .zero); + return on_close; + } }; pub const PostgresSQLStatement = struct { @@ -2723,7 +3034,7 @@ const Signature = struct { .float8 => try name.appendSlice(".float8"), .float4 => try name.appendSlice(".float4"), .numeric => try name.appendSlice(".numeric"), - .json => try name.appendSlice(".json"), + .json, .jsonb => try name.appendSlice(".json"), .bool => try name.appendSlice(".bool"), .timestamp => try name.appendSlice(".timestamp"), .timestamptz => try name.appendSlice(".timestamptz"), diff --git a/src/sql/postgres/postgres_protocol.zig b/src/sql/postgres/postgres_protocol.zig index 4aee1791f9..60eeaf9f9d 100644 --- a/src/sql/postgres/postgres_protocol.zig +++ b/src/sql/postgres/postgres_protocol.zig @@ -15,7 +15,7 @@ const int4 = postgres.int4; const int8 = postgres.int8; const PostgresInt64 = postgres.PostgresInt64; const types = postgres.types; - +const AnyPostgresError = postgres.AnyPostgresError; pub const ArrayList = struct { array: *std.ArrayList(u8), @@ -23,11 +23,11 @@ pub const ArrayList = struct { return this.array.items.len; } - pub fn write(this: @This(), bytes: []const u8) anyerror!void { + pub fn write(this: @This(), bytes: []const u8) AnyPostgresError!void { try this.array.appendSlice(bytes); } - pub fn pwrite(this: @This(), bytes: []const u8, i: usize) anyerror!void { + pub fn pwrite(this: @This(), bytes: []const u8, i: usize) AnyPostgresError!void { @memcpy(this.array.items[i..][0..bytes.len], bytes); } @@ -71,7 +71,7 @@ pub const StackReader = struct { pub fn ensureCapacity(this: StackReader, count: usize) bool { return this.buffer.len >= (this.offset.* + count); } - pub fn read(this: StackReader, count: usize) anyerror!Data { + pub fn read(this: StackReader, count: usize) AnyPostgresError!Data { const offset = this.offset.*; if (!this.ensureCapacity(count)) { return error.ShortRead; @@ -82,7 +82,7 @@ pub const StackReader = struct { .temporary = this.buffer[offset..this.offset.*], }; } - pub fn readZ(this: StackReader) anyerror!Data { + pub fn readZ(this: StackReader) AnyPostgresError!Data { const remaining = this.peek(); if (bun.strings.indexOfChar(remaining, 0)) |zero| { this.skip(zero + 1); @@ -98,8 +98,8 @@ pub const StackReader = struct { pub fn NewWriterWrap( comptime Context: type, comptime offsetFn_: (fn (ctx: Context) usize), - comptime writeFunction_: (fn (ctx: Context, bytes: []const u8) anyerror!void), - comptime pwriteFunction_: (fn (ctx: Context, bytes: []const u8, offset: usize) anyerror!void), + comptime writeFunction_: (fn (ctx: Context, bytes: []const u8) AnyPostgresError!void), + comptime pwriteFunction_: (fn (ctx: Context, bytes: []const u8, offset: usize) AnyPostgresError!void), ) type { return struct { wrapped: Context, @@ -111,7 +111,7 @@ pub fn NewWriterWrap( pub const WrappedWriter = @This(); - pub inline fn write(this: @This(), data: []const u8) anyerror!void { + pub inline fn write(this: @This(), data: []const u8) AnyPostgresError!void { try writeFn(this.wrapped, data); } @@ -119,16 +119,16 @@ pub fn NewWriterWrap( index: usize, context: WrappedWriter, - pub fn write(this: LengthWriter) anyerror!void { + pub fn write(this: LengthWriter) AnyPostgresError!void { try this.context.pwrite(&Int32(this.context.offset() - this.index), this.index); } - pub fn writeExcludingSelf(this: LengthWriter) anyerror!void { + pub fn writeExcludingSelf(this: LengthWriter) AnyPostgresError!void { try this.context.pwrite(&Int32(this.context.offset() -| (this.index + 4)), this.index); } }; - pub inline fn length(this: @This()) anyerror!LengthWriter { + pub inline fn length(this: @This()) AnyPostgresError!LengthWriter { const i = this.offset(); try this.int4(0); return LengthWriter{ @@ -141,7 +141,7 @@ pub fn NewWriterWrap( return offsetFn(this.wrapped); } - pub inline fn pwrite(this: @This(), data: []const u8, i: usize) anyerror!void { + pub inline fn pwrite(this: @This(), data: []const u8, i: usize) AnyPostgresError!void { try pwriteFn(this.wrapped, data, i); } @@ -208,81 +208,81 @@ pub fn NewWriterWrap( pub const FieldType = enum(u8) { /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a localized translation of one of these. Always present. - S = 'S', + severity = 'S', /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message). This is identical to the S field except that the contents are never localized. This is present only in messages generated by PostgreSQL versions 9.6 and later. - V = 'V', + localized_severity = 'V', /// Code: the SQLSTATE code for the error (see Appendix A). Not localizable. Always present. - C = 'C', + code = 'C', /// Message: the primary human-readable error message. This should be accurate but terse (typically one line). Always present. - M = 'M', + message = 'M', /// Detail: an optional secondary error message carrying more detail about the problem. Might run to multiple lines. - D = 'D', + detail = 'D', /// Hint: an optional suggestion what to do about the problem. This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. Might run to multiple lines. - H = 'H', + hint = 'H', /// Position: the field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. The first character has index 1, and positions are measured in characters not bytes. - P = 'P', + position = 'P', /// Internal position: this is defined the same as the P field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. The q field will always appear when this field appears. - p = 'p', + internal_position = 'p', /// Internal query: the text of a failed internally-generated command. This could be, for example, an SQL query issued by a PL/pgSQL function. - q = 'q', + internal = 'q', /// Where: an indication of the context in which the error occurred. Presently this includes a call stack traceback of active procedural language functions and internally-generated queries. The trace is one entry per line, most recent first. - W = 'W', + where = 'W', /// Schema name: if the error was associated with a specific database object, the name of the schema containing that object, if any. - s = 's', + schema = 's', /// Table name: if the error was associated with a specific table, the name of the table. (Refer to the schema name field for the name of the table's schema.) - t = 't', + table = 't', /// Column name: if the error was associated with a specific table column, the name of the column. (Refer to the schema and table name fields to identify the table.) - c = 'c', + column = 'c', /// Data type name: if the error was associated with a specific data type, the name of the data type. (Refer to the schema name field for the name of the data type's schema.) - d = 'd', + datatype = 'd', /// Constraint name: if the error was associated with a specific constraint, the name of the constraint. Refer to fields listed above for the associated table or domain. (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) - n = 'n', + constraint = 'n', /// File: the file name of the source-code location where the error was reported. - F = 'F', + file = 'F', /// Line: the line number of the source-code location where the error was reported. - L = 'L', + line = 'L', /// Routine: the name of the source-code routine reporting the error. - R = 'R', + routine = 'R', _, }; pub const FieldMessage = union(FieldType) { - S: String, - V: String, - C: String, - M: String, - D: String, - H: String, - P: String, - p: String, - q: String, - W: String, - s: String, - t: String, - c: String, - d: String, - n: String, - F: String, - L: String, - R: String, + severity: String, + localized_severity: String, + code: String, + message: String, + detail: String, + hint: String, + position: String, + internal_position: String, + internal: String, + where: String, + schema: String, + table: String, + column: String, + datatype: String, + constraint: String, + file: String, + line: String, + routine: String, pub fn format(this: FieldMessage, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { switch (this) { @@ -319,24 +319,25 @@ pub const FieldMessage = union(FieldType) { pub fn init(tag: FieldType, message: []const u8) !FieldMessage { return switch (tag) { - .S => FieldMessage{ .S = String.createUTF8(message) }, - .V => FieldMessage{ .V = String.createUTF8(message) }, - .C => FieldMessage{ .C = String.createUTF8(message) }, - .M => FieldMessage{ .M = String.createUTF8(message) }, - .D => FieldMessage{ .D = String.createUTF8(message) }, - .H => FieldMessage{ .H = String.createUTF8(message) }, - .P => FieldMessage{ .P = String.createUTF8(message) }, - .p => FieldMessage{ .p = String.createUTF8(message) }, - .q => FieldMessage{ .q = String.createUTF8(message) }, - .W => FieldMessage{ .W = String.createUTF8(message) }, - .s => FieldMessage{ .s = String.createUTF8(message) }, - .t => FieldMessage{ .t = String.createUTF8(message) }, - .c => FieldMessage{ .c = String.createUTF8(message) }, - .d => FieldMessage{ .d = String.createUTF8(message) }, - .n => FieldMessage{ .n = String.createUTF8(message) }, - .F => FieldMessage{ .F = String.createUTF8(message) }, - .L => FieldMessage{ .L = String.createUTF8(message) }, - .R => FieldMessage{ .R = String.createUTF8(message) }, + .severity => FieldMessage{ .severity = String.createUTF8(message) }, + // Ignore this one for now. + // .localized_severity => FieldMessage{ .localized_severity = String.createUTF8(message) }, + .code => FieldMessage{ .code = String.createUTF8(message) }, + .message => FieldMessage{ .message = String.createUTF8(message) }, + .detail => FieldMessage{ .detail = String.createUTF8(message) }, + .hint => FieldMessage{ .hint = String.createUTF8(message) }, + .position => FieldMessage{ .position = String.createUTF8(message) }, + .internal_position => FieldMessage{ .internal_position = String.createUTF8(message) }, + .internal => FieldMessage{ .internal = String.createUTF8(message) }, + .where => FieldMessage{ .where = String.createUTF8(message) }, + .schema => FieldMessage{ .schema = String.createUTF8(message) }, + .table => FieldMessage{ .table = String.createUTF8(message) }, + .column => FieldMessage{ .column = String.createUTF8(message) }, + .datatype => FieldMessage{ .datatype = String.createUTF8(message) }, + .constraint => FieldMessage{ .constraint = String.createUTF8(message) }, + .file => FieldMessage{ .file = String.createUTF8(message) }, + .line => FieldMessage{ .line = String.createUTF8(message) }, + .routine => FieldMessage{ .routine = String.createUTF8(message) }, else => error.UnknownFieldType, }; } @@ -348,8 +349,8 @@ pub fn NewReaderWrap( comptime peekFn_: (fn (ctx: Context) []const u8), comptime skipFn_: (fn (ctx: Context, count: usize) void), comptime ensureCapacityFn_: (fn (ctx: Context, count: usize) bool), - comptime readFunction_: (fn (ctx: Context, count: usize) anyerror!Data), - comptime readZ_: (fn (ctx: Context) anyerror!Data), + comptime readFunction_: (fn (ctx: Context, count: usize) AnyPostgresError!Data), + comptime readZ_: (fn (ctx: Context) AnyPostgresError!Data), ) type { return struct { wrapped: Context, @@ -366,11 +367,11 @@ pub fn NewReaderWrap( markMessageStartFn(this.wrapped); } - pub inline fn read(this: @This(), count: usize) anyerror!Data { + pub inline fn read(this: @This(), count: usize) AnyPostgresError!Data { return try readFn(this.wrapped, count); } - pub inline fn eatMessage(this: @This(), comptime msg_: anytype) anyerror!void { + pub inline fn eatMessage(this: @This(), comptime msg_: anytype) AnyPostgresError!void { const msg = msg_[1..]; try this.ensureCapacity(msg.len); @@ -380,7 +381,7 @@ pub fn NewReaderWrap( return error.InvalidMessage; } - pub fn skip(this: @This(), count: usize) anyerror!void { + pub fn skip(this: @This(), count: usize) AnyPostgresError!void { skipFn(this.wrapped, count); } @@ -388,11 +389,11 @@ pub fn NewReaderWrap( return peekFn(this.wrapped); } - pub inline fn readZ(this: @This()) anyerror!Data { + pub inline fn readZ(this: @This()) AnyPostgresError!Data { return try readZFn(this.wrapped); } - pub inline fn ensureCapacity(this: @This(), count: usize) anyerror!void { + pub inline fn ensureCapacity(this: @This(), count: usize) AnyPostgresError!void { if (!ensureCapacityFn(this.wrapped, count)) { return error.ShortRead; } @@ -457,7 +458,7 @@ pub fn NewWriter(comptime Context: type) type { fn decoderWrap(comptime Container: type, comptime decodeFn: anytype) type { return struct { - pub fn decode(this: *Container, context: anytype) anyerror!void { + pub fn decode(this: *Container, context: anytype) AnyPostgresError!void { const Context = @TypeOf(context); try decodeFn(this, Context, NewReader(Context){ .wrapped = context }); } @@ -466,7 +467,7 @@ fn decoderWrap(comptime Container: type, comptime decodeFn: anytype) type { fn writeWrap(comptime Container: type, comptime writeFn: anytype) type { return struct { - pub fn write(this: *Container, context: anytype) anyerror!void { + pub fn write(this: *Container, context: anytype) AnyPostgresError!void { const Context = @TypeOf(context); try writeFn(this, Context, NewWriter(Context){ .wrapped = context }); } @@ -538,9 +539,6 @@ pub const Authentication = union(enum) { }, 5 => { if (message_length != 12) return error.InvalidMessageLength; - if (!try reader.expectInt(u32, 5)) { - return error.InvalidMessage; - } var salt_data = try reader.bytes(4); defer salt_data.deinit(); this.* = .{ @@ -722,23 +720,117 @@ pub const ErrorResponse = struct { var b = bun.StringBuilder{}; defer b.deinit(bun.default_allocator); - for (this.messages.items) |msg| { - b.cap += switch (msg) { + // Pre-calculate capacity to avoid reallocations + for (this.messages.items) |*msg| { + b.cap += switch (msg.*) { inline else => |m| m.utf8ByteLength(), } + 1; } b.allocate(bun.default_allocator) catch {}; - for (this.messages.items) |msg| { - var str = switch (msg) { - inline else => |m| m.toUTF8(bun.default_allocator), - }; - defer str.deinit(); - _ = b.append(str.slice()); - _ = b.append("\n"); + // Build a more structured error message + var severity: String = String.dead; + var code: String = String.dead; + var message: String = String.dead; + var detail: String = String.dead; + var hint: String = String.dead; + var position: String = String.dead; + var where: String = String.dead; + var schema: String = String.dead; + var table: String = String.dead; + var column: String = String.dead; + var datatype: String = String.dead; + var constraint: String = String.dead; + var file: String = String.dead; + var line: String = String.dead; + var routine: String = String.dead; + + for (this.messages.items) |*msg| { + switch (msg.*) { + .severity => |str| severity = str, + .code => |str| code = str, + .message => |str| message = str, + .detail => |str| detail = str, + .hint => |str| hint = str, + .position => |str| position = str, + .where => |str| where = str, + .schema => |str| schema = str, + .table => |str| table = str, + .column => |str| column = str, + .datatype => |str| datatype = str, + .constraint => |str| constraint = str, + .file => |str| file = str, + .line => |str| line = str, + .routine => |str| routine = str, + else => {}, + } } - return globalObject.createSyntaxErrorInstance("Postgres error occurred\n{s}", .{b.allocatedSlice()[0..b.len]}); + var needs_newline = false; + construct_message: { + if (!message.isEmpty()) { + _ = b.appendStr(message); + needs_newline = true; + break :construct_message; + } + if (!detail.isEmpty()) { + if (needs_newline) { + _ = b.append("\n"); + } else { + _ = b.append(" "); + } + needs_newline = true; + _ = b.appendStr(detail); + } + if (!hint.isEmpty()) { + if (needs_newline) { + _ = b.append("\n"); + } else { + _ = b.append(" "); + } + needs_newline = true; + _ = b.appendStr(hint); + } + } + + const possible_fields = .{ + .{ "detail", detail, void }, + .{ "hint", hint, void }, + .{ "column", column, void }, + .{ "constraint", constraint, void }, + .{ "datatype", datatype, void }, + .{ "errno", code, i32 }, + .{ "position", position, i32 }, + .{ "schema", schema, void }, + .{ "table", table, void }, + .{ "where", where, void }, + }; + + const error_code: JSC.Error = + // https://www.postgresql.org/docs/8.1/errcodes-appendix.html + if (code.toInt32() orelse 0 == 42601) + JSC.Error.ERR_POSTGRES_SYNTAX_ERROR + else + JSC.Error.ERR_POSTGRES_SERVER_ERROR; + const err = error_code.fmt(globalObject, "{s}", .{b.allocatedSlice()[0..b.len]}); + + inline for (possible_fields) |field| { + if (!field.@"1".isEmpty()) { + const value = brk: { + if (field.@"2" == i32) { + if (field.@"1".toInt32()) |val| { + break :brk JSC.JSValue.jsNumberFromInt32(val); + } + } + + break :brk field.@"1".toJS(globalObject); + }; + + err.put(globalObject, JSC.ZigString.static(field.@"0"), value); + } + } + + return err; } }; @@ -847,7 +939,7 @@ pub const FormatCode = enum { pub const null_int4 = 4294967295; pub const DataRow = struct { - pub fn decode(context: anytype, comptime ContextType: type, reader: NewReader(ContextType), comptime forEach: fn (@TypeOf(context), index: u32, bytes: ?*Data) anyerror!bool) anyerror!void { + pub fn decode(context: anytype, comptime ContextType: type, reader: NewReader(ContextType), comptime forEach: fn (@TypeOf(context), index: u32, bytes: ?*Data) AnyPostgresError!bool) AnyPostgresError!void { var remaining_bytes = try reader.length(); remaining_bytes -|= 4; @@ -885,7 +977,7 @@ pub const FieldDescription = struct { this.name.deinit(); } - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) AnyPostgresError!void { var name = try reader.readZ(); errdefer { name.deinit(); @@ -1355,6 +1447,29 @@ pub const NoticeResponse = struct { } } pub const decode = decoderWrap(NoticeResponse, decodeInternal).decode; + + pub fn toJS(this: NoticeResponse, globalObject: *JSC.JSGlobalObject) JSValue { + var b = bun.StringBuilder{}; + defer b.deinit(bun.default_allocator); + + for (this.messages.items) |msg| { + b.cap += switch (msg) { + inline else => |m| m.utf8ByteLength(), + } + 1; + } + b.allocate(bun.default_allocator) catch {}; + + for (this.messages.items) |msg| { + var str = switch (msg) { + inline else => |m| m.toUTF8(bun.default_allocator), + }; + defer str.deinit(); + _ = b.append(str.slice()); + _ = b.append("\n"); + } + + return JSC.ZigString.init(b.allocatedSlice()[0..b.len]).toJS(globalObject); + } }; pub const CopyFail = struct { diff --git a/src/sql/postgres/postgres_types.zig b/src/sql/postgres/postgres_types.zig index 498b7f1d0a..74e8ff104b 100644 --- a/src/sql/postgres/postgres_types.zig +++ b/src/sql/postgres/postgres_types.zig @@ -12,6 +12,7 @@ const JSValue = JSC.JSValue; const JSC = bun.JSC; const short = postgres.short; const int4 = postgres.int4; +const AnyPostgresError = postgres.AnyPostgresError; // select b.typname, b.oid, b.typarray // from pg_catalog.pg_type a @@ -169,8 +170,17 @@ pub const Tag = enum(short) { bit_array = 1561, varbit_array = 1563, numeric_array = 1231, + jsonb = 3802, + jsonb_array = 3807, + // Not really sure what this is. + jsonpath = 4072, + jsonpath_array = 4073, _, + pub fn name(this: Tag) ?[]const u8 { + return std.enums.tagName(Tag, this); + } + pub fn isBinaryFormatSupported(this: Tag) bool { return switch (this) { // TODO: .int2_array, .float8_array, @@ -282,7 +292,7 @@ pub const Tag = enum(short) { globalObject: *JSC.JSGlobalObject, comptime Type: type, value: Type, - ) anyerror!JSValue { + ) AnyPostgresError!JSValue { switch (tag) { .numeric => { return numeric.toJS(globalObject, value); @@ -292,7 +302,7 @@ pub const Tag = enum(short) { return numeric.toJS(globalObject, value); }, - .json => { + .json, .jsonb => { return json.toJS(globalObject, value); }, @@ -326,7 +336,7 @@ pub const Tag = enum(short) { tag: Tag, globalObject: *JSC.JSGlobalObject, value: anytype, - ) anyerror!JSValue { + ) AnyPostgresError!JSValue { return toJSWithType(tag, globalObject, @TypeOf(value), value); } @@ -363,16 +373,16 @@ pub const Tag = enum(short) { // Ban these types: if (tag == .NumberObject) { - return error.JSError; + return globalObject.ERR_INVALID_ARG_TYPE("Number object is ambiguous and cannot be used as a PostgreSQL type", .{}).throw(); } if (tag == .BooleanObject) { - return error.JSError; + return globalObject.ERR_INVALID_ARG_TYPE("Boolean object is ambiguous and cannot be used as a PostgreSQL type", .{}).throw(); } // It's something internal if (!tag.isIndexable()) { - return error.JSError; + return globalObject.ERR_INVALID_ARG_TYPE("Unknown object is not a valid PostgreSQL type", .{}).throw(); } // We will JSON.stringify anything else. @@ -414,7 +424,7 @@ pub const string = struct { globalThis: *JSC.JSGlobalObject, comptime Type: type, value: Type, - ) anyerror!JSValue { + ) AnyPostgresError!JSValue { switch (comptime Type) { [:0]u8, []u8, []const u8, [:0]const u8 => { var str = String.fromUTF8(value); @@ -456,7 +466,7 @@ pub const numeric = struct { pub fn toJS( _: *JSC.JSGlobalObject, value: anytype, - ) anyerror!JSValue { + ) AnyPostgresError!JSValue { return JSValue.jsNumber(value); } }; @@ -468,12 +478,12 @@ pub const json = struct { pub fn toJS( globalObject: *JSC.JSGlobalObject, value: *Data, - ) anyerror!JSValue { + ) AnyPostgresError!JSValue { defer value.deinit(); var str = bun.String.fromUTF8(value.slice()); defer str.deref(); const parse_result = JSValue.parse(str.toJS(globalObject), globalObject); - if (parse_result.isAnyError()) { + if (parse_result.AnyPostgresError()) { return globalObject.throwValue(parse_result); } @@ -488,7 +498,7 @@ pub const @"bool" = struct { pub fn toJS( _: *JSC.JSGlobalObject, value: bool, - ) anyerror!JSValue { + ) AnyPostgresError!JSValue { return JSValue.jsBoolean(value); } }; @@ -548,7 +558,7 @@ pub const bytea = struct { pub fn toJS( globalObject: *JSC.JSGlobalObject, value: *Data, - ) anyerror!JSValue { + ) AnyPostgresError!JSValue { defer value.deinit(); // var slice = value.slice()[@min(1, value.len)..]; diff --git a/src/string.zig b/src/string.zig index 30bc054a53..4f874a9864 100644 --- a/src/string.zig +++ b/src/string.zig @@ -323,6 +323,12 @@ pub const String = extern struct { extern fn BunString__fromUTF16ToLatin1(bytes: [*]const u16, len: usize) String; extern fn BunString__fromLatin1Unitialized(len: usize) String; extern fn BunString__fromUTF16Unitialized(len: usize) String; + extern fn BunString__toInt32(this: String) i64; + pub fn toInt32(this: String) ?i32 { + const val = BunString__toInt32(this); + if (val > std.math.maxInt(i32)) return null; + return @intCast(val); + } pub fn ascii(bytes: []const u8) String { return String{ .tag = .ZigString, .value = .{ .ZigString = ZigString.init(bytes) } }; diff --git a/src/string_builder.zig b/src/string_builder.zig index e5e3d0bd47..0178663a4f 100644 --- a/src/string_builder.zig +++ b/src/string_builder.zig @@ -89,6 +89,12 @@ pub fn appendZ(this: *StringBuilder, slice: string) [:0]const u8 { return result; } +pub fn appendStr(this: *StringBuilder, str: bun.String) string { + const slice = str.toUTF8(bun.default_allocator); + defer slice.deinit(); + return this.append(slice.slice()); +} + pub fn append(this: *StringBuilder, slice: string) string { if (comptime Environment.allow_assert) { assert(this.len <= this.cap); // didn't count everything diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index 8c0089c760..92fd82931b 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -1,5 +1,5 @@ import { postgres, sql } from "bun:sql"; -import { expect, test } from "bun:test"; +import { expect, test, mock } from "bun:test"; import { $ } from "bun"; import { bunExe, isCI, withoutAggressiveGC } from "harness"; import path from "path"; @@ -13,18 +13,20 @@ if (!isCI) { // local all postgres trust // local all bun_sql_test_scram scram-sha-256 // local all bun_sql_test trust - // + // local all bun_sql_test_md5 md5 + // # IPv4 local connections: // host all ${USERNAME} 127.0.0.1/32 trust // host all postgres 127.0.0.1/32 trust // host all bun_sql_test_scram 127.0.0.1/32 scram-sha-256 // host all bun_sql_test 127.0.0.1/32 trust + // host all bun_sql_test_md5 127.0.0.1/32 md5 // # IPv6 local connections: // host all ${USERNAME} ::1/128 trust // host all postgres ::1/128 trust // host all bun_sql_test ::1/128 trust // host all bun_sql_test_scram ::1/128 scram-sha-256 - // + // host all bun_sql_test_md5 ::1/128 md5 // # Allow replication connections from localhost, by a user with the // # replication privilege. // local replication all trust @@ -33,9 +35,6 @@ if (!isCI) { // --- Expected pg_hba.conf --- process.env.DATABASE_URL = "postgres://bun_sql_test@localhost:5432/bun_sql_test"; - const delay = ms => Bun.sleep(ms); - const rel = x => new URL(x, import.meta.url); - const login = { username: "bun_sql_test", }; @@ -54,8 +53,8 @@ if (!isCI) { db: "bun_sql_test", username: login.username, password: login.password, - idle_timeout: 1, - connect_timeout: 1, + idle_timeout: 0, + connect_timeout: 0, max: 1, }; @@ -67,6 +66,97 @@ if (!isCI) { expect(result).toBe(1); }); + test("Connection timeout works", async () => { + const onclose = mock(); + const onconnect = mock(); + await using sql = postgres({ + ...options, + hostname: "unreachable_host", + connection_timeout: 1, + onconnect, + onclose, + }); + let error: any; + try { + await sql`select pg_sleep(2)`; + } catch (e) { + error = e; + } + expect(error.code).toBe(`ERR_POSTGRES_CONNECTION_TIMEOUT`); + expect(error.message).toContain("Connection timeout after 1ms"); + expect(onconnect).not.toHaveBeenCalled(); + expect(onclose).toHaveBeenCalledTimes(1); + }); + + test("Idle timeout works at start", async () => { + const onclose = mock(); + const onconnect = mock(); + await using sql = postgres({ + ...options, + idle_timeout: 1, + onconnect, + onclose, + }); + let error: any; + try { + await sql`select pg_sleep(2)`; + } catch (e) { + error = e; + } + expect(error.code).toBe(`ERR_POSTGRES_IDLE_TIMEOUT`); + expect(onconnect).toHaveBeenCalled(); + expect(onclose).toHaveBeenCalledTimes(1); + }); + + test("Idle timeout is reset when a query is run", async () => { + const onClosePromise = Promise.withResolvers(); + const onclose = mock(err => { + onClosePromise.resolve(err); + }); + const onconnect = mock(); + await using sql = postgres({ + ...options, + idle_timeout: 100, + onconnect, + onclose, + }); + expect(await sql`select 123 as x`).toEqual([{ x: 123 }]); + expect(onconnect).toHaveBeenCalledTimes(1); + expect(onclose).not.toHaveBeenCalled(); + const err = await onClosePromise.promise; + expect(err.code).toBe(`ERR_POSTGRES_IDLE_TIMEOUT`); + }); + + test("Max lifetime works", async () => { + const onClosePromise = Promise.withResolvers(); + const onclose = mock(err => { + onClosePromise.resolve(err); + }); + const onconnect = mock(); + const sql = postgres({ + ...options, + max_lifetime: 64, + onconnect, + onclose, + }); + let error: any; + expect(await sql`select 1 as x`).toEqual([{ x: 1 }]); + expect(onconnect).toHaveBeenCalledTimes(1); + try { + while (true) { + for (let i = 0; i < 100; i++) { + await sql`select pg_sleep(1)`; + } + } + } catch (e) { + error = e; + } + + expect(onclose).toHaveBeenCalledTimes(1); + + expect(error.code).toBe(`ERR_POSTGRES_LIFETIME_TIMEOUT`); + }); + test("Uses default database without slash", async () => { const sql = postgres("postgres://localhost"); expect(sql.options.username).toBe(sql.options.database); @@ -145,10 +235,9 @@ if (!isCI) { expect(x).toEqual({ a: "hello", b: 42 }); }); - // It's treating as a string. - test.todo("implicit jsonb", async () => { + test("implicit jsonb", async () => { const x = (await sql`select ${{ a: "hello", b: 42 }}::jsonb as x`)[0].x; - expect([x.a, x.b].join(",")).toBe("hello,42"); + expect(x).toEqual({ a: "hello", b: 42 }); }); test("bulk insert nested sql()", async () => { @@ -428,9 +517,11 @@ if (!isCI) { test("Null sets to null", async () => expect((await sql`select ${null} as x`)[0].x).toBeNull()); // Add code property. - test.todo("Throw syntax error", async () => { - const code = await sql`wat 1`.catch(x => x); - console.log({ code }); + test("Throw syntax error", async () => { + const err = await sql`wat 1`.catch(x => x); + expect(err.code).toBe("ERR_POSTGRES_SYNTAX_ERROR"); + expect(err.errno).toBe(42601); + expect(err).toBeInstanceOf(SyntaxError); }); // t('Connect using uri', async() => @@ -502,13 +593,26 @@ if (!isCI) { // return [1, (await sql`select 1 as x`)[0].x] // }) - // t('Login without password', async() => { - // return [true, (await postgres({ ...options, ...login })`select true as x`)[0].x] - // }) + test("Login without password", async () => { + await using sql = postgres({ ...options, ...login }); + expect((await sql`select true as x`)[0].x).toBe(true); + }); - // t('Login using MD5', async() => { - // return [true, (await postgres({ ...options, ...login_md5 })`select true as x`)[0].x] - // }) + test("Login using MD5", async () => { + await using sql = postgres({ ...options, ...login_md5 }); + expect(await sql`select true as x`).toEqual([{ x: true }]); + }); + + test("Login with bad credentials propagates error from server", async () => { + const sql = postgres({ ...options, ...login_md5, username: "bad_user", password: "bad_password" }); + let err; + try { + await sql`select true as x`; + } catch (e) { + err = e; + } + expect(err.code).toBe("ERR_POSTGRES_SERVER_ERROR"); + }); test("Login using scram-sha-256", async () => { await using sql = postgres({ ...options, ...login_scram }); @@ -1159,9 +1263,10 @@ if (!isCI) { // ] // }) - // t('dynamic column name', async() => { - // return ['!not_valid', Object.keys((await sql`select 1 as ${ sql('!not_valid') }`)[0])[0]] - // }) + test.todo("dynamic column name", async () => { + const result = await sql`select 1 as ${"\\!not_valid"}`; + expect(Object.keys(result[0])[0]).toBe("!not_valid"); + }); // t('dynamic select as', async() => { // return ['2', (await sql`select ${ sql({ a: 1, b: 2 }) }`)[0].b] @@ -1178,12 +1283,12 @@ if (!isCI) { // return ['the answer', (await sql`insert into test ${ sql(x) } returning *`)[0].b, await sql`drop table test`] // }) - // t('dynamic insert pluck', async() => { - // await sql`create table test (a int, b text)` - // const x = { a: 42, b: 'the answer' } - - // return [null, (await sql`insert into test ${ sql(x, 'a') } returning *`)[0].b, await sql`drop table test`] - // }) + // test.todo("dynamic insert pluck", async () => { + // await sql`create table test (a int, b text)`; + // const x = { a: 42, b: "the answer" }; + // const [{ b }] = await sql`insert into test ${sql(x, "a")} returning *`; + // expect(b).toBe("the answer"); + // }); // t('dynamic in with empty array', async() => { // await sql`create table test (a int)`