From 9bfd9db78bb349c1a2053bae7c0c52814dfe276c Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Mon, 20 Jan 2025 16:58:37 -0800 Subject: [PATCH] more(sql) type fixes and tests (#16512) --- packages/bun-types/bun.d.ts | 55 ++- src/bun.js/api/postgres.classes.ts | 4 + src/bun.js/bindings/BunObject.cpp | 3 +- src/bun.js/bindings/SQLClient.cpp | 186 +++++---- src/js/bun/sql.ts | 503 +++++++++++++++++++----- src/sql/postgres.zig | 232 ++++++++--- src/sql/postgres/postgres_protocol.zig | 11 +- src/sql/postgres/postgres_types.zig | 24 +- test/js/sql/docker/Dockerfile | 69 ++++ test/js/sql/docker/pg_hba.conf | 98 +++++ test/js/sql/sql.test.ts | 510 ++++++++++++++++--------- test/js/sql/tls-sql.test.ts | 239 +++++++++++- 12 files changed, 1473 insertions(+), 461 deletions(-) create mode 100644 test/js/sql/docker/Dockerfile create mode 100644 test/js/sql/docker/pg_hba.conf diff --git a/packages/bun-types/bun.d.ts b/packages/bun-types/bun.d.ts index ec3529efc1..178f959324 100644 --- a/packages/bun-types/bun.d.ts +++ b/packages/bun-types/bun.d.ts @@ -2019,7 +2019,7 @@ declare module "bun" { /** Database server port number */ port: number | string; /** Database user for authentication */ - user: string; + username: string; /** Database password for authentication */ password: string; /** Name of the database to connect to */ @@ -2040,6 +2040,8 @@ declare module "bun" { onclose: (client: SQL) => void; /** Maximum number of connections in the pool */ max: number; + /** By default values outside i32 range are returned as strings. If this is true, values outside i32 range are returned as BigInts. */ + bigint: boolean; }; /** @@ -2065,7 +2067,12 @@ declare module "bun" { * Callback function type for transaction contexts * @param sql Function to execute SQL queries within the transaction */ - type SQLContextCallback = (sql: (strings: string, ...values: any[]) => SQLQuery | Array) => Promise; + type SQLTransactionContextCallback = (sql: TransactionSQL) => Promise | Array; + /** + * Callback function type for savepoint contexts + * @param sql Function to execute SQL queries within the savepoint + */ + type SQLSavepointContextCallback = (sql: SavepointSQL) => Promise | Array; /** * Main SQL client interface providing connection and transaction management @@ -2091,7 +2098,13 @@ declare module "bun" { * @example * const [user] = await sql`select * from users where id = ${1}`; */ - (strings: string, ...values: any[]): SQLQuery; + (strings: string | TemplateStringsArray, ...values: any[]): SQLQuery; + /** + * Helper function to allow easy use to insert values into a query + * @example + * const result = await sql`insert into users ${sql(users)} RETURNING *`; + */ + (obj: any): SQLQuery; /** Commits a distributed transaction also know as prepared transaction in postgres or XA transaction in MySQL * @example * await sql.commitDistributed("my_distributed_transaction"); @@ -2107,12 +2120,12 @@ declare module "bun" { * await sql.connect(); */ connect(): Promise; - /** Closes the database connection with optional timeout in seconds + /** Closes the database connection with optional timeout in seconds. If timeout is 0, it will close immediately, if is not provided it will wait for all queries to finish before closing. * @example * await sql.close({ timeout: 1 }); */ close(options?: { timeout?: number }): Promise; - /** Closes the database connection with optional timeout in seconds + /** Closes the database connection with optional timeout in seconds. If timeout is 0, it will close immediately, if is not provided it will wait for all queries to finish before closing. * @alias close * @example * await sql.end({ timeout: 1 }); @@ -2166,7 +2179,7 @@ declare module "bun" { * return [user, account] * }) */ - begin(fn: SQLContextCallback): Promise; + begin(fn: SQLTransactionContextCallback): Promise; /** Begins a new transaction with options * Will reserve a connection for the transaction and supply a scoped sql instance for all transaction uses in the callback function. sql.begin will resolve with the returned value from the callback function. * BEGIN is automatically sent with the optional options, and if anything fails ROLLBACK will be called so the connection can be released and execution can continue. @@ -2191,7 +2204,7 @@ declare module "bun" { * return [user, account] * }) */ - begin(options: string, fn: SQLContextCallback): Promise; + begin(options: string, fn: SQLTransactionContextCallback): Promise; /** Alternative method to begin a transaction * Will reserve a connection for the transaction and supply a scoped sql instance for all transaction uses in the callback function. sql.transaction will resolve with the returned value from the callback function. * BEGIN is automatically sent with the optional options, and if anything fails ROLLBACK will be called so the connection can be released and execution can continue. @@ -2217,7 +2230,7 @@ declare module "bun" { * return [user, account] * }) */ - transaction(fn: SQLContextCallback): Promise; + transaction(fn: SQLTransactionContextCallback): Promise; /** Alternative method to begin a transaction with options * Will reserve a connection for the transaction and supply a scoped sql instance for all transaction uses in the callback function. sql.transaction will resolve with the returned value from the callback function. * BEGIN is automatically sent with the optional options, and if anything fails ROLLBACK will be called so the connection can be released and execution can continue. @@ -2243,7 +2256,7 @@ declare module "bun" { * return [user, account] * }) */ - transaction(options: string, fn: SQLContextCallback): Promise; + transaction(options: string, fn: SQLTransactionContextCallback): Promise; /** Begins a distributed transaction * Also know as Two-Phase Commit, in a distributed transaction, Phase 1 involves the coordinator preparing nodes by ensuring data is written and ready to commit, while Phase 2 finalizes with nodes committing or rolling back based on the coordinator's decision, ensuring durability and releasing locks. * In PostgreSQL and MySQL distributed transactions persist beyond the original session, allowing privileged users or coordinators to commit/rollback them, ensuring support for distributed transactions, recovery, and administrative tasks. @@ -2259,13 +2272,23 @@ declare module "bun" { * await sql.commitDistributed("numbers"); * // or await sql.rollbackDistributed("numbers"); */ - beginDistributed(name: string, fn: SQLContextCallback): Promise; + beginDistributed(name: string, fn: SQLTransactionContextCallback): Promise; /** Alternative method to begin a distributed transaction * @alias beginDistributed */ - distributed(name: string, fn: SQLContextCallback): Promise; + distributed(name: string, fn: SQLTransactionContextCallback): Promise; + /**If you know what you're doing, you can use unsafe to pass any string you'd like. + * Please note that this can lead to SQL injection if you're not careful. + * You can also nest sql.unsafe within a safe sql expression. This is useful if only part of your fraction has unsafe elements. + * @example + * const result = await sql.unsafe(`select ${danger} from users where id = ${dragons}`) + */ + unsafe(string: string, values?: any[]): SQLQuery; + /** Current client options */ options: SQLOptions; + + [Symbol.asyncDispose](): Promise; } /** @@ -2275,6 +2298,7 @@ declare module "bun" { interface ReservedSQL extends SQL { /** Releases the client back to the connection pool */ release(): void; + [Symbol.dispose](): void; } /** @@ -2283,10 +2307,17 @@ declare module "bun" { */ interface TransactionSQL extends SQL { /** Creates a savepoint within the current transaction */ - savepoint(name: string, fn: SQLContextCallback): Promise; + savepoint(name: string, fn: SQLSavepointContextCallback): Promise; + savepoint(fn: SQLSavepointContextCallback): Promise; } + /** + * Represents a savepoint within a transaction + */ + interface SavepointSQL extends SQL {} var sql: SQL; + var postgres: SQL; + var SQL: SQL; /** * This lets you use macros as regular imports diff --git a/src/bun.js/api/postgres.classes.ts b/src/bun.js/api/postgres.classes.ts index 40664ecb2b..de2ac1cb83 100644 --- a/src/bun.js/api/postgres.classes.ts +++ b/src/bun.js/api/postgres.classes.ts @@ -68,6 +68,10 @@ export default [ fn: "doDone", length: 0, }, + setMode: { + fn: "setMode", + length: 1, + }, }, values: ["pendingValue", "target", "columns", "binding"], estimatedSize: true, diff --git a/src/bun.js/bindings/BunObject.cpp b/src/bun.js/bindings/BunObject.cpp index cfe3b9fe33..787a0ba878 100644 --- a/src/bun.js/bindings/BunObject.cpp +++ b/src/bun.js/bindings/BunObject.cpp @@ -755,7 +755,8 @@ JSC_DEFINE_HOST_FUNCTION(functionFileURLToPath, (JSC::JSGlobalObject * globalObj revision constructBunRevision ReadOnly|DontDelete|PropertyCallback semver BunObject_getter_wrap_semver ReadOnly|DontDelete|PropertyCallback s3 BunObject_callback_s3 DontDelete|Function 1 - sql defaultBunSQLObject DontDelete|PropertyCallback + sql defaultBunSQLObject DontDelete|PropertyCallback + postgres defaultBunSQLObject DontDelete|PropertyCallback SQL constructBunSQLObject DontDelete|PropertyCallback serve BunObject_callback_serve DontDelete|Function 1 sha BunObject_callback_sha DontDelete|Function 1 diff --git a/src/bun.js/bindings/SQLClient.cpp b/src/bun.js/bindings/SQLClient.cpp index 2077cb29b5..af6737872c 100644 --- a/src/bun.js/bindings/SQLClient.cpp +++ b/src/bun.js/bindings/SQLClient.cpp @@ -32,6 +32,11 @@ typedef struct DataCellArray { unsigned length; } DataCellArray; +typedef struct DataCellRaw { + void* ptr; + uint64_t length; +} DataCellRaw; + typedef struct TypedArrayDataCell { void* headPtr; void* data; @@ -53,6 +58,7 @@ typedef union DataCellValue { WTF::StringImpl* json; DataCellArray array; TypedArrayDataCell typed_array; + DataCellRaw raw; } DataCellValue; enum class DataCellTag : uint8_t { @@ -68,6 +74,13 @@ enum class DataCellTag : uint8_t { Json = 9, Array = 10, TypedArray = 11, + Raw = 12, +}; + +enum class BunResultMode : uint8_t { + Objects = 0, + Values = 1, + Raw = 2, }; typedef struct DataCell { @@ -102,9 +115,24 @@ static JSC::JSValue toJS(JSC::VM& vm, JSC::JSGlobalObject* globalObject, DataCel case DataCellTag::Null: return jsNull(); break; + case DataCellTag::Raw: { + Zig::GlobalObject* zigGlobal = jsCast(globalObject); + auto* subclassStructure = zigGlobal->JSBufferSubclassStructure(); + auto* uint8Array = JSC::JSUint8Array::createUninitialized(globalObject, subclassStructure, cell.value.raw.length); + if (UNLIKELY(uint8Array == nullptr)) { + return {}; + } + + if (cell.value.raw.length > 0) { + memcpy(uint8Array->vector(), reinterpret_cast(cell.value.raw.ptr), cell.value.raw.length); + } + return uint8Array; + } case DataCellTag::String: { - return jsString(vm, WTF::String(cell.value.string)); - break; + if (cell.value.string) { + return jsString(vm, WTF::String(cell.value.string)); + } + return jsEmptyString(vm); } case DataCellTag::Double: return jsDoubleNumber(cell.value.number); @@ -137,10 +165,12 @@ static JSC::JSValue toJS(JSC::VM& vm, JSC::JSGlobalObject* globalObject, DataCel return uint8Array; } case DataCellTag::Json: { - auto str = WTF::String(cell.value.string); - JSC::JSValue json = JSC::JSONParse(globalObject, str); - return json; - break; + if (cell.value.json) { + auto str = WTF::String(cell.value.json); + JSC::JSValue json = JSC::JSONParse(globalObject, str); + return json; + } + return jsNull(); } case DataCellTag::Array: { MarkedArgumentBuffer args; @@ -250,79 +280,104 @@ static JSC::JSValue toJS(JSC::VM& vm, JSC::JSGlobalObject* globalObject, DataCel } } -static JSC::JSValue toJS(JSC::Structure* structure, DataCell* cells, unsigned count, JSC::JSGlobalObject* globalObject, Bun::BunStructureFlags flags) +static JSC::JSValue toJS(JSC::Structure* structure, DataCell* cells, unsigned count, JSC::JSGlobalObject* globalObject, Bun::BunStructureFlags flags, BunResultMode result_mode) { auto& vm = globalObject->vm(); - auto* object = JSC::constructEmptyObject(vm, structure); auto scope = DECLARE_THROW_SCOPE(vm); - // TODO: once we have more tests for this, let's add another branch for - // "only mixed names and mixed indexed columns, no duplicates" - // then we cna remove this sort and instead do two passes. - if (flags.hasIndexedColumns() && flags.hasNamedColumns()) { - // sort the cells by if they're named or indexed, put named first. - // this is to conform to the Structure offsets from earlier. - std::sort(cells, cells + count, [](DataCell& a, DataCell& b) { - return a.isNamedColumn() && !b.isNamedColumn(); - }); - } + switch (result_mode) { + case BunResultMode::Objects: // objects - // Fast path: named columns only, no duplicate columns - if (flags.hasNamedColumns() && !flags.hasDuplicateColumns() && !flags.hasIndexedColumns()) { - for (unsigned i = 0; i < count; i++) { - auto& cell = cells[i]; - JSValue value = toJS(vm, globalObject, cell); - RETURN_IF_EXCEPTION(scope, {}); - ASSERT(!cell.isDuplicateColumn()); - ASSERT(!cell.isIndexedColumn()); - ASSERT(cell.isNamedColumn()); - object->putDirectOffset(vm, i, value); + { + auto* object = JSC::constructEmptyObject(vm, structure); + + // TODO: once we have more tests for this, let's add another branch for + // "only mixed names and mixed indexed columns, no duplicates" + // then we cna remove this sort and instead do two passes. + if (flags.hasIndexedColumns() && flags.hasNamedColumns()) { + // sort the cells by if they're named or indexed, put named first. + // this is to conform to the Structure offsets from earlier. + std::sort(cells, cells + count, [](DataCell& a, DataCell& b) { + return a.isNamedColumn() && !b.isNamedColumn(); + }); } - } else if (flags.hasIndexedColumns() && !flags.hasNamedColumns() && !flags.hasDuplicateColumns()) { - for (unsigned i = 0; i < count; i++) { - auto& cell = cells[i]; - JSValue value = toJS(vm, globalObject, cell); - RETURN_IF_EXCEPTION(scope, {}); - ASSERT(!cell.isDuplicateColumn()); - ASSERT(cell.isIndexedColumn()); - ASSERT(!cell.isNamedColumn()); - // cell.index can be > count - // for example: - // select 1 as "8", 2 as "2", 3 as "3" - // -> { "8": 1, "2": 2, "3": 3 } - // 8 > count - object->putDirectIndex(globalObject, cell.index, value); - } - } else { - unsigned structureOffsetIndex = 0; - // slow path: named columns with duplicate columns or indexed columns - for (unsigned i = 0; i < count; i++) { - auto& cell = cells[i]; - if (cell.isIndexedColumn()) { + + // Fast path: named columns only, no duplicate columns + if (flags.hasNamedColumns() && !flags.hasDuplicateColumns() && !flags.hasIndexedColumns()) { + for (unsigned i = 0; i < count; i++) { + auto& cell = cells[i]; JSValue value = toJS(vm, globalObject, cell); RETURN_IF_EXCEPTION(scope, {}); - ASSERT(cell.index < count); - ASSERT(!cell.isNamedColumn()); ASSERT(!cell.isDuplicateColumn()); - object->putDirectIndex(globalObject, cell.index, value); - } else if (cell.isNamedColumn()) { - JSValue value = toJS(vm, globalObject, cell); - RETURN_IF_EXCEPTION(scope, {}); ASSERT(!cell.isIndexedColumn()); + ASSERT(cell.isNamedColumn()); + object->putDirectOffset(vm, i, value); + } + } else if (flags.hasIndexedColumns() && !flags.hasNamedColumns() && !flags.hasDuplicateColumns()) { + for (unsigned i = 0; i < count; i++) { + auto& cell = cells[i]; + JSValue value = toJS(vm, globalObject, cell); + RETURN_IF_EXCEPTION(scope, {}); ASSERT(!cell.isDuplicateColumn()); - ASSERT(cell.index < count); - object->putDirectOffset(vm, structureOffsetIndex++, value); - } else if (cell.isDuplicateColumn()) { - // skip it! + ASSERT(cell.isIndexedColumn()); + ASSERT(!cell.isNamedColumn()); + // cell.index can be > count + // for example: + // select 1 as "8", 2 as "2", 3 as "3" + // -> { "8": 1, "2": 2, "3": 3 } + // 8 > count + object->putDirectIndex(globalObject, cell.index, value); + } + } else { + unsigned structureOffsetIndex = 0; + // slow path: named columns with duplicate columns or indexed columns + for (unsigned i = 0; i < count; i++) { + auto& cell = cells[i]; + if (cell.isIndexedColumn()) { + JSValue value = toJS(vm, globalObject, cell); + RETURN_IF_EXCEPTION(scope, {}); + ASSERT(cell.index < count); + ASSERT(!cell.isNamedColumn()); + ASSERT(!cell.isDuplicateColumn()); + object->putDirectIndex(globalObject, cell.index, value); + } else if (cell.isNamedColumn()) { + JSValue value = toJS(vm, globalObject, cell); + RETURN_IF_EXCEPTION(scope, {}); + ASSERT(!cell.isIndexedColumn()); + ASSERT(!cell.isDuplicateColumn()); + ASSERT(cell.index < count); + object->putDirectOffset(vm, structureOffsetIndex++, value); + } else if (cell.isDuplicateColumn()) { + // skip it! + } } } + return object; } - return object; -} + case BunResultMode::Raw: // raw is just array mode with raw values + case BunResultMode::Values: // values + { + auto* array = JSC::constructEmptyArray(globalObject, static_cast(nullptr), count); + RETURN_IF_EXCEPTION(scope, {}); -static JSC::JSValue toJS(JSC::JSArray* array, JSC::Structure* structure, DataCell* cells, unsigned count, JSC::JSGlobalObject* globalObject, Bun::BunStructureFlags flags) + for (unsigned i = 0; i < count; i++) { + auto& cell = cells[i]; + JSValue value = toJS(vm, globalObject, cell); + RETURN_IF_EXCEPTION(scope, {}); + array->putDirectIndex(globalObject, i, value); + } + return array; + } + + default: + // not a valid result mode + ASSERT_NOT_REACHED(); + return jsUndefined(); + } +} +static JSC::JSValue toJS(JSC::JSArray* array, JSC::Structure* structure, DataCell* cells, unsigned count, JSC::JSGlobalObject* globalObject, Bun::BunStructureFlags flags, BunResultMode result_mode) { - JSValue value = toJS(structure, cells, count, globalObject, flags); + JSValue value = toJS(structure, cells, count, globalObject, flags, result_mode); if (value.isEmpty()) return {}; @@ -342,14 +397,13 @@ static JSC::JSValue toJS(JSC::JSArray* array, JSC::Structure* structure, DataCel extern "C" EncodedJSValue JSC__constructObjectFromDataCell( JSC::JSGlobalObject* globalObject, EncodedJSValue encodedArrayValue, - EncodedJSValue encodedStructureValue, DataCell* cells, unsigned count, unsigned flags) + EncodedJSValue encodedStructureValue, DataCell* cells, unsigned count, unsigned flags, uint8_t result_mode) { JSValue arrayValue = JSValue::decode(encodedArrayValue); JSValue structureValue = JSValue::decode(encodedStructureValue); auto* array = arrayValue ? jsDynamicCast(arrayValue) : nullptr; auto* structure = jsDynamicCast(structureValue); - - return JSValue::encode(toJS(array, structure, cells, count, globalObject, Bun::BunStructureFlags(flags))); + return JSValue::encode(toJS(array, structure, cells, count, globalObject, Bun::BunStructureFlags(flags), BunResultMode(result_mode))); } typedef struct ExternColumnIdentifier { diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index 1a7a8c7f65..bf5604104d 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -5,6 +5,7 @@ const enum QueryStatus { cancelled = 1 << 2, error = 1 << 3, executed = 1 << 4, + invalidHandle = 1 << 5, } const cmds = ["", "INSERT", "DELETE", "UPDATE", "MERGE", "SELECT", "MOVE", "FETCH", "COPY"]; @@ -22,6 +23,14 @@ function connectionClosedError() { } hideFromStack(connectionClosedError); +enum SQLQueryResultMode { + objects = 0, + values = 1, + raw = 2, +} +const escapeIdentifier = function escape(str) { + return '"' + str.replaceAll('"', '""').replaceAll(".", '"."') + '"'; +}; class SQLResultArray extends PublicArray { static [Symbol.toStringTag] = "SQLResults"; @@ -29,16 +38,16 @@ class SQLResultArray extends PublicArray { command; count; } - -const rawMode_values = 1; -const rawMode_objects = 2; - const _resolve = Symbol("resolve"); const _reject = Symbol("reject"); const _handle = Symbol("handle"); const _run = Symbol("run"); const _queryStatus = Symbol("status"); const _handler = Symbol("handler"); +const _strings = Symbol("strings"); +const _values = Symbol("values"); +const _poolSize = Symbol("poolSize"); +const _flags = Symbol("flags"); const PublicPromise = Promise; type TransactionCallback = (sql: (strings: string, ...values: any[]) => Query) => Promise; @@ -61,6 +70,7 @@ function normalizeSSLMode(value: string): SSLMode { case "prefer": return SSLMode.prefer; case "require": + case "required": return SSLMode.require; case "verify-ca": case "verify_ca": @@ -76,12 +86,38 @@ function normalizeSSLMode(value: string): SSLMode { throw $ERR_INVALID_ARG_VALUE("sslmode", value); } +enum SQLQueryFlags { + none = 0, + allowUnsafeTransaction = 1 << 0, + unsafe = 1 << 1, + bigint = 1 << 2, +} +function getQueryHandle(query) { + let handle = query[_handle]; + if (!handle) { + try { + query[_handle] = handle = doCreateQuery( + query[_strings], + query[_values], + query[_flags] & SQLQueryFlags.allowUnsafeTransaction, + query[_poolSize], + query[_flags] & SQLQueryFlags.bigint, + ); + } catch (err) { + query[_queryStatus] |= QueryStatus.error | QueryStatus.invalidHandle; + query.reject(err); + } + } + return handle; +} class Query extends PublicPromise { [_resolve]; [_reject]; [_handle]; [_handler]; [_queryStatus] = 0; + [_strings]; + [_values]; [Symbol.for("nodejs.util.inspect.custom")]() { const status = this[_queryStatus]; @@ -92,7 +128,7 @@ class Query extends PublicPromise { return `PostgresQuery { ${active ? "active" : ""} ${cancelled ? "cancelled" : ""} ${executed ? "executed" : ""} ${error ? "error" : ""} }`; } - constructor(handle, handler) { + constructor(strings, values, allowUnsafeTransaction, poolSize, handler) { var resolve_, reject_; super((resolve, reject) => { resolve_ = resolve; @@ -100,21 +136,29 @@ class Query extends PublicPromise { }); this[_resolve] = resolve_; this[_reject] = reject_; - this[_handle] = handle; + this[_handle] = null; this[_handler] = handler; - this[_queryStatus] = handle ? 0 : QueryStatus.cancelled; + this[_queryStatus] = 0; + this[_poolSize] = poolSize; + this[_strings] = strings; + this[_values] = values; + this[_flags] = allowUnsafeTransaction; } async [_run]() { - const { [_handle]: handle, [_handler]: handler, [_queryStatus]: status } = this; + const { [_handler]: handler, [_queryStatus]: status } = this; - if (status & (QueryStatus.executed | QueryStatus.error | QueryStatus.cancelled)) { + if (status & (QueryStatus.executed | QueryStatus.error | QueryStatus.cancelled | QueryStatus.invalidHandle)) { return; } + const handle = getQueryHandle(this); + if (!handle) return this; + this[_queryStatus] |= QueryStatus.executed; // this avoids a infinite loop await 1; + return handler(this, handle); } @@ -141,14 +185,20 @@ class Query extends PublicPromise { resolve(x) { this[_queryStatus] &= ~QueryStatus.active; - this[_handle].done(); + const handle = getQueryHandle(this); + if (!handle) return this; + handle.done(); return this[_resolve](x); } reject(x) { this[_queryStatus] &= ~QueryStatus.active; this[_queryStatus] |= QueryStatus.error; - this[_handle].done(); + if (!(this[_queryStatus] & QueryStatus.invalidHandle)) { + const handle = getQueryHandle(this); + if (!handle) return this[_reject](x); + handle.done(); + } return this[_reject](x); } @@ -161,7 +211,8 @@ class Query extends PublicPromise { this[_queryStatus] |= QueryStatus.cancelled; if (status & QueryStatus.executed) { - this[_handle].cancel(); + const handle = getQueryHandle(this); + handle.cancel(); } return this; @@ -173,12 +224,16 @@ class Query extends PublicPromise { } raw() { - this[_handle].raw = rawMode_objects; + const handle = getQueryHandle(this); + if (!handle) return this; + handle.setMode(SQLQueryResultMode.raw); return this; } values() { - this[_handle].raw = rawMode_values; + const handle = getQueryHandle(this); + if (!handle) return this; + handle.setMode(SQLQueryResultMode.values); return this; } @@ -259,6 +314,7 @@ enum PooledConnectionFlags { /// preReserved is used to indicate that the connection will be reserved in the future when queryCount drops to 0 preReserved = 1 << 2, } + class PooledConnection { pool: ConnectionPool; connection: ReturnType; @@ -267,7 +323,6 @@ class PooledConnection { queries: Set<(err: Error) => void> = new Set(); onFinish: ((err: Error | null) => void) | null = null; connectionInfo: any; - flags: number = 0; /// queryCount is used to indicate the number of queries using the connection, if a connection is reserved or if its a transaction queryCount will be 1 independently of the number of queries queryCount: number = 0; @@ -468,42 +523,34 @@ class ConnectionPool { this.onAllQueriesFinished(); } } + if (connection.state !== PooledConnectionState.connected) { // connection is not ready + if (connection.storedError) { + // this connection got a error but maybe we can wait for another + + if (this.hasConnectionsAvailable()) { + return; + } + + const waitingQueue = this.waitingQueue; + const reservedQueue = this.reservedQueue; + + this.waitingQueue = []; + this.reservedQueue = []; + // we have no connections available so lets fails + for (const pending of waitingQueue) { + pending(connection.storedError, connection); + } + for (const pending of reservedQueue) { + pending(connection.storedError, connection); + } + } return; } + if (was_reserved) { - if (this.waitingQueue.length > 0) { - if (connection.storedError) { - // this connection got a error but maybe we can wait for another - - if (this.hasConnectionsAvailable()) { - return; - } - - // we have no connections available so lets fails - let pending; - while ((pending = this.waitingQueue.shift())) { - pending.onConnected(connection.storedError, connection); - } - return; - } - const pendingReserved = this.reservedQueue.shift(); - if (pendingReserved) { - connection.flags |= PooledConnectionFlags.reserved; - connection.queryCount++; - // we have a connection waiting for a reserved connection lets prioritize it - pendingReserved(connection.storedError, connection); - return; - } - this.flushConcurrentQueries(); - } else { - // connection is ready, lets add it back to the ready connections - this.readyConnections.add(connection); - } - } else { - if (connection.queryCount == 0) { - // ok we can actually bind reserved queries to it + if (this.waitingQueue.length > 0 || this.reservedQueue.length > 0) { const pendingReserved = this.reservedQueue.shift(); if (pendingReserved) { connection.flags |= PooledConnectionFlags.reserved; @@ -515,9 +562,24 @@ class ConnectionPool { } this.readyConnections.add(connection); - this.flushConcurrentQueries(); + return; } + if (connection.queryCount === 0) { + // ok we can actually bind reserved queries to it + const pendingReserved = this.reservedQueue.shift(); + if (pendingReserved) { + connection.flags |= PooledConnectionFlags.reserved; + connection.queryCount++; + // we have a connection waiting for a reserved connection lets prioritize it + pendingReserved(connection.storedError, connection); + return; + } + } + + this.readyConnections.add(connection); + + this.flushConcurrentQueries(); } hasConnectionsAvailable() { @@ -600,6 +662,7 @@ class ConnectionPool { const { promise, resolve } = Promise.withResolvers(); connection.onFinish = resolve; promises.push(promise); + connection.connection.close(); } break; case PooledConnectionState.connected: @@ -622,7 +685,7 @@ class ConnectionPool { } async close(options?: { timeout?: number }) { if (this.closed) { - return Promise.reject(connectionClosedError()); + return; } let timeout = options?.timeout; if (timeout) { @@ -631,26 +694,40 @@ class ConnectionPool { throw $ERR_INVALID_ARG_VALUE("options.timeout", timeout, "must be a non-negative integer less than 2^31"); } this.closed = true; - if (timeout > 0 && this.hasPendingQueries()) { - const { promise, resolve } = Promise.withResolvers(); - const timer = setTimeout(() => { - // timeout is reached, lets close and probably fail some queries - this.#close().finally(resolve); - }, timeout * 1000); - timer.unref(); // dont block the event loop - this.onAllQueriesFinished = () => { - clearTimeout(timer); - // everything is closed, lets close the pool - this.#close().finally(resolve); - }; - - return promise; + if (timeout === 0 || !this.hasPendingQueries()) { + // close immediately + await this.#close(); + return; } + + const { promise, resolve } = Promise.withResolvers(); + const timer = setTimeout(() => { + // timeout is reached, lets close and probably fail some queries + this.#close().finally(resolve); + }, timeout * 1000); + timer.unref(); // dont block the event loop + this.onAllQueriesFinished = () => { + clearTimeout(timer); + // everything is closed, lets close the pool + this.#close().finally(resolve); + }; + + return promise; } else { this.closed = true; + if (!this.hasPendingQueries()) { + // close immediately + await this.#close(); + return; + } + // gracefully close the pool + const { promise, resolve } = Promise.withResolvers(); + this.onAllQueriesFinished = () => { + // everything is closed, lets close the pool + this.#close().finally(resolve); + }; + return promise; } - - await this.#close(); } /** @@ -721,8 +798,11 @@ class ConnectionPool { this.poolStarted = true; const pollSize = this.connections.length; // pool is always at least 1 connection - this.connections[0] = new PooledConnection(this.connectionInfo, this); - this.connections[0].flags |= PooledConnectionFlags.preReserved; // lets pre reserve the first connection + const firstConnection = new PooledConnection(this.connectionInfo, this); + this.connections[0] = firstConnection; + if (reserved) { + firstConnection.flags |= PooledConnectionFlags.preReserved; // lets pre reserve the first connection + } for (let i = 1; i < pollSize; i++) { this.connections[i] = new PooledConnection(this.connectionInfo, this); } @@ -732,7 +812,7 @@ class ConnectionPool { let connectionWithLeastQueries: PooledConnection | null = null; let leastQueries = Infinity; for (const connection of this.readyConnections) { - if (connection.flags & PooledConnectionFlags.reserved || connection.flags & PooledConnectionFlags.preReserved) + if (connection.flags & PooledConnectionFlags.preReserved || connection.flags & PooledConnectionFlags.reserved) continue; const queryCount = connection.queryCount; if (queryCount > 0) { @@ -802,8 +882,10 @@ function createConnection( var hasSQLArrayParameter = false; function normalizeStrings(strings, values) { hasSQLArrayParameter = false; - if ($isJSArray(strings)) { + + if ($isArray(strings)) { const count = strings.length; + if (count === 0) { return ""; } @@ -852,6 +934,93 @@ function normalizeStrings(strings, values) { return strings + ""; } +function hasQuery(value: any) { + return value instanceof Query; +} +function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize, bigint) { + let sqlString; + let final_values: Array; + if ($isArray(strings) && values.some(hasQuery)) { + // we need to handle fragments of queries + final_values = []; + const final_strings = []; + let strings_idx = 0; + + for (let i = 0; i < values.length; i++) { + const value = values[i]; + if (value instanceof Query) { + let sub_strings = value[_strings]; + var is_unsafe = value[_flags] & SQLQueryFlags.unsafe; + + if (typeof sub_strings === "string") { + if (!is_unsafe) { + // identifier + sub_strings = escapeIdentifier(sub_strings); + } + //@ts-ignore + final_strings.push(strings[strings_idx] + sub_strings + strings[strings_idx + 1]); + strings_idx += 2; // we merged 2 strings into 1 + // in this case we dont have values to merge + } else { + // complex fragment, we need to merge values + const sub_values = value[_values]; + + if (final_strings.length > 0) { + // complex not the first + const current_idx = final_strings.length - 1; + final_strings[current_idx] = final_strings[current_idx] + sub_strings[0]; + + if (sub_strings.length > 1) { + final_strings.push(...sub_strings.slice(1)); + } + final_values.push(...sub_values); + } else { + // complex the first + final_strings.push(strings[strings_idx] + sub_strings[0]); + strings_idx += 1; + final_values.push(...sub_values); + if (sub_strings.length > 1) { + final_strings.push(...sub_strings.slice(1)); + } + } + } + } else { + // for each value we have 2 strings + //@ts-ignore + final_strings.push(strings[strings_idx]); + strings_idx += 1; + if (strings_idx + 1 < strings.length) { + //@ts-ignore + final_strings.push(strings[strings_idx + 1]); + strings_idx += 1; + } + + final_values.push(value); + } + } + + sqlString = normalizeStrings(final_strings, final_values); + } else { + sqlString = normalizeStrings(strings, values); + final_values = values; + } + let columns; + if (hasSQLArrayParameter) { + hasSQLArrayParameter = false; + const v = final_values[0]; + columns = v.columns; + final_values = v.value; + } + if (!allowUnsafeTransaction) { + if (poolSize !== 1) { + const upperCaseSqlString = sqlString.toUpperCase().trim(); + if (upperCaseSqlString.startsWith("BEGIN") || upperCaseSqlString.startsWith("START TRANSACTION")) { + throw $ERR_POSTGRES_UNSAFE_TRANSACTION("Only use sql.begin, sql.reserved or max: 1"); + } + } + } + return createQuery(sqlString, final_values, new SQLResultArray(), columns, !!bigint); +} class SQLArrayParameter { value: any; @@ -901,8 +1070,9 @@ function loadOptions(o) { maxLifetime, onconnect, onclose, - max; - const env = Bun.env; + max, + bigint; + const env = Bun.env || {}; var sslMode: SSLMode = SSLMode.disable; if (o === undefined || (typeof o === "string" && o.length === 0)) { @@ -962,7 +1132,7 @@ function loadOptions(o) { } query = query.trim(); } - + o ||= {}; hostname ||= o.hostname || o.host || env.PGHOST || "localhost"; port ||= Number(o.port || env.PGPORT || 5432); username ||= o.username || o.user || env.PGUSERNAME || env.PGUSER || env.USER || env.USERNAME || "postgres"; @@ -978,6 +1148,7 @@ function loadOptions(o) { connectionTimeout ??= o.connection_timeout; maxLifetime ??= o.maxLifetime; maxLifetime ??= o.max_lifetime; + bigint ??= o.bigint; onconnect ??= o.onconnect; onclose ??= o.onclose; @@ -1002,6 +1173,7 @@ function loadOptions(o) { "must be a non-negative integer less than 2^31", ); } + idleTimeout *= 1000; } if (connectionTimeout != null) { @@ -1013,6 +1185,7 @@ function loadOptions(o) { "must be a non-negative integer less than 2^31", ); } + connectionTimeout *= 1000; } if (maxLifetime != null) { @@ -1024,6 +1197,7 @@ function loadOptions(o) { "must be a non-negative integer less than 2^31", ); } + maxLifetime *= 1000; } if (max != null) { @@ -1078,6 +1252,8 @@ function loadOptions(o) { } ret.max = max || 10; + ret.bigint = bigint; + return ret; } @@ -1091,6 +1267,7 @@ function assertValidTransactionName(name: string) { throw Error(`Distributed transaction name cannot contain single quotes.`); } } + function SQL(o, e = {}) { if (typeof o === "string" || o instanceof URL) { o = { ...e, url: o }; @@ -1098,26 +1275,6 @@ function SQL(o, e = {}) { var connectionInfo = loadOptions(o); var pool = new ConnectionPool(connectionInfo); - function doCreateQuery(strings, values, allowUnsafeTransaction) { - const sqlString = normalizeStrings(strings, values); - let columns; - if (hasSQLArrayParameter) { - hasSQLArrayParameter = false; - const v = values[0]; - columns = v.columns; - values = v.value; - } - if (!allowUnsafeTransaction) { - if (connectionInfo.max !== 1) { - const upperCaseSqlString = sqlString.toUpperCase().trim(); - if (upperCaseSqlString.startsWith("BEGIN") || upperCaseSqlString.startsWith("START TRANSACTION")) { - throw $ERR_POSTGRES_UNSAFE_TRANSACTION("Only use sql.begin, sql.reserved or max: 1"); - } - } - } - return createQuery(sqlString, values, new SQLResultArray(), columns); - } - function onQueryDisconnected(err) { // connection closed mid query this will not be called if the query finishes first const query = this; @@ -1159,7 +1316,27 @@ function SQL(o, e = {}) { } function queryFromPool(strings, values) { try { - return new Query(doCreateQuery(strings, values, false), queryFromPoolHandler); + return new Query( + strings, + values, + connectionInfo.bigint ? SQLQueryFlags.bigint : SQLQueryFlags.none, + connectionInfo.max, + queryFromPoolHandler, + ); + } catch (err) { + return Promise.reject(err); + } + } + + function unsafeQuery(strings, values) { + try { + return new Query( + strings, + values, + connectionInfo.bigint ? SQLQueryFlags.bigint | SQLQueryFlags.unsafe : SQLQueryFlags.unsafe, + connectionInfo.max, + queryFromPoolHandler, + ); } catch (err) { return Promise.reject(err); } @@ -1187,7 +1364,12 @@ function SQL(o, e = {}) { function queryFromTransaction(strings, values, pooledConnection, transactionQueries) { try { const query = new Query( - doCreateQuery(strings, values, true), + strings, + values, + connectionInfo.bigint + ? SQLQueryFlags.allowUnsafeTransaction | SQLQueryFlags.bigint + : SQLQueryFlags.allowUnsafeTransaction, + connectionInfo.max, queryFromTransactionHandler.bind(pooledConnection, transactionQueries), ); transactionQueries.add(query); @@ -1196,6 +1378,24 @@ function SQL(o, e = {}) { return Promise.reject(err); } } + function unsafeQueryFromTransaction(strings, values, pooledConnection, transactionQueries) { + try { + const query = new Query( + strings, + values, + connectionInfo.bigint + ? SQLQueryFlags.allowUnsafeTransaction | SQLQueryFlags.unsafe | SQLQueryFlags.bigint + : SQLQueryFlags.allowUnsafeTransaction | SQLQueryFlags.unsafe, + connectionInfo.max, + queryFromTransactionHandler.bind(pooledConnection, transactionQueries), + ); + transactionQueries.add(query); + return query; + } catch (err) { + return Promise.reject(err); + } + } + function onTransactionDisconnected(err) { const reject = this.reject; this.connectionState |= ReservedConnectionState.closed; @@ -1232,12 +1432,23 @@ function SQL(o, e = {}) { ) { return Promise.reject(connectionClosedError()); } - if ($isJSArray(strings) && strings[0] && typeof strings[0] === "object") { - return new SQLArrayParameter(strings, values); + if ($isArray(strings)) { + if (strings[0] && typeof strings[0] === "object") { + return new SQLArrayParameter(strings, values); + } + } else if ( + typeof strings === "object" && + !(strings instanceof Query) && + !(strings instanceof SQLArrayParameter) + ) { + return new SQLArrayParameter([strings], values); } // we use the same code path as the transaction sql return queryFromTransaction(strings, values, pooledConnection, state.queries); } + reserved_sql.unsafe = (string, args = []) => { + return unsafeQueryFromTransaction(string, args, pooledConnection, state.queries); + }; reserved_sql.connect = () => { if (state.connectionState & ReservedConnectionState.closed) { return Promise.reject(connectionClosedError()); @@ -1344,7 +1555,7 @@ function SQL(o, e = {}) { state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.reject(connectionClosedError()); + return Promise.resolve(undefined); } state.connectionState &= ~ReservedConnectionState.acceptQueries; let timeout = options?.timeout; @@ -1454,8 +1665,8 @@ function SQL(o, e = {}) { let transactionSavepoints = new Set(); const adapter = connectionInfo.adapter; let BEGIN_COMMAND: string = "BEGIN"; - let ROLLBACK_COMMAND: string = "COMMIT"; - let COMMIT_COMMAND: string = "ROLLBACK"; + let ROLLBACK_COMMAND: string = "ROLLBACK"; + let COMMIT_COMMAND: string = "COMMIT"; let SAVEPOINT_COMMAND: string = "SAVEPOINT"; let RELEASE_SAVEPOINT_COMMAND: string | null = "RELEASE SAVEPOINT"; let ROLLBACK_TO_SAVEPOINT_COMMAND: string = "ROLLBACK TO SAVEPOINT"; @@ -1546,12 +1757,23 @@ function SQL(o, e = {}) { ) { return Promise.reject(connectionClosedError()); } - if ($isJSArray(strings) && strings[0] && typeof strings[0] === "object") { - return new SQLArrayParameter(strings, values); + if ($isArray(strings)) { + if (strings[0] && typeof strings[0] === "object") { + return new SQLArrayParameter(strings, values); + } + } else if ( + typeof strings === "object" && + !(strings instanceof Query) && + !(strings instanceof SQLArrayParameter) + ) { + return new SQLArrayParameter([strings], values); } return queryFromTransaction(strings, values, pooledConnection, state.queries); } + transaction_sql.unsafe = (string, args = []) => { + return unsafeQueryFromTransaction(string, args, pooledConnection, state.queries); + }; // reserve is allowed to be called inside transaction connection but will return a new reserved connection from the pool and will not be part of the transaction // this matchs the behavior of the postgres package transaction_sql.reserve = () => sql.reserve(); @@ -1622,7 +1844,7 @@ function SQL(o, e = {}) { state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.reject(connectionClosedError()); + return Promise.resolve(undefined); } state.connectionState &= ~ReservedConnectionState.acceptQueries; const transactionQueries = state.queries; @@ -1774,13 +1996,21 @@ function SQL(o, e = {}) { * ] * sql`insert into users ${sql(users)}` */ - if ($isJSArray(strings) && strings[0] && typeof strings[0] === "object") { - return new SQLArrayParameter(strings, values); + if ($isArray(strings)) { + if (strings[0] && typeof strings[0] === "object") { + return new SQLArrayParameter(strings, values); + } + } else if (typeof strings === "object" && !(strings instanceof Query) && !(strings instanceof SQLArrayParameter)) { + return new SQLArrayParameter([strings], values); } return queryFromPool(strings, values); } + sql.unsafe = (string, args = []) => { + return unsafeQuery(string, args); + }; + sql.reserve = () => { if (pool.closed) { return Promise.reject(connectionClosedError()); @@ -1911,8 +2141,15 @@ var lazyDefaultSQL; function resetDefaultSQL(sql) { lazyDefaultSQL = sql; - Object.assign(defaultSQLObject, lazyDefaultSQL); - exportsObject.default = exportsObject.sql = lazyDefaultSQL; + // this will throw "attempt to assign to readonly property" + // Object.assign(defaultSQLObject, lazyDefaultSQL); + // exportsObject.default = exportsObject.sql = lazyDefaultSQL; +} + +function ensureDefaultSQL() { + if (!lazyDefaultSQL) { + resetDefaultSQL(SQL(undefined)); + } } var initialDefaultSQL; @@ -1926,6 +2163,62 @@ var defaultSQLObject = (initialDefaultSQL = function sql(strings, ...values) { return lazyDefaultSQL(strings, ...values); }); +defaultSQLObject.reserve = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.reserve(...args); +}; +defaultSQLObject.commitDistributed = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.commitDistributed(...args); +}; +defaultSQLObject.rollbackDistributed = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.rollbackDistributed(...args); +}; +defaultSQLObject.distributed = defaultSQLObject.beginDistributed = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.beginDistributed(...args); +}; + +defaultSQLObject.connect = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.connect(...args); +}; + +defaultSQLObject.unsafe = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.unsafe(...args); +}; + +defaultSQLObject.transaction = defaultSQLObject.begin = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.begin(...args); +}; + +defaultSQLObject.end = defaultSQLObject.close = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.close(...args); +}; +defaultSQLObject.flush = (...args) => { + ensureDefaultSQL(); + return lazyDefaultSQL.flush(...args); +}; +//define lazy properties +Object.defineProperties(defaultSQLObject, { + options: { + get: () => { + ensureDefaultSQL(); + return lazyDefaultSQL.options; + }, + }, + [Symbol.asyncDispose]: { + get: () => { + ensureDefaultSQL(); + return lazyDefaultSQL[Symbol.asyncDispose]; + }, + }, +}); + var exportsObject = { sql: defaultSQLObject, default: defaultSQLObject, diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index 445ad4e654..874575f8ae 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -58,6 +58,8 @@ pub const Data = union(enum) { temporary: []const u8, empty: void, + pub const Empty: Data = .{ .empty = {} }; + pub fn toOwned(this: @This()) !bun.ByteList { return switch (this) { .owned => this.owned, @@ -202,7 +204,11 @@ pub const PostgresSQLContext = struct { } } }; - +pub const PostgresSQLQueryResultMode = enum(u8) { + objects = 0, + values = 1, + raw = 2, +}; pub const PostgresSQLQuery = struct { statement: ?*PostgresSQLStatement = null, query: bun.String = bun.String.empty, @@ -212,9 +218,15 @@ pub const PostgresSQLQuery = struct { thisValue: JSValue = .undefined, status: Status = Status.pending, - is_done: bool = false, + ref_count: std.atomic.Value(u32) = std.atomic.Value(u32).init(1), - binary: bool = false, + + flags: packed struct { + is_done: bool = false, + binary: bool = false, + bigint: bool = false, + result_mode: PostgresSQLQueryResultMode = .objects, + } = .{}, pub usingnamespace JSC.Codegen.JSPostgresSQLQuery; const log = bun.Output.scoped(.PostgresSQLQuery, false); @@ -474,7 +486,7 @@ pub const PostgresSQLQuery = struct { consumePendingValue(thisValue, globalObject) orelse .undefined, tag.toJSTag(globalObject), tag.toJSNumber(), - PostgresSQLConnection.queriesGetCached(connection) orelse .undefined, + if (connection == .zero) .undefined else PostgresSQLConnection.queriesGetCached(connection) orelse .undefined, }); } @@ -489,7 +501,7 @@ pub const PostgresSQLQuery = struct { } pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { - const arguments = callframe.arguments_old(4).slice(); + const arguments = callframe.arguments_old(5).slice(); var args = JSC.Node.ArgumentsSlice.init(globalThis.bunVM(), arguments); defer args.deinit(); const query = args.nextEat() orelse { @@ -509,6 +521,8 @@ pub const PostgresSQLQuery = struct { const pending_value = args.nextEat() orelse .undefined; const columns = args.nextEat() orelse .undefined; + const js_bigint = args.nextEat() orelse .false; + const bigint = js_bigint.isBoolean() and js_bigint.asBoolean(); if (!pending_value.jsType().isArrayLike()) { return globalThis.throwInvalidArgumentType("query", "pendingValue", "Array"); @@ -522,6 +536,9 @@ pub const PostgresSQLQuery = struct { ptr.* = .{ .query = query.toBunString(globalThis), .thisValue = this_value, + .flags = .{ + .bigint = bigint, + }, }; ptr.query.ref(); @@ -541,16 +558,28 @@ pub const PostgresSQLQuery = struct { pub fn doDone(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { _ = globalObject; - this.is_done = true; + this.flags.is_done = true; return .undefined; } + pub fn setMode(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + const js_mode = callframe.argument(0); + if (js_mode.isEmptyOrUndefinedOrNull() or !js_mode.isNumber()) { + return globalObject.throwInvalidArgumentType("setMode", "mode", "Number"); + } + const mode = js_mode.coerce(i32, globalObject); + this.flags.result_mode = std.meta.intToEnum(PostgresSQLQueryResultMode, mode) catch { + return globalObject.throwInvalidArgumentTypeValue("mode", "Number", js_mode); + }; + return .undefined; + } pub fn doRun(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { var arguments_ = callframe.arguments_old(2); const arguments = arguments_.slice(); const connection: *PostgresSQLConnection = arguments[0].as(PostgresSQLConnection) orelse { return globalObject.throw("connection must be a PostgresSQLConnection", .{}); }; + connection.poll_ref.ref(globalObject.bunVM()); var query = arguments[1]; if (!query.isObject()) { @@ -563,14 +592,13 @@ pub const PostgresSQLQuery = struct { defer query_str.deinit(); const columns_value = PostgresSQLQuery.columnsGetCached(this_value) orelse .undefined; - var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value) catch |err| { + var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value, connection.prepared_statement_id) catch |err| { if (!globalObject.hasException()) return globalObject.throwError(err, "failed to generate signature"); return error.JSError; }; var writer = connection.writer(); - const entry = connection.statements.getOrPut(bun.default_allocator, bun.hash(signature.name)) catch |err| { signature.deinit(); return globalObject.throwError(err, "failed to allocate statement"); @@ -589,7 +617,7 @@ pub const PostgresSQLQuery = struct { // if it has params, we need to wait for ParamDescription to be received before we can write the data } else { - this.binary = this.statement.?.fields.len > 0; + this.flags.binary = this.statement.?.fields.len > 0; log("bindAndExecute", .{}); PostgresRequest.bindAndExecute(globalObject, this.statement.?, binding_value, columns_value, PostgresSQLConnection.Writer, writer) catch |err| { if (!globalObject.hasException()) @@ -604,6 +632,8 @@ pub const PostgresSQLQuery = struct { // If it does not have params, we can write and execute immediately in one go if (!has_params) { + log("prepareAndQueryWithSignature", .{}); + PostgresRequest.prepareAndQueryWithSignature(globalObject, query_str.slice(), binding_value, PostgresSQLConnection.Writer, writer, &signature) catch |err| { signature.deinit(); if (!globalObject.hasException()) @@ -612,7 +642,9 @@ pub const PostgresSQLQuery = struct { }; did_write = true; } else { - PostgresRequest.writeQuery(query_str.slice(), signature.name, signature.fields, PostgresSQLConnection.Writer, writer) catch |err| { + log("writeQuery", .{}); + + PostgresRequest.writeQuery(query_str.slice(), signature.prepared_statement_name, signature.fields, PostgresSQLConnection.Writer, writer) catch |err| { signature.deinit(); if (!globalObject.hasException()) return globalObject.throwError(err, "failed to write query"); @@ -630,7 +662,7 @@ pub const PostgresSQLQuery = struct { const stmt = bun.default_allocator.create(PostgresSQLStatement) catch |err| { return globalObject.throwError(err, "failed to allocate statement"); }; - + connection.prepared_statement_id += 1; stmt.* = .{ .signature = signature, .ref_count = 2, .status = PostgresSQLStatement.Status.parsing }; this.statement = stmt; entry.value_ptr.* = stmt; @@ -868,11 +900,11 @@ pub const PostgresRequest = struct { writer: protocol.NewWriter(Context), signature: *Signature, ) AnyPostgresError!void { - try writeQuery(query, signature.name, signature.fields, Context, writer); - try writeBind(signature.name, bun.String.empty, globalObject, array_value, .zero, &.{}, &.{}, Context, writer); + try writeQuery(query, signature.prepared_statement_name, signature.fields, Context, writer); + try writeBind(signature.prepared_statement_name, bun.String.empty, globalObject, array_value, .zero, &.{}, &.{}, Context, writer); var exec = protocol.Execute{ .p = .{ - .prepared_statement = signature.name, + .prepared_statement = signature.prepared_statement_name, }, }; try exec.writeInternal(Context, writer); @@ -889,10 +921,10 @@ pub const PostgresRequest = struct { comptime Context: type, writer: protocol.NewWriter(Context), ) !void { - try writeBind(statement.signature.name, bun.String.empty, globalObject, array_value, columns_value, statement.parameters, statement.fields, Context, writer); + try writeBind(statement.signature.prepared_statement_name, bun.String.empty, globalObject, array_value, columns_value, statement.parameters, statement.fields, Context, writer); var exec = protocol.Execute{ .p = .{ - .prepared_statement = statement.signature.name, + .prepared_statement = statement.signature.prepared_statement_name, }, }; try exec.writeInternal(Context, writer); @@ -979,6 +1011,7 @@ pub const PostgresSQLConnection = struct { globalObject: *JSC.JSGlobalObject, statements: PreparedStatementsMap, + prepared_statement_id: u64 = 0, pending_activity_count: std.atomic.Value(u32) = std.atomic.Value(u32).init(0), js_value: JSValue = JSValue.undefined, @@ -1343,7 +1376,7 @@ pub const PostgresSQLConnection = struct { this.ref(); defer this.deref(); - if (!this.socket.isClosed()) this.socket.close(); + this.refAndClose(); const on_close = this.consumeOnCloseCallback(this.globalObject) orelse return; const loop = this.globalObject.bunVM().eventLoop(); @@ -1415,6 +1448,8 @@ pub const PostgresSQLConnection = struct { const loop = vm.eventLoop(); loop.enter(); defer loop.exit(); + this.poll_ref.unref(this.globalObject.bunVM()); + this.fail("Connection closed", error.ConnectionClosed); } @@ -1428,7 +1463,7 @@ pub const PostgresSQLConnection = struct { .options = Data{ .temporary = this.options }, }; msg.writeInternal(Writer, this.writer()) catch |err| { - this.socket.close(); + this.refAndClose(); this.fail("Failed to write startup message", err); }; } @@ -1906,13 +1941,21 @@ pub const PostgresSQLConnection = struct { bun.default_allocator.destroy(this); } + fn refAndClose(this: *@This()) void { + if (!this.socket.isClosed()) { + // event loop need to be alive to close the socket + this.poll_ref.ref(this.globalObject.bunVM()); + // will unref on socket close + this.socket.close(); + } + } + pub fn disconnect(this: *@This()) void { this.stopTimers(); if (this.status == .connected) { this.status = .disconnected; - this.poll_ref.disable(); - this.socket.close(); + this.refAndClose(); } } @@ -2019,11 +2062,12 @@ pub const PostgresSQLConnection = struct { json = 9, array = 10, typed_array = 11, + raw = 12, }; pub const Value = extern union { null: u8, - string: bun.WTF.StringImpl, + string: ?bun.WTF.StringImpl, float8: f64, int4: i32, int8: i64, @@ -2031,9 +2075,10 @@ pub const PostgresSQLConnection = struct { date: f64, date_with_time_zone: f64, bytea: [2]usize, - json: bun.WTF.StringImpl, + json: ?bun.WTF.StringImpl, array: Array, typed_array: TypedArray, + raw: Raw, }; pub const Array = extern struct { @@ -2045,6 +2090,10 @@ pub const PostgresSQLConnection = struct { return ptr[0..this.len]; } }; + pub const Raw = extern struct { + ptr: ?[*]const u8 = null, + len: u64, + }; pub const TypedArray = extern struct { head_ptr: ?[*]u8 = null, ptr: ?[*]u8 = null, @@ -2068,10 +2117,14 @@ pub const PostgresSQLConnection = struct { switch (this.tag) { .string => { - this.value.string.deref(); + if (this.value.string) |str| { + str.deref(); + } }, .json => { - this.value.json.deref(); + if (this.value.json) |str| { + str.deref(); + } }, .bytea => { if (this.value.bytea[1] == 0) return; @@ -2091,8 +2144,21 @@ pub const PostgresSQLConnection = struct { else => {}, } } - - pub fn fromBytes(binary: bool, oid: int4, bytes: []const u8, globalObject: *JSC.JSGlobalObject) !DataCell { + pub fn raw(optional_bytes: ?*Data) DataCell { + if (optional_bytes) |bytes| { + const bytes_slice = bytes.slice(); + return DataCell{ + .tag = .raw, + .value = .{ .raw = .{ .ptr = @ptrCast(bytes_slice.ptr), .len = bytes_slice.len } }, + }; + } + // TODO: check empty and null fields + return DataCell{ + .tag = .null, + .value = .{ .null = 0 }, + }; + } + pub fn fromBytes(binary: bool, bigint: bool, oid: int4, bytes: []const u8, globalObject: *JSC.JSGlobalObject) !DataCell { switch (@as(types.Tag, @enumFromInt(@as(short, @intCast(oid))))) { // TODO: .int2_array, .float8_array inline .int4_array, .float4_array => |tag| { @@ -2121,13 +2187,13 @@ pub const PostgresSQLConnection = struct { .ptr = null, .len = 0, .byte_len = 0, - .type = tag.toJSTypedArrayType(), + .type = try tag.toJSTypedArrayType(), }, }, }; } - const elements = tag.pgArrayType().init(bytes).slice(); + const elements = (try tag.pgArrayType()).init(bytes).slice(); return DataCell{ .tag = .typed_array, @@ -2137,13 +2203,13 @@ pub const PostgresSQLConnection = struct { .ptr = if (elements.len > 0) @ptrCast(elements.ptr) else null, .len = @truncate(elements.len), .byte_len = @truncate(bytes.len), - .type = tag.toJSTypedArrayType(), + .type = try tag.toJSTypedArrayType(), }, }, }; } else { // TODO: - return fromBytes(false, @intFromEnum(types.Tag.bytea), bytes, globalObject); + return fromBytes(false, bigint, @intFromEnum(types.Tag.bytea), bytes, globalObject); } }, .int4 => { @@ -2153,6 +2219,15 @@ pub const PostgresSQLConnection = struct { return DataCell{ .tag = .int4, .value = .{ .int4 = bun.fmt.parseInt(i32, bytes, 0) catch 0 } }; } }, + // postgres when reading bigint as int8 it returns a string unless type: { bigint: postgres.BigInt is set + .int8 => { + if (bigint) { + // .int8 is a 64-bit integer always string + return DataCell{ .tag = .int8, .value = .{ .int8 = bun.fmt.parseInt(i64, bytes, 0) catch 0 } }; + } else { + return DataCell{ .tag = .string, .value = .{ .string = if (bytes.len > 0) bun.String.createUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; + } + }, .float8 => { if (binary and bytes.len == 8) { return DataCell{ .tag = .float8, .value = .{ .float8 = try parseBinary(.float8, f64, bytes) } }; @@ -2170,7 +2245,7 @@ pub const PostgresSQLConnection = struct { } }, .jsonb, .json => { - return DataCell{ .tag = .json, .value = .{ .json = String.createUTF8(bytes).value.WTFStringImpl }, .free_value = 1 }; + return DataCell{ .tag = .json, .value = .{ .json = if (bytes.len > 0) String.createUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; }, .bool => { if (binary) { @@ -2218,7 +2293,7 @@ pub const PostgresSQLConnection = struct { } }, else => { - return DataCell{ .tag = .string, .value = .{ .string = bun.String.createUTF8(bytes).value.WTFStringImpl }, .free_value = 1 }; + return DataCell{ .tag = .string, .value = .{ .string = if (bytes.len > 0) bun.String.createUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; }, } } @@ -2309,6 +2384,7 @@ pub const PostgresSQLConnection = struct { list: []DataCell, fields: []const protocol.FieldDescription, binary: bool = false, + bigint: bool = false, count: usize = 0, globalObject: *JSC.JSGlobalObject, @@ -2319,26 +2395,31 @@ pub const PostgresSQLConnection = struct { [*]DataCell, u32, Flags, + u8, // result_mode ) JSValue; - pub fn toJS(this: *Putter, globalObject: *JSC.JSGlobalObject, array: JSValue, structure: JSValue, flags: Flags) JSValue { - return JSC__constructObjectFromDataCell(globalObject, array, structure, this.list.ptr, @truncate(this.fields.len), flags); + pub fn toJS(this: *Putter, globalObject: *JSC.JSGlobalObject, array: JSValue, structure: JSValue, flags: Flags, result_mode: PostgresSQLQueryResultMode) JSValue { + return JSC__constructObjectFromDataCell(globalObject, array, structure, this.list.ptr, @truncate(this.fields.len), flags, @intFromEnum(result_mode)); } - pub fn put(this: *Putter, index: u32, optional_bytes: ?*Data) !bool { + fn putImpl(this: *Putter, index: u32, optional_bytes: ?*Data, comptime is_raw: bool) !bool { const field = &this.fields[index]; const oid = field.type_oid; debug("index: {d}, oid: {d}", .{ index, oid }); const cell: *DataCell = &this.list[index]; - cell.* = if (optional_bytes) |data| - try DataCell.fromBytes(this.binary, oid, data.slice(), this.globalObject) - else - DataCell{ - .tag = .null, - .value = .{ - .null = 0, - }, - }; + if (is_raw) { + cell.* = DataCell.raw(optional_bytes); + } else { + cell.* = if (optional_bytes) |data| + try DataCell.fromBytes(this.binary, this.bigint, oid, data.slice(), this.globalObject) + else + DataCell{ + .tag = .null, + .value = .{ + .null = 0, + }, + }; + } this.count += 1; cell.index = switch (field.name_or_index) { // The indexed columns can be out of order. @@ -2357,6 +2438,13 @@ pub const PostgresSQLConnection = struct { }; return true; } + + pub fn putRaw(this: *Putter, index: u32, optional_bytes: ?*Data) !bool { + return this.putImpl(index, optional_bytes, true); + } + pub fn put(this: *Putter, index: u32, optional_bytes: ?*Data) !bool { + return this.putImpl(index, optional_bytes, false); + } }; }; @@ -2402,7 +2490,7 @@ pub const PostgresSQLConnection = struct { continue; }; req.status = .binding; - req.binary = stmt.fields.len > 0; + req.flags.binary = stmt.fields.len > 0; any = true; } else { break; @@ -2429,12 +2517,24 @@ pub const PostgresSQLConnection = struct { .DataRow => { const request = this.current() orelse return error.ExpectedRequest; var statement = request.statement orelse return error.ExpectedStatement; - statement.checkForDuplicateFields(); + var structure: JSValue = .undefined; + // explict use switch without else so if new modes are added, we don't forget to check for duplicate fields + switch (request.flags.result_mode) { + .objects => { + // check for duplicate fields + statement.checkForDuplicateFields(); + structure = statement.structure(this.js_value, this.globalObject); + }, + .raw, .values => { + // no need to check for duplicate fields or structure + }, + } var putter = DataCell.Putter{ .list = &.{}, .fields = statement.fields, - .binary = request.binary, + .binary = request.flags.binary, + .bigint = request.flags.bigint, .globalObject = this.globalObject, }; @@ -2452,18 +2552,29 @@ pub const PostgresSQLConnection = struct { cells = try bun.default_allocator.alloc(DataCell, statement.fields.len); free_cells = true; } + // make sure all cells are reseted if reader short breaks the fields will just be null with is better than undefined behavior + @memset(cells, DataCell{ .tag = .null, .value = .{ .null = 0 } }); putter.list = cells; - try protocol.DataRow.decode( - &putter, - Context, - reader, - DataCell.Putter.put, - ); + if (request.flags.result_mode == .raw) { + try protocol.DataRow.decode( + &putter, + Context, + reader, + DataCell.Putter.putRaw, + ); + } else { + try protocol.DataRow.decode( + &putter, + Context, + reader, + DataCell.Putter.put, + ); + } const pending_value = if (request.thisValue == .zero) .zero else PostgresSQLQuery.pendingValueGetCached(request.thisValue) orelse .zero; pending_value.ensureStillAlive(); - const result = putter.toJS(this.globalObject, pending_value, statement.structure(this.js_value, this.globalObject), statement.fields_flags); + const result = putter.toJS(this.globalObject, pending_value, structure, statement.fields_flags, request.flags.result_mode); if (pending_value == .zero) { PostgresSQLQuery.pendingValueSetCached(request.thisValue, this.globalObject, result); @@ -2837,13 +2948,18 @@ pub const PostgresSQLConnection = struct { } pub fn consumeOnConnectCallback(this: *const PostgresSQLConnection, globalObject: *JSC.JSGlobalObject) ?JSC.JSValue { + debug("consumeOnConnectCallback", .{}); const on_connect = PostgresSQLConnection.onconnectGetCached(this.js_value) orelse return null; + debug("consumeOnConnectCallback exists", .{}); + PostgresSQLConnection.onconnectSetCached(this.js_value, globalObject, .zero); return on_connect; } pub fn consumeOnCloseCallback(this: *const PostgresSQLConnection, globalObject: *JSC.JSGlobalObject) ?JSC.JSValue { + debug("consumeOnCloseCallback", .{}); const on_close = PostgresSQLConnection.oncloseGetCached(this.js_value) orelse return null; + debug("consumeOnCloseCallback exists", .{}); PostgresSQLConnection.oncloseSetCached(this.js_value, globalObject, .zero); return on_close; } @@ -3102,8 +3218,11 @@ const Signature = struct { fields: []const int4, name: []const u8, query: []const u8, + prepared_statement_name: []const u8, + const log = bun.Output.scoped(.PostgresSignature, false); pub fn deinit(this: *Signature) void { + bun.default_allocator.free(this.prepared_statement_name); bun.default_allocator.free(this.fields); bun.default_allocator.free(this.name); bun.default_allocator.free(this.query); @@ -3116,7 +3235,7 @@ const Signature = struct { return hasher.final(); } - pub fn generate(globalObject: *JSC.JSGlobalObject, query: []const u8, array_value: JSValue, columns: JSValue) !Signature { + pub fn generate(globalObject: *JSC.JSGlobalObject, query: []const u8, array_value: JSValue, columns: JSValue, prepared_statement_id: u64) !Signature { var fields = std.ArrayList(int4).init(bun.default_allocator); var name = try std.ArrayList(u8).initCapacity(bun.default_allocator, query.len); @@ -3170,8 +3289,11 @@ const Signature = struct { if (iter.anyFailed()) { return error.InvalidQueryBinding; } + // max u64 length is 20, max prepared_statement_name length is 63 + const prepared_statement_name = try bun.fmt.allocPrint(bun.default_allocator, "P{s}${d}", .{ name.items[0..@min(40, name.items.len)], prepared_statement_id }); return Signature{ + .prepared_statement_name = prepared_statement_name, .name = name.items, .fields = fields.items, .query = try bun.default_allocator.dupe(u8, query), diff --git a/src/sql/postgres/postgres_protocol.zig b/src/sql/postgres/postgres_protocol.zig index 2abeff0787..6d36812f30 100644 --- a/src/sql/postgres/postgres_protocol.zig +++ b/src/sql/postgres/postgres_protocol.zig @@ -799,16 +799,16 @@ pub const ErrorResponse = struct { .{ "column", column, void }, .{ "constraint", constraint, void }, .{ "datatype", datatype, void }, - .{ "errno", code, i32 }, + // in the past this was set to i32 but postgres returns a strings lets keep it compatible + .{ "errno", code, void }, .{ "position", position, i32 }, .{ "schema", schema, void }, .{ "table", table, void }, .{ "where", where, void }, }; - const error_code: JSC.Error = // https://www.postgresql.org/docs/8.1/errcodes-appendix.html - if (code.toInt32() orelse 0 == 42601) + if (code.eqlComptime("42601")) JSC.Error.ERR_POSTGRES_SYNTAX_ERROR else JSC.Error.ERR_POSTGRES_SERVER_ERROR; @@ -948,7 +948,10 @@ pub const DataRow = struct { for (0..remaining_fields) |index| { const byte_length = try reader.int4(); switch (byte_length) { - 0 => break, + 0 => { + var empty = Data.Empty; + if (!try forEach(context, @intCast(index), &empty)) break; + }, null_int4 => { if (!try forEach(context, @intCast(index), null)) break; }, diff --git a/src/sql/postgres/postgres_types.zig b/src/sql/postgres/postgres_types.zig index 74e8ff104b..a6fa23ff73 100644 --- a/src/sql/postgres/postgres_types.zig +++ b/src/sql/postgres/postgres_types.zig @@ -253,38 +253,28 @@ pub const Tag = enum(short) { }; } - pub fn toJSTypedArrayType(comptime T: Tag) JSValue.JSType { + pub fn toJSTypedArrayType(comptime T: Tag) !JSValue.JSType { return comptime switch (T) { .int4_array => .Int32Array, // .int2_array => .Uint2Array, .float4_array => .Float32Array, // .float8_array => .Float64Array, - else => @compileError("TODO: not implemented"), + else => error.UnsupportedArrayType, }; } - pub fn byteArrayType(comptime T: Tag) type { + pub fn byteArrayType(comptime T: Tag) !type { return comptime switch (T) { .int4_array => i32, // .int2_array => i16, .float4_array => f32, // .float8_array => f64, - else => @compileError("TODO: not implemented"), + else => error.UnsupportedArrayType, }; } - pub fn unsignedByteArrayType(comptime T: Tag) type { - return comptime switch (T) { - .int4_array => u32, - // .int2_array => u16, - .float4_array => f32, - // .float8_array => f64, - else => @compileError("TODO: not implemented"), - }; - } - - pub fn pgArrayType(comptime T: Tag) type { - return PostgresBinarySingleDimensionArray(byteArrayType(T)); + pub fn pgArrayType(comptime T: Tag) !type { + return PostgresBinarySingleDimensionArray(try byteArrayType(T)); } fn toJSWithType( @@ -397,7 +387,7 @@ pub const Tag = enum(short) { if (value.isAnyInt()) { const int = value.toInt64(); - if (int >= std.math.minInt(u32) and int <= std.math.maxInt(u32)) { + if (int >= std.math.minInt(i32) and int <= std.math.maxInt(i32)) { return .int4; } diff --git a/test/js/sql/docker/Dockerfile b/test/js/sql/docker/Dockerfile new file mode 100644 index 0000000000..923a232e9f --- /dev/null +++ b/test/js/sql/docker/Dockerfile @@ -0,0 +1,69 @@ +# Dockerfile +FROM postgres:15 + +# Create initialization script +RUN echo '#!/bin/bash\n\ +set -e\n\ +\n\ +# Wait for PostgreSQL to start\n\ +until pg_isready; do\n\ + echo "Waiting for PostgreSQL to start..."\n\ + sleep 1\n\ +done\n\ +\n\ +dropdb --if-exists bun_sql_test\n\ +\n\ +# Drop and recreate users with different auth methods\n\ +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL\n\ + DROP USER IF EXISTS bun_sql_test;\n\ + CREATE USER bun_sql_test;\n\ + \n\ + ALTER SYSTEM SET password_encryption = '"'"'md5'"'"';\n\ + SELECT pg_reload_conf();\n\ + DROP USER IF EXISTS bun_sql_test_md5;\n\ + CREATE USER bun_sql_test_md5 WITH PASSWORD '"'"'bun_sql_test_md5'"'"';\n\ + \n\ + ALTER SYSTEM SET password_encryption = '"'"'scram-sha-256'"'"';\n\ + SELECT pg_reload_conf();\n\ + DROP USER IF EXISTS bun_sql_test_scram;\n\ + CREATE USER bun_sql_test_scram WITH PASSWORD '"'"'bun_sql_test_scram'"'"';\n\ +EOSQL\n\ +\n\ +# Create database and set permissions\n\ +createdb bun_sql_test\n\ +\n\ +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL\n\ + GRANT ALL ON DATABASE bun_sql_test TO bun_sql_test;\n\ + ALTER DATABASE bun_sql_test OWNER TO bun_sql_test;\n\ +EOSQL\n\ +' > /docker-entrypoint-initdb.d/init-users-db.sh + +# Make the script executable +RUN chmod +x /docker-entrypoint-initdb.d/init-users-db.sh + +# Create pg_hba.conf +RUN mkdir -p /etc/postgresql && touch /etc/postgresql/pg_hba.conf && \ + echo "local all postgres trust" >> /etc/postgresql/pg_hba.conf && \ + echo "local all bun_sql_test trust" >> /etc/postgresql/pg_hba.conf && \ + echo "local all bun_sql_test_md5 md5" >> /etc/postgresql/pg_hba.conf && \ + echo "local all bun_sql_test_scram scram-sha-256" >> /etc/postgresql/pg_hba.conf && \ + echo "host all postgres 127.0.0.1/32 trust" >> /etc/postgresql/pg_hba.conf && \ + echo "host all bun_sql_test 127.0.0.1/32 trust" >> /etc/postgresql/pg_hba.conf && \ + echo "host all bun_sql_test_md5 127.0.0.1/32 md5" >> /etc/postgresql/pg_hba.conf && \ + echo "host all bun_sql_test_scram 127.0.0.1/32 scram-sha-256" >> /etc/postgresql/pg_hba.conf && \ + echo "host all postgres ::1/128 trust" >> /etc/postgresql/pg_hba.conf && \ + echo "host all bun_sql_test ::1/128 trust" >> /etc/postgresql/pg_hba.conf && \ + echo "host all bun_sql_test_md5 ::1/128 md5" >> /etc/postgresql/pg_hba.conf && \ + echo "host all bun_sql_test_scram ::1/128 scram-sha-256" >> /etc/postgresql/pg_hba.conf && \ + echo "local replication all trust" >> /etc/postgresql/pg_hba.conf && \ + echo "host replication all 127.0.0.1/32 trust" >> /etc/postgresql/pg_hba.conf && \ + echo "host replication all ::1/128 trust" >> /etc/postgresql/pg_hba.conf +RUN mkdir -p /docker-entrypoint-initdb.d && \ + echo "ALTER SYSTEM SET max_prepared_transactions = '100';" > /docker-entrypoint-initdb.d/configure-postgres.sql + +# Set environment variables +ENV POSTGRES_HOST_AUTH_METHOD=trust +ENV POSTGRES_USER=postgres + +# Expose PostgreSQL port +EXPOSE 5432 \ No newline at end of file diff --git a/test/js/sql/docker/pg_hba.conf b/test/js/sql/docker/pg_hba.conf new file mode 100644 index 0000000000..0079ef3214 --- /dev/null +++ b/test/js/sql/docker/pg_hba.conf @@ -0,0 +1,98 @@ +# PostgreSQL Client Authentication Configuration File +# =================================================== +# +# Refer to the "Client Authentication" section in the PostgreSQL +# documentation for a complete description of this file. A short +# synopsis follows. +# +# This file controls: which hosts are allowed to connect, how clients +# are authenticated, which PostgreSQL user names they can use, which +# databases they can access. Records take one of these forms: +# +# local DATABASE USER METHOD [OPTIONS] +# host DATABASE USER ADDRESS METHOD [OPTIONS] +# hostssl DATABASE USER ADDRESS METHOD [OPTIONS] +# hostnossl DATABASE USER ADDRESS METHOD [OPTIONS] +# hostgssenc DATABASE USER ADDRESS METHOD [OPTIONS] +# hostnogssenc DATABASE USER ADDRESS METHOD [OPTIONS] +# +# (The uppercase items must be replaced by actual values.) +# +# The first field is the connection type: +# - "local" is a Unix-domain socket +# - "host" is a TCP/IP socket (encrypted or not) +# - "hostssl" is a TCP/IP socket that is SSL-encrypted +# - "hostnossl" is a TCP/IP socket that is not SSL-encrypted +# - "hostgssenc" is a TCP/IP socket that is GSSAPI-encrypted +# - "hostnogssenc" is a TCP/IP socket that is not GSSAPI-encrypted +# +# DATABASE can be "all", "sameuser", "samerole", "replication", a +# database name, or a comma-separated list thereof. The "all" +# keyword does not match "replication". Access to replication +# must be enabled in a separate record (see example below). +# +# USER can be "all", a user name, a group name prefixed with "+", or a +# comma-separated list thereof. In both the DATABASE and USER fields +# you can also write a file name prefixed with "@" to include names +# from a separate file. +# +# ADDRESS specifies the set of hosts the record matches. It can be a +# host name, or it is made up of an IP address and a CIDR mask that is +# an integer (between 0 and 32 (IPv4) or 128 (IPv6) inclusive) that +# specifies the number of significant bits in the mask. A host name +# that starts with a dot (.) matches a suffix of the actual host name. +# Alternatively, you can write an IP address and netmask in separate +# columns to specify the set of hosts. Instead of a CIDR-address, you +# can write "samehost" to match any of the server's own IP addresses, +# or "samenet" to match any address in any subnet that the server is +# directly connected to. +# +# METHOD can be "trust", "reject", "md5", "password", "scram-sha-256", +# "gss", "sspi", "ident", "peer", "pam", "ldap", "radius" or "cert". +# Note that "password" sends passwords in clear text; "md5" or +# "scram-sha-256" are preferred since they send encrypted passwords. +# +# OPTIONS are a set of options for the authentication in the format +# NAME=VALUE. The available options depend on the different +# authentication methods -- refer to the "Client Authentication" +# section in the documentation for a list of which options are +# available for which authentication methods. +# +# Database and user names containing spaces, commas, quotes and other +# special characters must be quoted. Quoting one of the keywords +# "all", "sameuser", "samerole" or "replication" makes the name lose +# its special character, and just match a database or username with +# that name. +# +# This file is read on server startup and when the server receives a +# SIGHUP signal. If you edit the file on a running system, you have to +# SIGHUP the server for the changes to take effect, run "pg_ctl reload", +# or execute "SELECT pg_reload_conf()". +# +# Put your actual configuration here +# ---------------------------------- +# +# If you want to allow non-local connections, you need to add more +# "host" records. In that case you will also need to make PostgreSQL +# listen on a non-local interface via the listen_addresses +# configuration parameter, or via the -i or -h command line switches. + +# CAUTION: Configuring the system for local "trust" authentication +# allows any local user to connect as any PostgreSQL user, including +# the database superuser. If you do not trust all your local users, +# use another authentication method. + + +# TYPE DATABASE USER ADDRESS METHOD + +# "local" is for Unix domain socket connections only +local all all trust +# IPv4 local connections: +host all all 127.0.0.1/32 trust +# IPv6 local connections: +host all all ::1/128 trust +# Allow replication connections from localhost, by a user with the +# replication privilege. +local replication all trust +host replication all 127.0.0.1/32 trust +host replication all ::1/128 trust diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index 4d4fd38c61..f79c769278 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -1,13 +1,109 @@ -import { sql } from "bun"; +import { sql, SQL } from "bun"; const postgres = (...args) => new sql(...args); -import { expect, test, mock } from "bun:test"; +import { expect, test, mock, beforeAll, afterAll } from "bun:test"; import { $ } from "bun"; -import { bunExe, isCI, withoutAggressiveGC } from "harness"; +import { bunExe, isCI, withoutAggressiveGC, isLinux } from "harness"; import path from "path"; -const hasPsql = Bun.which("psql"); -if (!isCI && hasPsql) { - require("./bootstrap.js"); +import { exec, execSync } from "child_process"; +import { promisify } from "util"; + +const execAsync = promisify(exec); +import net from "net"; +const dockerCLI = Bun.which("docker") as string; + +async function findRandomPort() { + return new Promise((resolve, reject) => { + // Create a server to listen on a random port + const server = net.createServer(); + server.listen(0, () => { + const port = server.address().port; + server.close(() => resolve(port)); + }); + server.on("error", reject); + }); +} +async function waitForPostgres(port) { + for (let i = 0; i < 3; i++) { + try { + const sql = new SQL(`postgres://postgres@localhost:${port}/postgres`, { + idle_timeout: 20, + max_lifetime: 60 * 30, + }); + + await sql`SELECT 1`; + await sql.end(); + console.log("PostgreSQL is ready!"); + return true; + } catch (error) { + console.log(`Waiting for PostgreSQL... (${i + 1}/3)`); + await new Promise(resolve => setTimeout(resolve, 1000)); + } + } + throw new Error("PostgreSQL failed to start"); +} + +async function startContainer(): Promise<{ port: number; containerName: string }> { + try { + // Build the Docker image + console.log("Building Docker image..."); + const dockerfilePath = path.join(import.meta.dir, "docker", "Dockerfile"); + await execAsync(`${dockerCLI} build --pull --rm -f "${dockerfilePath}" -t custom-postgres .`, { + cwd: path.join(import.meta.dir, "docker"), + }); + const port = await findRandomPort(); + const containerName = `postgres-test-${port}`; + // Check if container exists and remove it + try { + await execAsync(`${dockerCLI} rm -f ${containerName}`); + } catch (error) { + // Container might not exist, ignore error + } + + // Start the container + await execAsync( + `${dockerCLI} run -d --name ${containerName} -p ${port}:5432 custom-postgres`, + ); + + // Wait for PostgreSQL to be ready + await waitForPostgres(port); + return { + port, + containerName, + }; + } catch (error) { + console.error("Error:", error); + process.exit(1); + } +} + +function isDockerEnabled(): boolean { + if (!dockerCLI) { + return false; + } + + // TODO: investigate why its not starting on Linux arm64 + if (isLinux && process.arch === "arm64") { + return false; + } + + try { + const info = execSync(`${dockerCLI} info`, { stdio: ["ignore", "pipe", "inherit"] }); + return info.toString().indexOf("Server Version:") !== -1; + } catch { + return false; + } +} +if (isDockerEnabled()) { + const container: { port: number; containerName: string } = await startContainer(); + afterAll(async () => { + try { + await execAsync(`${dockerCLI} stop -t 0 ${container.containerName}`); + await execAsync(`${dockerCLI} rm -f ${container.containerName}`); + } catch (error) {} + }); + + // require("./bootstrap.js"); // macOS location: /opt/homebrew/var/postgresql@14/pg_hba.conf // --- Expected pg_hba.conf --- @@ -35,33 +131,36 @@ if (!isCI && hasPsql) { // host replication all 127.0.0.1/32 trust // host replication all ::1/128 trust // --- Expected pg_hba.conf --- - process.env.DATABASE_URL = "postgres://bun_sql_test@localhost:5432/bun_sql_test"; + process.env.DATABASE_URL = `postgres://bun_sql_test@localhost:${container.port}/bun_sql_test`; const login = { username: "bun_sql_test", + port: container.port, }; const login_md5 = { username: "bun_sql_test_md5", password: "bun_sql_test_md5", + port: container.port, }; const login_scram = { username: "bun_sql_test_scram", password: "bun_sql_test_scram", + port: container.port, }; const options = { db: "bun_sql_test", username: login.username, password: login.password, - idle_timeout: 0, - connect_timeout: 0, + port: container.port, max: 1, }; test("Connects with no options", async () => { - const sql = postgres({ max: 1 }); + // we need at least the usename and port + await using sql = postgres({ max: 1, port: container.port, username: login.username }); const result = (await sql`select 1 as x`)[0].x; sql.close(); @@ -73,19 +172,20 @@ if (!isCI && hasPsql) { const onconnect = mock(); await using sql = postgres({ ...options, - hostname: "unreachable_host", - connection_timeout: 1, + hostname: "example.com", + connection_timeout: 4, onconnect, onclose, + max: 1, }); let error: any; try { - await sql`select pg_sleep(2)`; + await sql`select pg_sleep(8)`; } catch (e) { error = e; } expect(error.code).toBe(`ERR_POSTGRES_CONNECTION_TIMEOUT`); - expect(error.message).toContain("Connection timeout after 1ms"); + expect(error.message).toContain("Connection timeout after 4s"); expect(onconnect).not.toHaveBeenCalled(); expect(onclose).toHaveBeenCalledTimes(1); }); @@ -118,7 +218,7 @@ if (!isCI && hasPsql) { const onconnect = mock(); await using sql = postgres({ ...options, - idle_timeout: 100, + idle_timeout: 1, onconnect, onclose, }); @@ -137,7 +237,7 @@ if (!isCI && hasPsql) { const onconnect = mock(); const sql = postgres({ ...options, - max_lifetime: 64, + max_lifetime: 1, onconnect, onclose, }); @@ -250,8 +350,8 @@ if (!isCI && hasPsql) { expect((await sql`select ${null} as x`)[0].x).toBeNull(); }); - test.todo("Unsigned Integer", async () => { - expect((await sql`select ${0x7fffffff + 2} as x`)[0].x).toBe(0x7fffffff + 2); + test("Unsigned Integer", async () => { + expect((await sql`select ${0x7fffffff + 2} as x`)[0].x).toBe("2147483649"); }); test("Signed Integer", async () => { @@ -326,15 +426,17 @@ if (!isCI && hasPsql) { // ['c', (await sql`select ${ sql.array(['a', 'b', 'c']) } as x`)[0].x[2]] // ) - // t('Array of Date', async() => { - // const now = new Date() - // return [now.getTime(), (await sql`select ${ sql.array([now, now, now]) } as x`)[0].x[2].getTime()] - // }) + // test("Array of Date", async () => { + // const now = new Date(); + // const result = await sql`select ${sql.array([now, now, now])} as x`; + // expect(result[0].x[2].getTime()).toBe(now.getTime()); + // }); - // t.only("Array of Box", async () => [ - // "(3,4),(1,2);(6,7),(4,5)", - // (await sql`select ${"{(1,2),(3,4);(4,5),(6,7)}"}::box[] as x`)[0].x.join(";"), - // ]); + test.todo("Array of Box", async () => { + const result = await sql`select ${"{(1,2),(3,4);(4,5),(6,7)}"}::box[] as x`; + console.log(result); + expect(result[0].x.join(";")).toBe("(1,2);(3,4);(4,5);(6,7)"); + }); // t('Nested array n2', async() => // ['4', (await sql`select ${ sql.array([[1, 2], [3, 4]]) } as x`)[0].x[1][1]] @@ -348,9 +450,9 @@ if (!isCI && hasPsql) { // ['Hello "you",c:\\windows', (await sql`select ${ sql.array(['Hello "you"', 'c:\\windows']) } as x`)[0].x.join(',')] // ) - // t.only("Escapes", async () => { - // expect(Object.keys((await sql`select 1 as ${sql('hej"hej')}`)[0])[0]).toBe('hej"hej'); - // }); + test("Escapes", async () => { + expect(Object.keys((await sql`select 1 as ${sql('hej"hej')}`)[0])[0]).toBe('hej"hej'); + }); // t.only( // "big query body", @@ -399,7 +501,7 @@ if (!isCI && hasPsql) { await sql`insert into test values('hej')`; }) .catch(e => e.errno), - ).toBe(22); + ).toBe("22P02"); } finally { await sql`drop table test`; } @@ -540,13 +642,13 @@ if (!isCI && hasPsql) { await sql .begin(sql => [sql`select wat`, sql`select current_setting('bun_sql.test') as x, ${1} as a`]) .catch(e => e.errno), - ).toBe(42703); + ).toBe("42703"); }); - // test.only("Fragments in transactions", async () => { - // const sql = postgres({ ...options, debug: true, idle_timeout: 1, fetch_types: false }); - // expect((await sql.begin(sql => sql`select true as x where ${sql`1=1`}`))[0].x).toBe(true); - // }); + test("Fragments in transactions", async () => { + const sql = postgres({ ...options, debug: true, idle_timeout: 1, fetch_types: false }); + expect((await sql.begin(sql => sql`select true as x where ${sql`1=1`}`))[0].x).toBe(true); + }); test("Transaction rejects with rethrown error", async () => { await using sql = postgres({ ...options }); @@ -618,21 +720,25 @@ if (!isCI && hasPsql) { } }); - // t('Helpers in Transaction', async() => { - // return ['1', (await sql.begin(async sql => - // await sql`select ${ sql({ x: 1 }) }` - // ))[0].x] - // }) + test("Helpers in Transaction", async () => { + const result = await sql.begin(async sql => await sql`select ${sql.unsafe("1 as x")}`); + expect(result[0].x).toBe(1); + }); - // t('Undefined values throws', async() => { - // let error + test("Undefined values throws", async () => { + // in bun case undefined is null should we fix this? null is a better DX - // await sql` - // select ${ undefined } as x - // `.catch(x => error = x.code) + // let error; - // return ['UNDEFINED_VALUE', error] - // }) + // await sql` + // select ${undefined} as x + // `.catch(x => (error = x.code)); + + // expect(error).toBe("UNDEFINED_VALUE"); + + const result = await sql`select ${undefined} as x`; + expect(result[0].x).toBeNull(); + }); // t('Transform undefined', async() => { // const sql = postgres({ ...options, transform: { undefined: null } }) @@ -648,20 +754,29 @@ if (!isCI && hasPsql) { // Add code property. test("Throw syntax error", async () => { + await using sql = postgres({ ...options, max: 1 }); const err = await sql`wat 1`.catch(x => x); + expect(err.errno).toBe("42601"); expect(err.code).toBe("ERR_POSTGRES_SYNTAX_ERROR"); - expect(err.errno).toBe(42601); expect(err).toBeInstanceOf(SyntaxError); }); - // t('Connect using uri', async() => - // [true, await new Promise((resolve, reject) => { - // const sql = postgres('postgres://' + login.user + ':' + (login.pass || '') + '@localhost:5432/' + options.db, { - // idle_timeout - // }) - // sql`select 1`.then(() => resolve(true), reject) - // })] - // ) + test("Connect using uri", async () => [ + true, + await new Promise((resolve, reject) => { + const sql = postgres( + "postgres://" + + login_md5.username + + ":" + + (login_md5.password || "") + + "@localhost:" + + container.port + + "/" + + options.db, + ); + sql`select 1`.then(() => resolve(true), reject); + }), + ]); // t('Options from uri with special characters in user and pass', async() => { // const opt = postgres({ user: 'öla', pass: 'pass^word' }).options @@ -754,7 +869,7 @@ if (!isCI && hasPsql) { }); // Promise.all on multiple values in-flight doesn't work currently due to pendingValueGetcached pointing to the wrong value. - test.todo("Parallel connections using scram-sha-256", async () => { + test("Parallel connections using scram-sha-256", async () => { await using sql = postgres({ ...options, ...login_scram }); return [ true, @@ -852,51 +967,51 @@ if (!isCI && hasPsql) { // return ['hello', result[0].x] // }) - // t('Connection ended promise', async() => { - // const sql = postgres(options) + test("Connection ended promise", async () => { + const sql = postgres(options); - // await sql.end() + await sql.end(); - // return [undefined, await sql.end()] - // }) + expect(await sql.end()).toBeUndefined(); + }); - // t('Connection ended timeout', async() => { - // const sql = postgres(options) + test("Connection ended timeout", async () => { + const sql = postgres(options); - // await sql.end({ timeout: 10 }) + await sql.end({ timeout: 10 }); - // return [undefined, await sql.end()] - // }) + expect(await sql.end()).toBeUndefined(); + }); - // t('Connection ended error', async() => { - // const sql = postgres(options) - // await sql.end() - // return ['CONNECTION_ENDED', (await sql``.catch(x => x.code))] - // }) + test("Connection ended error", async () => { + const sql = postgres(options); + await sql.end(); + return expect(await sql``.catch(x => x.code)).toBe("ERR_POSTGRES_CONNECTION_CLOSED"); + }); - // t('Connection end does not cancel query', async() => { - // const sql = postgres(options) + test("Connection end does not cancel query", async () => { + const sql = postgres(options); - // const promise = sql`select 1 as x`.execute() + const promise = sql`select pg_sleep(0.2) as x`.execute(); + // we await 1 to start the query + await 1; + await sql.end(); + return expect(await promise).toEqual([{ x: "" }]); + }); - // await sql.end() + test("Connection destroyed", async () => { + const sql = postgres(options); + process.nextTick(() => sql.end({ timeout: 0 })); + expect(await sql``.catch(x => x.code)).toBe("ERR_POSTGRES_CONNECTION_CLOSED"); + }); - // return [1, (await promise)[0].x] - // }) + test("Connection destroyed with query before", async () => { + const sql = postgres(options); + const error = sql`select pg_sleep(0.2)`.catch(err => err.code); - // t('Connection destroyed', async() => { - // const sql = postgres(options) - // process.nextTick(() => sql.end({ timeout: 0 })) - // return ['CONNECTION_DESTROYED', await sql``.catch(x => x.code)] - // }) - - // t('Connection destroyed with query before', async() => { - // const sql = postgres(options) - // , error = sql`select pg_sleep(0.2)`.catch(err => err.code) - - // sql.end({ timeout: 0 }) - // return ['CONNECTION_DESTROYED', await error] - // }) + sql.end({ timeout: 0 }); + return expect(await error).toBe("ERR_POSTGRES_CONNECTION_CLOSED"); + }); // t('transform column', async() => { // const sql = postgres({ @@ -1020,14 +1135,18 @@ if (!isCI && hasPsql) { // ] // }) - // t('unsafe', async() => { - // await sql`create table test (x int)` - // return [1, (await sql.unsafe('insert into test values ($1) returning *', [1]))[0].x, await sql`drop table test`] - // }) + test("unsafe", async () => { + await sql`create table test (x int)`; + try { + expect(await sql.unsafe("insert into test values ($1) returning *", [1])).toEqual([{ x: 1 }]); + } finally { + await sql`drop table test`; + } + }); - // t('unsafe simple', async() => { - // return [1, (await sql.unsafe('select 1 as x'))[0].x] - // }) + test("unsafe simple", async () => { + expect(await sql.unsafe("select 1 as x")).toEqual([{ x: 1 }]); + }); // t('unsafe simple includes columns', async() => { // return ['x', (await sql.unsafe('select 1 as x').values()).columns[0].name] @@ -1045,12 +1164,14 @@ if (!isCI && hasPsql) { // ] // }) - // t('simple query using unsafe with multiple statements', async() => { - // return [ - // '1,2', - // (await sql.unsafe('select 1 as x;select 2 as x')).map(x => x[0].x).join() - // ] - // }) + test.todo("simple query using unsafe with multiple statements", async () => { + // bun always uses prepared statements, so this is not supported + // PostgresError: cannot insert multiple commands into a prepared statement + // errno: "42601", + // code: "ERR_POSTGRES_SYNTAX_ERROR" + expect(await sql.unsafe("select 1 as x;select 2 as x")).toEqual([{ x: 1 }, { x: 2 }]); + // return ["1,2", (await sql.unsafe("select 1 as x;select 2 as x")).map(x => x[0].x).join()]; + }); // t('simple query using simple() with multiple statements', async() => { // return [ @@ -1291,9 +1412,9 @@ if (!isCI && hasPsql) { { timeout: 1000000 }, ); - // t('only allows one statement', async() => - // ['42601', await sql`select 1; select 2`.catch(e => e.code)] - // ) + test("only allows one statement", async () => { + expect(await sql`select 1; select 2`.catch(e => e.errno)).toBe("42601"); + }); // t('await sql() throws not tagged error', async() => { // let error @@ -1348,53 +1469,49 @@ if (!isCI && hasPsql) { } }); - // t('Connection errors are caught using begin()', { - // timeout: 2 - // }, async() => { - // let error - // try { - // const sql = postgres({ host: 'localhost', port: 1 }) + test("Connection errors are caught using begin()", async () => { + let error; + try { + const sql = postgres({ host: "localhost", port: 1 }); - // await sql.begin(async(sql) => { - // await sql`insert into test (label, value) values (${1}, ${2})` - // }) - // } catch (err) { - // error = err - // } + await sql.begin(async sql => { + await sql`insert into test (label, value) values (${1}, ${2})`; + }); + } catch (err) { + error = err; + } + expect(error.code).toBe("ERR_POSTGRES_CONNECTION_CLOSED"); + }); - // return [ - // true, - // error.code === 'ECONNREFUSED' || - // error.message === 'Connection refused (os error 61)' - // ] - // }) + test("dynamic table name", async () => { + await sql`create table test(a int)`; + try { + return expect((await sql`select * from ${sql("test")}`).length).toBe(0); + } finally { + await sql`drop table test`; + } + }); - // t('dynamic table name', async() => { - // await sql`create table test(a int)` - // return [ - // 0, (await sql`select * from ${ sql('test') }`).count, - // await sql`drop table test` - // ] - // }) + test("dynamic schema name", async () => { + await sql`create table test(a int)`; + try { + return expect((await sql`select * from ${sql("public")}.test`).length).toBe(0); + } finally { + await sql`drop table test`; + } + }); - // t('dynamic schema name', async() => { - // await sql`create table test(a int)` - // return [ - // 0, (await sql`select * from ${ sql('public') }.test`).count, - // await sql`drop table test` - // ] - // }) + test("dynamic schema and table name", async () => { + await sql`create table test(a int)`; + try { + return expect((await sql`select * from ${sql("public.test")}`).length).toBe(0); + } finally { + await sql`drop table test`; + } + }); - // t('dynamic schema and table name', async() => { - // await sql`create table test(a int)` - // return [ - // 0, (await sql`select * from ${ sql('public.test') }`).count, - // await sql`drop table test` - // ] - // }) - - test.todo("dynamic column name", async () => { - const result = await sql`select 1 as ${"\\!not_valid"}`; + test("dynamic column name", async () => { + const result = await sql`select 1 as ${sql("!not_valid")}`; expect(Object.keys(result[0])[0]).toBe("!not_valid"); }); @@ -1406,19 +1523,27 @@ if (!isCI && hasPsql) { // return [undefined, (await sql`select ${ sql({ a: 1, b: 2 }, 'a') }`)[0].b] // }) - // t('dynamic insert', async() => { - // await sql`create table test (a int, b text)` - // const x = { a: 42, b: 'the answer' } + test("dynamic insert", async () => { + await sql`create table test (a int, b text)`; + try { + const x = { a: 42, b: "the answer" }; + expect((await sql`insert into test ${sql(x)} returning *`)[0].b).toBe("the answer"); + } finally { + await sql`drop table test`; + } + }); - // return ['the answer', (await sql`insert into test ${ sql(x) } returning *`)[0].b, await sql`drop table test`] - // }) - - // test.todo("dynamic insert pluck", async () => { - // await sql`create table test (a int, b text)`; - // const x = { a: 42, b: "the answer" }; - // const [{ b }] = await sql`insert into test ${sql(x, "a")} returning *`; - // expect(b).toBe("the answer"); - // }); + test("dynamic insert pluck", async () => { + try { + await sql`create table test2 (a int, b text)`; + const x = { a: 42, b: "the answer" }; + const [{ b, a }] = await sql`insert into test2 ${sql(x, "a")} returning *`; + expect(b).toBeNull(); + expect(a).toBe(42); + } finally { + await sql`drop table test2`; + } + }); // t('dynamic in with empty array', async() => { // await sql`create table test (a int)` @@ -1871,17 +1996,20 @@ if (!isCI && hasPsql) { // return [1, (await sql`select 1 as x`)[0].x] // }) - test("Big result", async () => { - const result = await sql`select * from generate_series(1, 100000)`; - expect(result.count).toBe(100000); - let i = 1; + test.skipIf(isCI)( + "Big result", + async () => { + await using sql = postgres(options); + const result = await sql`select * from generate_series(1, 100000)`; + expect(result.count).toBe(100000); + let i = 1; - for (const row of result) { - if (row.generate_series !== i++) { - throw new Error(`Row out of order at index ${i - 1}`); + for (const row of result) { + expect(row.generate_series).toBe(i++); } - } - }); + }, + 10000, + ); // t('Debug', async() => { // let result @@ -1895,10 +2023,17 @@ if (!isCI && hasPsql) { // return ['select 1', result] // }) - // t('bigint is returned as String', async() => [ - // 'string', - // typeof (await sql`select 9223372036854777 as x`)[0].x - // ]) + test("bigint is returned as String", async () => { + expect(typeof (await sql`select 9223372036854777 as x`)[0].x).toBe("string"); + }); + + test("bigint is returned as BigInt", async () => { + await using sql = postgres({ + ...options, + bigint: true, + }); + expect((await sql`select 9223372036854777 as x`)[0].x).toBe(9223372036854777n); + }); test("int is returned as Number", async () => { expect((await sql`select 123 as x`)[0].x).toBe(123); @@ -2865,25 +3000,22 @@ if (!isCI && hasPsql) { // return ['12233445566778', xs.sort().join('')] // }) - // t('reserve connection', async() => { - // const reserved = await sql.reserve() + test("reserve connection", async () => { + const sql = postgres({ ...options, max: 1 }); + const reserved = await sql.reserve(); - // setTimeout(() => reserved.release(), 510) + setTimeout(() => reserved.release(), 510); - // const xs = await Promise.all([ - // reserved`select 1 as x`.then(([{ x }]) => ({ time: Date.now(), x })), - // sql`select 2 as x`.then(([{ x }]) => ({ time: Date.now(), x })), - // reserved`select 3 as x`.then(([{ x }]) => ({ time: Date.now(), x })) - // ]) + const xs = await Promise.all([ + reserved`select 1 as x`.then(([{ x }]) => ({ time: Date.now(), x })), + sql`select 2 as x`.then(([{ x }]) => ({ time: Date.now(), x })), + reserved`select 3 as x`.then(([{ x }]) => ({ time: Date.now(), x })), + ]); - // if (xs[1].time - xs[2].time < 500) - // throw new Error('Wrong time') + if (xs[1].time - xs[2].time < 500) throw new Error("Wrong time"); - // return [ - // '123', - // xs.map(x => x.x).join('') - // ] - // }) + expect(xs.map(x => x.x).join("")).toBe("123"); + }); test("keeps process alive when it should", async () => { const file = path.posix.join(__dirname, "sql-fixture-ref.ts"); diff --git a/test/js/sql/tls-sql.test.ts b/test/js/sql/tls-sql.test.ts index 78bd4d0daa..bb573a2976 100644 --- a/test/js/sql/tls-sql.test.ts +++ b/test/js/sql/tls-sql.test.ts @@ -1,25 +1,240 @@ -import { test, expect } from "bun:test"; +import { test, expect, mock } from "bun:test"; import { getSecret } from "harness"; -import { sql as SQL } from "bun"; +import { SQL, sql, postgres } from "bun"; const TLS_POSTGRES_DATABASE_URL = getSecret("TLS_POSTGRES_DATABASE_URL"); +const options = { + url: TLS_POSTGRES_DATABASE_URL, + tls: true, + adapter: "postgresql", + max: 1, + bigint: true, +}; if (TLS_POSTGRES_DATABASE_URL) { + test("default sql", async () => { + expect(sql.reserve).toBeDefined(); + expect(sql.options).toBeDefined(); + expect(sql[Symbol.asyncDispose]).toBeDefined(); + expect(sql.begin).toBeDefined(); + expect(sql.beginDistributed).toBeDefined(); + expect(sql.distributed).toBeDefined(); + expect(sql.unsafe).toBeDefined(); + expect(sql.end).toBeDefined(); + expect(sql.close).toBeDefined(); + expect(sql.transaction).toBeDefined(); + expect(sql.distributed).toBeDefined(); + expect(sql.unsafe).toBeDefined(); + expect(sql.commitDistributed).toBeDefined(); + expect(sql.rollbackDistributed).toBeDefined(); + }); + test("default postgres", async () => { + expect(postgres.reserve).toBeDefined(); + expect(postgres.options).toBeDefined(); + expect(postgres[Symbol.asyncDispose]).toBeDefined(); + expect(postgres.begin).toBeDefined(); + expect(postgres.beginDistributed).toBeDefined(); + expect(postgres.distributed).toBeDefined(); + expect(postgres.unsafe).toBeDefined(); + expect(postgres.end).toBeDefined(); + expect(postgres.close).toBeDefined(); + expect(postgres.transaction).toBeDefined(); + expect(postgres.distributed).toBeDefined(); + expect(postgres.unsafe).toBeDefined(); + expect(postgres.commitDistributed).toBeDefined(); + expect(postgres.rollbackDistributed).toBeDefined(); + }); test("tls (explicit)", async () => { - const sql = new SQL({ - url: TLS_POSTGRES_DATABASE_URL!, - tls: true, - adapter: "postgresql", - }); - + await using sql = new SQL(options); const [{ one, two }] = await sql`SELECT 1 as one, '2' as two`; expect(one).toBe(1); expect(two).toBe("2"); + await sql.close(); }); - test("tls (implicit)", async () => { - const [{ one, two }] = await SQL`SELECT 1 as one, '2' as two`; - expect(one).toBe(1); - expect(two).toBe("2"); + test("Throws on illegal transactions", async () => { + await using sql = new SQL({ ...options, max: 2 }); + const error = await sql`BEGIN`.catch(e => e); + return expect(error.code).toBe("ERR_POSTGRES_UNSAFE_TRANSACTION"); + }); + + test("Transaction throws", async () => { + await using sql = new SQL(options); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into test values(1)`; + await sql`insert into test values('hej')`; + }) + .catch(e => e.errno), + ).toBe("22P02"); + }); + + test("Transaction rolls back", async () => { + await using sql = new SQL(options); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; + + await sql + .begin(async sql => { + await sql`insert into test values(1)`; + await sql`insert into test values('hej')`; + }) + .catch(() => { + /* ignore */ + }); + + expect((await sql`select a from test`).count).toBe(0); + }); + + test("Transaction throws on uncaught savepoint", async () => { + await using sql = new SQL(options); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into test values(1)`; + await sql.savepoint(async sql => { + await sql`insert into test values(2)`; + throw new Error("fail"); + }); + }) + .catch(err => err.message), + ).toBe("fail"); + }); + + test("Transaction throws on uncaught named savepoint", async () => { + await using sql = new SQL(options); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into test values(1)`; + await sql.savepoit("watpoint", async sql => { + await sql`insert into test values(2)`; + throw new Error("fail"); + }); + }) + .catch(() => "fail"), + ).toBe("fail"); + }); + + test("Transaction succeeds on caught savepoint", async () => { + await using sql = new SQL(options); + const table_id = `test_random${Bun.randomUUIDv7().toString().replace(/-/g, "_")}`; + await sql`CREATE TABLE IF NOT EXISTS ${sql(table_id)} (a int)`; + try { + await sql.begin(async sql => { + await sql`insert into ${sql(table_id)} values(1)`; + await sql + .savepoint(async sql => { + await sql`insert into ${sql(table_id)} values(2)`; + throw new Error("please rollback"); + }) + .catch(() => { + /* ignore */ + }); + await sql`insert into ${sql(table_id)} values(3)`; + }); + expect((await sql`select count(1) from ${sql(table_id)}`)[0].count).toBe(2n); + } finally { + await sql`DROP TABLE IF EXISTS ${sql(table_id)}`; + } + }); + + test("Savepoint returns Result", async () => { + let result; + await using sql = new SQL(options); + await sql.begin(async t => { + result = await t.savepoint(s => s`select 1 as x`); + }); + expect(result[0]?.x).toBe(1); + }); + + test("Transaction requests are executed implicitly", async () => { + await using sql = new SQL({ ...options, debug: true, idle_timeout: 1, fetch_types: false }); + expect( + ( + await sql.begin(sql => [ + sql`select set_config('bun_sql.test', 'testing', true)`, + sql`select current_setting('bun_sql.test') as x`, + ]) + )[1][0].x, + ).toBe("testing"); + }); + + test("Uncaught transaction request errors bubbles to transaction", async () => { + await using sql = new SQL({ ...options, debug: true, idle_timeout: 1, fetch_types: false, max: 10 }); + expect( + await sql + .begin(sql => [sql`select wat`, sql`select current_setting('bun_sql.test') as x, ${1} as a`]) + .catch(e => e.errno), + ).toBe("42703"); + }); + + test("Transaction rejects with rethrown error", async () => { + await using sql = new SQL(options); + expect( + await sql + .begin(async sql => { + try { + await sql`select exception`; + } catch (ex) { + throw new Error("WAT"); + } + }) + .catch(e => e.message), + ).toBe("WAT"); + }); + + test("Parallel transactions", async () => { + await using sql = new SQL({ ...options, max: 2 }); + + expect( + (await Promise.all([sql.begin(sql => sql`select 1 as count`), sql.begin(sql => sql`select 1 as count`)])) + .map(x => x[0].count) + .join(""), + ).toBe("11"); + }); + + test("Many transactions at beginning of connection", async () => { + await using sql = new SQL({ ...options, max: 10 }); + const xs = await Promise.all(Array.from({ length: 100 }, () => sql.begin(sql => sql`select 1`))); + return expect(xs.length).toBe(100); + }); + + test("Transactions array", async () => { + await using sql = new SQL(options); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; + expect( + (await sql.begin(sql => [sql`select 1 as count`, sql`select 1 as count`])).map(x => x[0].count).join(""), + ).toBe("11"); + }); + + test("Transaction waits", async () => { + await using sql = new SQL({ ...options, max: 10 }); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; + await sql.begin(async sql => { + await sql`insert into test values(1)`; + await sql + .savepoint(async sql => { + await sql`insert into test values(2)`; + throw new Error("please rollback"); + }) + .catch(() => { + /* ignore */ + }); + await sql`insert into test values(3)`; + }); + expect( + ( + await Promise.all([ + sql.begin(async sql => await sql`select 1 as count`), + sql.begin(async sql => await sql`select 1 as count`), + ]) + ) + .map(x => x[0].count) + .join(""), + ).toBe("11"); }); }