diff --git a/packages/bun-types/bun.d.ts b/packages/bun-types/bun.d.ts index 6e29b9d3b5..d13bc035f2 100644 --- a/packages/bun-types/bun.d.ts +++ b/packages/bun-types/bun.d.ts @@ -2066,6 +2066,8 @@ declare module "bun" { max?: number; /** By default values outside i32 range are returned as strings. If this is true, values outside i32 range are returned as BigInts. */ bigint?: boolean; + /** Automatic creation of prepared statements, defaults to true */ + prepare?: boolean; }; /** @@ -2079,6 +2081,8 @@ declare module "bun" { cancelled: boolean; /** Cancels the executing query */ cancel(): SQLQuery; + /** Execute as a simple query, no parameters are allowed but can execute multiple commands separated by semicolons */ + simple(): SQLQuery; /** Executes the query */ execute(): SQLQuery; /** Returns the raw query result */ diff --git a/src/bun.js/api/postgres.classes.ts b/src/bun.js/api/postgres.classes.ts index de2ac1cb83..136c0dd09b 100644 --- a/src/bun.js/api/postgres.classes.ts +++ b/src/bun.js/api/postgres.classes.ts @@ -72,6 +72,10 @@ export default [ fn: "setMode", length: 1, }, + setPendingValue: { + fn: "setPendingValue", + length: 1, + }, }, values: ["pendingValue", "target", "columns", "binding"], estimatedSize: true, diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index c0fad4208b..a1caf6ca7d 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -175,7 +175,6 @@ const errors: ErrorCodeMapping = [ ["ERR_POSTGRES_INVALID_TRANSACTION_STATE", Error, "PostgresError"], ["ERR_POSTGRES_QUERY_CANCELLED", Error, "PostgresError"], ["ERR_POSTGRES_UNSAFE_TRANSACTION", Error, "PostgresError"], - // S3 ["ERR_S3_MISSING_CREDENTIALS", Error], ["ERR_S3_INVALID_METHOD", Error], diff --git a/src/bun.js/bindings/bindings.zig b/src/bun.js/bindings/bindings.zig index e1a9565cbc..743472018a 100644 --- a/src/bun.js/bindings/bindings.zig +++ b/src/bun.js/bindings/bindings.zig @@ -6540,10 +6540,10 @@ pub const CallFrame = opaque { /// arguments(n).mut() -> `var args = argumentsAsArray(n); &args` pub fn arguments_old(self: *const CallFrame, comptime max: usize) Arguments(max) { const slice = self.arguments(); - comptime bun.assert(max <= 13); + comptime bun.assert(max <= 14); return switch (@as(u4, @min(slice.len, max))) { 0 => .{ .ptr = undefined, .len = 0 }, - inline 1...13 => |count| Arguments(max).init(comptime @min(count, max), slice.ptr), + inline 1...14 => |count| Arguments(max).init(comptime @min(count, max), slice.ptr), else => unreachable, }; } diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index 52b9878045..3925d230d8 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -47,6 +47,7 @@ const _strings = Symbol("strings"); const _values = Symbol("values"); const _poolSize = Symbol("poolSize"); const _flags = Symbol("flags"); +const _results = Symbol("results"); const PublicPromise = Promise; type TransactionCallback = (sql: (strings: string, ...values: any[]) => Query) => Promise; @@ -90,6 +91,7 @@ enum SQLQueryFlags { allowUnsafeTransaction = 1 << 0, unsafe = 1 << 1, bigint = 1 << 2, + simple = 1 << 3, } function getQueryHandle(query) { @@ -102,6 +104,7 @@ function getQueryHandle(query) { query[_flags] & SQLQueryFlags.allowUnsafeTransaction, query[_poolSize], query[_flags] & SQLQueryFlags.bigint, + query[_flags] & SQLQueryFlags.simple, ); } catch (err) { query[_queryStatus] |= QueryStatus.error | QueryStatus.invalidHandle; @@ -144,6 +147,7 @@ class Query extends PublicPromise { this[_strings] = strings; this[_values] = values; this[_flags] = allowUnsafeTransaction; + this[_results] = null; } async [_run](async: boolean) { @@ -237,6 +241,11 @@ class Query extends PublicPromise { return this; } + simple() { + this[_flags] |= SQLQueryFlags.simple; + return this; + } + values() { const handle = getQueryHandle(this); if (!handle) return this; @@ -266,7 +275,51 @@ class Query extends PublicPromise { Object.defineProperty(Query, Symbol.species, { value: PublicPromise }); Object.defineProperty(Query, Symbol.toStringTag, { value: "Query" }); init( - function onResolvePostgresQuery(query, result, commandTag, count, queries) { + function onResolvePostgresQuery(query, result, commandTag, count, queries, is_last) { + /// simple queries + if (query[_flags] & SQLQueryFlags.simple) { + // simple can have multiple results or a single result + if (is_last) { + if (queries) { + const queriesIndex = queries.indexOf(query); + if (queriesIndex !== -1) { + queries.splice(queriesIndex, 1); + } + } + try { + query.resolve(query[_results]); + } catch (e) {} + return; + } + $assert(result instanceof SQLResultArray, "Invalid result array"); + // prepare for next query + query[_handle].setPendingValue(new SQLResultArray()); + + if (typeof commandTag === "string") { + if (commandTag.length > 0) { + result.command = commandTag; + } + } else { + result.command = cmds[commandTag]; + } + + result.count = count || 0; + const last_result = query[_results]; + + if (!last_result) { + query[_results] = result; + } else { + if (last_result instanceof SQLResultArray) { + // multiple results + query[_results] = [last_result, result]; + } else { + // 3 or more results + last_result.push(result); + } + } + return; + } + /// prepared statements $assert(result instanceof SQLResultArray, "Invalid result array"); if (typeof commandTag === "string") { if (commandTag.length > 0) { @@ -277,14 +330,12 @@ init( } result.count = count || 0; - if (queries) { const queriesIndex = queries.indexOf(query); if (queriesIndex !== -1) { queries.splice(queriesIndex, 1); } } - try { query.resolve(result); } catch (e) {} @@ -857,6 +908,7 @@ function createConnection( idleTimeout = 0, connectionTimeout = 30 * 1000, maxLifetime = 0, + prepare = true, }, onConnected, onClose, @@ -879,6 +931,7 @@ function createConnection( idleTimeout, connectionTimeout, maxLifetime, + !prepare, ); } @@ -1033,7 +1086,7 @@ function handleQueryFragment(strings, values) { return { final_strings, final_values }; } -function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize, bigint) { +function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize, bigint, simple) { let columns; let { final_strings, final_values } = handleQueryFragment(strings, values); @@ -1053,7 +1106,7 @@ function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize, bigint } } } - return createQuery(sqlString, final_values, new SQLResultArray(), columns, !!bigint); + return createQuery(sqlString, final_values, new SQLResultArray(), columns, !!bigint, !!simple); } class SQLArrayParameter { @@ -1112,6 +1165,7 @@ function loadOptions(o) { onclose, max, bigint; + let prepare = true; const env = Bun.env || {}; var sslMode: SSLMode = SSLMode.disable; @@ -1188,6 +1242,10 @@ function loadOptions(o) { maxLifetime ??= o.maxLifetime; maxLifetime ??= o.max_lifetime; bigint ??= o.bigint; + // we need to explicitly set prepare to false if it is false + if (o.prepare === false) { + prepare = false; + } onconnect ??= o.onconnect; onclose ??= o.onclose; @@ -1271,7 +1329,7 @@ function loadOptions(o) { default: throw new Error(`Unsupported adapter: ${adapter}. Only \"postgres\" is supported for now`); } - const ret: any = { hostname, port, username, password, database, tls, query, sslMode, adapter }; + const ret: any = { hostname, port, username, password, database, tls, query, sslMode, adapter, prepare, bigint }; if (idleTimeout != null) { ret.idleTimeout = idleTimeout; } @@ -1289,8 +1347,6 @@ function loadOptions(o) { } ret.max = max || 10; - ret.bigint = bigint; - return ret; } @@ -1368,13 +1424,11 @@ function SQL(o, e = {}) { function unsafeQuery(strings, values) { try { - return new Query( - strings, - values, - connectionInfo.bigint ? SQLQueryFlags.bigint | SQLQueryFlags.unsafe : SQLQueryFlags.unsafe, - connectionInfo.max, - queryFromPoolHandler, - ); + let flags = connectionInfo.bigint ? SQLQueryFlags.bigint | SQLQueryFlags.unsafe : SQLQueryFlags.unsafe; + if ((values?.length ?? 0) === 0) { + flags |= SQLQueryFlags.simple; + } + return new Query(strings, values, flags, connectionInfo.max, queryFromPoolHandler); } catch (err) { return Promise.reject(err); } @@ -1418,12 +1472,17 @@ function SQL(o, e = {}) { } function unsafeQueryFromTransaction(strings, values, pooledConnection, transactionQueries) { try { + let flags = connectionInfo.bigint + ? SQLQueryFlags.allowUnsafeTransaction | SQLQueryFlags.unsafe | SQLQueryFlags.bigint + : SQLQueryFlags.allowUnsafeTransaction | SQLQueryFlags.unsafe; + + if ((values?.length ?? 0) === 0) { + flags |= SQLQueryFlags.simple; + } const query = new Query( strings, values, - connectionInfo.bigint - ? SQLQueryFlags.allowUnsafeTransaction | SQLQueryFlags.unsafe | SQLQueryFlags.bigint - : SQLQueryFlags.allowUnsafeTransaction | SQLQueryFlags.unsafe, + flags, connectionInfo.max, queryFromTransactionHandler.bind(pooledConnection, transactionQueries), ); @@ -2048,7 +2107,6 @@ function SQL(o, e = {}) { sql.unsafe = (string, args = []) => { return unsafeQuery(string, args); }; - sql.reserve = () => { if (pool.closed) { return Promise.reject(connectionClosedError()); diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index be9f749f2e..91d49721d8 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -330,17 +330,20 @@ pub const PostgresSQLQuery = struct { is_done: bool = false, binary: bool = false, bigint: bool = false, + simple: bool = false, result_mode: PostgresSQLQueryResultMode = .objects, } = .{}, pub usingnamespace JSC.Codegen.JSPostgresSQLQuery; - pub fn getTarget(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject) JSC.JSValue { + pub fn getTarget(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, clean_target: bool) JSC.JSValue { const thisValue = this.thisValue.get(); if (thisValue == .zero) { return .zero; } const target = PostgresSQLQuery.targetGetCached(thisValue) orelse return .zero; - PostgresSQLQuery.targetSetCached(thisValue, globalObject, .zero); + if (clean_target) { + PostgresSQLQuery.targetSetCached(thisValue, globalObject, .zero); + } return target; } @@ -351,6 +354,8 @@ pub const PostgresSQLQuery = struct { binding, /// The query is running running, + /// The query is waiting for a partial response + partial_response, /// The query was successful success, /// The query failed @@ -406,7 +411,7 @@ pub const PostgresSQLQuery = struct { this.status = .fail; const thisValue = this.thisValue.get(); defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject); + const targetValue = this.getTarget(globalObject, true); if (thisValue == .zero or targetValue == .zero) { return; } @@ -427,7 +432,7 @@ pub const PostgresSQLQuery = struct { const thisValue = this.thisValue.get(); defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject); + const targetValue = this.getTarget(globalObject, true); if (thisValue == .zero or targetValue == .zero) { return; } @@ -559,24 +564,29 @@ pub const PostgresSQLQuery = struct { return pending_value; } - pub fn onSuccess(this: *@This(), command_tag_str: []const u8, globalObject: *JSC.JSGlobalObject, connection: JSC.JSValue) void { - this.status = .success; + pub fn onResult(this: *@This(), command_tag_str: []const u8, globalObject: *JSC.JSGlobalObject, connection: JSC.JSValue, is_last: bool) void { this.ref(); defer this.deref(); const thisValue = this.thisValue.get(); - defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject); - defer allowGC(thisValue, globalObject); + const targetValue = this.getTarget(globalObject, is_last); + if (is_last) { + this.status = .success; + } else { + this.status = .partial_response; + } + defer if (is_last) { + allowGC(thisValue, globalObject); + this.thisValue.deinit(); + }; if (thisValue == .zero or targetValue == .zero) { return; } - const tag = CommandTag.init(command_tag_str); - const vm = JSC.VirtualMachine.get(); const function = vm.rareData().postgresql_context.onQueryResolveFn.get().?; const event_loop = vm.eventLoop(); + const tag = CommandTag.init(command_tag_str); event_loop.runCallback(function, globalObject, thisValue, &.{ targetValue, @@ -584,6 +594,7 @@ pub const PostgresSQLQuery = struct { tag.toJSTag(globalObject), tag.toJSNumber(), if (connection == .zero) .undefined else PostgresSQLConnection.queriesGetCached(connection) orelse .undefined, + JSValue.jsBoolean(is_last), }); } @@ -598,7 +609,7 @@ pub const PostgresSQLQuery = struct { } pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { - const arguments = callframe.arguments_old(5).slice(); + const arguments = callframe.arguments_old(6).slice(); var args = JSC.Node.ArgumentsSlice.init(globalThis.bunVM(), arguments); defer args.deinit(); const query = args.nextEat() orelse { @@ -619,8 +630,18 @@ pub const PostgresSQLQuery = struct { const pending_value = args.nextEat() orelse .undefined; const columns = args.nextEat() orelse .undefined; const js_bigint = args.nextEat() orelse .false; - const bigint = js_bigint.isBoolean() and js_bigint.asBoolean(); + const js_simple = args.nextEat() orelse .false; + const bigint = js_bigint.isBoolean() and js_bigint.asBoolean(); + const simple = js_simple.isBoolean() and js_simple.asBoolean(); + if (simple) { + if (values.getLength(globalThis) > 0) { + return globalThis.throwInvalidArguments("simple query cannot have parameters", .{}); + } + if (query.getLength(globalThis) >= std.math.maxInt(i32)) { + return globalThis.throwInvalidArguments("query is too long", .{}); + } + } if (!pending_value.jsType().isArrayLike()) { return globalThis.throwInvalidArgumentType("query", "pendingValue", "Array"); } @@ -635,6 +656,7 @@ pub const PostgresSQLQuery = struct { .thisValue = JSRef.initWeak(this_value), .flags = .{ .bigint = bigint, + .simple = simple, }, }; @@ -657,6 +679,11 @@ pub const PostgresSQLQuery = struct { this.flags.is_done = true; return .undefined; } + pub fn setPendingValue(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + const result = callframe.argument(0); + PostgresSQLQuery.pendingValueSetCached(this.thisValue.get(), globalObject, result); + return .undefined; + } pub fn setMode(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { const js_mode = callframe.argument(0); if (js_mode.isEmptyOrUndefinedOrNull() or !js_mode.isNumber()) { @@ -669,6 +696,7 @@ pub const PostgresSQLQuery = struct { }; return .undefined; } + pub fn doRun(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { var arguments_ = callframe.arguments_old(2); const arguments = arguments_.slice(); @@ -687,56 +715,98 @@ pub const PostgresSQLQuery = struct { const binding_value = PostgresSQLQuery.bindingGetCached(this_value) orelse .zero; var query_str = this.query.toUTF8(bun.default_allocator); defer query_str.deinit(); + var writer = connection.writer(); + + if (this.flags.simple) { + debug("executeQuery", .{}); + + const can_execute = !connection.hasQueryRunning(); + if (can_execute) { + PostgresRequest.executeQuery(query_str.slice(), PostgresSQLConnection.Writer, writer) catch |err| { + if (!globalObject.hasException()) + return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to execute query", err)); + return error.JSError; + }; + connection.flags.is_ready_for_query = false; + this.status = .running; + } else { + this.status = .pending; + } + const stmt = bun.default_allocator.create(PostgresSQLStatement) catch { + return globalObject.throwOutOfMemory(); + }; + // Query is simple and it's the only owner of the statement + stmt.* = .{ + .signature = Signature.empty(), + .ref_count = 1, + .status = .parsing, + }; + this.statement = stmt; + // We need a strong reference to the query so that it doesn't get GC'd + connection.requests.writeItem(this) catch return globalObject.throwOutOfMemory(); + this.ref(); + this.thisValue.upgrade(globalObject); + + PostgresSQLQuery.targetSetCached(this_value, globalObject, query); + if (this.status == .running) { + connection.flushDataAndResetTimeout(); + } else { + connection.resetConnectionTimeout(); + } + return .undefined; + } + const columns_value = PostgresSQLQuery.columnsGetCached(this_value) orelse .undefined; - var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value, connection.prepared_statement_id) catch |err| { + var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value, connection.prepared_statement_id, connection.flags.use_unnamed_prepared_statements) catch |err| { if (!globalObject.hasException()) return globalObject.throwError(err, "failed to generate signature"); return error.JSError; }; - var writer = connection.writer(); - const entry = connection.statements.getOrPut(bun.default_allocator, bun.hash(signature.name)) catch |err| { - signature.deinit(); - return globalObject.throwError(err, "failed to allocate statement"); - }; - const has_params = signature.fields.len > 0; var did_write = false; enqueue: { - if (entry.found_existing) { - this.statement = entry.value_ptr.*; - this.statement.?.ref(); - signature.deinit(); + var connection_entry_value: ?**PostgresSQLStatement = null; + if (!connection.flags.use_unnamed_prepared_statements) { + const entry = connection.statements.getOrPut(bun.default_allocator, bun.hash(signature.name)) catch |err| { + signature.deinit(); + return globalObject.throwError(err, "failed to allocate statement"); + }; + connection_entry_value = entry.value_ptr; + if (entry.found_existing) { + this.statement = connection_entry_value.?.*; + this.statement.?.ref(); + signature.deinit(); - switch (this.statement.?.status) { - .failed => { - // If the statement failed, we need to throw the error - return globalObject.throwValue(this.statement.?.error_response.?.toJS(globalObject)); - }, - .prepared => { - if (!connection.hasQueryRunning()) { - this.flags.binary = this.statement.?.fields.len > 0; - debug("bindAndExecute", .{}); + switch (this.statement.?.status) { + .failed => { + // If the statement failed, we need to throw the error + return globalObject.throwValue(this.statement.?.error_response.?.toJS(globalObject)); + }, + .prepared => { + if (!connection.hasQueryRunning()) { + this.flags.binary = this.statement.?.fields.len > 0; + debug("bindAndExecute", .{}); - // bindAndExecute will bind + execute, it will change to running after binding is complete - PostgresRequest.bindAndExecute(globalObject, this.statement.?, binding_value, columns_value, PostgresSQLConnection.Writer, writer) catch |err| { - if (!globalObject.hasException()) - return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to bind and execute query", err)); - return error.JSError; - }; - connection.flags.is_ready_for_query = false; - this.status = .binding; + // bindAndExecute will bind + execute, it will change to running after binding is complete + PostgresRequest.bindAndExecute(globalObject, this.statement.?, binding_value, columns_value, PostgresSQLConnection.Writer, writer) catch |err| { + if (!globalObject.hasException()) + return globalObject.throwValue(postgresErrorToJS(globalObject, "failed to bind and execute query", err)); + return error.JSError; + }; + connection.flags.is_ready_for_query = false; + this.status = .binding; - did_write = true; - } - }, - .parsing, .pending => {}, + did_write = true; + } + }, + .parsing, .pending => {}, + } + + break :enqueue; } - - break :enqueue; } - const can_execute = !connection.hasQueryRunning(); if (can_execute) { @@ -776,10 +846,16 @@ pub const PostgresSQLQuery = struct { const stmt = bun.default_allocator.create(PostgresSQLStatement) catch { return globalObject.throwOutOfMemory(); }; - connection.prepared_statement_id += 1; - stmt.* = .{ .signature = signature, .ref_count = 2, .status = if (can_execute) .parsing else .pending }; - this.statement = stmt; - entry.value_ptr.* = stmt; + // we only have connection_entry_value if we are using named prepared statements + if (connection_entry_value) |entry_value| { + connection.prepared_statement_id += 1; + stmt.* = .{ .signature = signature, .ref_count = 2, .status = if (can_execute) .parsing else .pending }; + this.statement = stmt; + entry_value.* = stmt; + } else { + stmt.* = .{ .signature = signature, .ref_count = 1, .status = if (can_execute) .parsing else .pending }; + this.statement = stmt; + } } } // We need a strong reference to the query so that it doesn't get GC'd @@ -1045,6 +1121,16 @@ pub const PostgresRequest = struct { try writer.write(&protocol.Sync); } + pub fn executeQuery( + query: []const u8, + comptime Context: type, + writer: protocol.NewWriter(Context), + ) !void { + try protocol.writeQuery(query, Context, writer); + try writer.write(&protocol.Flush); + try writer.write(&protocol.Sync); + } + pub fn onData( connection: *PostgresSQLConnection, comptime Context: type, @@ -1175,6 +1261,7 @@ pub const PostgresSQLConnection = struct { pub const ConnectionFlags = packed struct { is_ready_for_query: bool = false, is_processing_data: bool = false, + use_unnamed_prepared_statements: bool = false, }; pub const TLSStatus = union(enum) { @@ -1740,7 +1827,7 @@ pub const PostgresSQLConnection = struct { pub fn call(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { var vm = globalObject.bunVM(); - const arguments = callframe.arguments_old(13).slice(); + const arguments = callframe.arguments_old(14).slice(); const hostname_str = arguments[0].toBunString(globalObject); defer hostname_str.deref(); const port = arguments[1].coerce(i32, globalObject); @@ -1842,6 +1929,7 @@ pub const PostgresSQLConnection = struct { const idle_timeout = arguments[10].toInt32(); const connection_timeout = arguments[11].toInt32(); const max_lifetime = arguments[12].toInt32(); + const use_unnamed_prepared_statements = arguments[13].asBoolean(); const ptr: *PostgresSQLConnection = try bun.default_allocator.create(PostgresSQLConnection); @@ -1863,6 +1951,9 @@ pub const PostgresSQLConnection = struct { .idle_timeout_interval_ms = @intCast(idle_timeout), .connection_timeout_ms = @intCast(connection_timeout), .max_lifetime_interval_ms = @intCast(max_lifetime), + .flags = .{ + .use_unnamed_prepared_statements = use_unnamed_prepared_statements, + }, }; ptr.updateHasPendingActivity(); @@ -2043,6 +2134,7 @@ pub const PostgresSQLConnection = struct { // in the middle of running .binding, .running, + .partial_response, => { if (js_reason) |reason| { request.onJSError(reason, this.globalObject); @@ -3223,49 +3315,40 @@ pub const PostgresSQLConnection = struct { var req: *PostgresSQLQuery = this.requests.peekItem(0); switch (req.status) { .pending => { - const stmt = req.statement orelse return error.ExpectedStatement; - - switch (stmt.status) { - .failed => { - bun.assert(stmt.error_response != null); - req.onError(stmt.error_response.?, this.globalObject); + if (req.flags.simple) { + debug("executeQuery", .{}); + var query_str = req.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + PostgresRequest.executeQuery(query_str.slice(), PostgresSQLConnection.Writer, this.writer()) catch |err| { + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); req.deref(); this.requests.discard(1); continue; - }, - .prepared => { - const thisValue = req.thisValue.get(); - bun.assert(thisValue != .zero); - const binding_value = PostgresSQLQuery.bindingGetCached(thisValue) orelse .zero; - const columns_value = PostgresSQLQuery.columnsGetCached(thisValue) orelse .zero; - req.flags.binary = stmt.fields.len > 0; + }; + this.flags.is_ready_for_query = false; + req.status = .running; + return; + } else { + const stmt = req.statement orelse return error.ExpectedStatement; - PostgresRequest.bindAndExecute(this.globalObject, stmt, binding_value, columns_value, PostgresSQLConnection.Writer, this.writer()) catch |err| { - req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + switch (stmt.status) { + .failed => { + bun.assert(stmt.error_response != null); + req.onError(stmt.error_response.?, this.globalObject); req.deref(); this.requests.discard(1); continue; - }; - this.flags.is_ready_for_query = false; - req.status = .binding; - return; - }, - .pending => { - // statement is pending, lets write/parse it - var query_str = req.query.toUTF8(bun.default_allocator); - defer query_str.deinit(); - const has_params = stmt.signature.fields.len > 0; - // If it does not have params, we can write and execute immediately in one go - if (!has_params) { + }, + .prepared => { const thisValue = req.thisValue.get(); bun.assert(thisValue != .zero); - // prepareAndQueryWithSignature will write + bind + execute, it will change to running after binding is complete const binding_value = PostgresSQLQuery.bindingGetCached(thisValue) orelse .zero; - PostgresRequest.prepareAndQueryWithSignature(this.globalObject, query_str.slice(), binding_value, PostgresSQLConnection.Writer, this.writer(), &stmt.signature) catch |err| { - stmt.status = .failed; - stmt.error_response = .{ .postgres_error = err }; + const columns_value = PostgresSQLQuery.columnsGetCached(thisValue) orelse .zero; + req.flags.binary = stmt.fields.len > 0; + + PostgresRequest.bindAndExecute(this.globalObject, stmt, binding_value, columns_value, PostgresSQLConnection.Writer, this.writer()) catch |err| { req.onWriteFail(err, this.globalObject, this.getQueriesArray()); req.deref(); this.requests.discard(1); @@ -3274,44 +3357,69 @@ pub const PostgresSQLConnection = struct { }; this.flags.is_ready_for_query = false; req.status = .binding; - stmt.status = .parsing; - return; - } - const connection_writer = this.writer(); - // write query and wait for it to be prepared - PostgresRequest.writeQuery(query_str.slice(), stmt.signature.prepared_statement_name, stmt.signature.fields, PostgresSQLConnection.Writer, connection_writer) catch |err| { - stmt.error_response = .{ .postgres_error = err }; - stmt.status = .failed; + }, + .pending => { + // statement is pending, lets write/parse it + var query_str = req.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + const has_params = stmt.signature.fields.len > 0; + // If it does not have params, we can write and execute immediately in one go + if (!has_params) { + const thisValue = req.thisValue.get(); + bun.assert(thisValue != .zero); + // prepareAndQueryWithSignature will write + bind + execute, it will change to running after binding is complete + const binding_value = PostgresSQLQuery.bindingGetCached(thisValue) orelse .zero; + PostgresRequest.prepareAndQueryWithSignature(this.globalObject, query_str.slice(), binding_value, PostgresSQLConnection.Writer, this.writer(), &stmt.signature) catch |err| { + stmt.status = .failed; + stmt.error_response = .{ .postgres_error = err }; + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + req.deref(); + this.requests.discard(1); - req.onWriteFail(err, this.globalObject, this.getQueriesArray()); - req.deref(); - this.requests.discard(1); + continue; + }; + this.flags.is_ready_for_query = false; + req.status = .binding; + stmt.status = .parsing; - continue; - }; - connection_writer.write(&protocol.Sync) catch |err| { - stmt.error_response = .{ .postgres_error = err }; - stmt.status = .failed; + return; + } + const connection_writer = this.writer(); + // write query and wait for it to be prepared + PostgresRequest.writeQuery(query_str.slice(), stmt.signature.prepared_statement_name, stmt.signature.fields, PostgresSQLConnection.Writer, connection_writer) catch |err| { + stmt.error_response = .{ .postgres_error = err }; + stmt.status = .failed; - req.onWriteFail(err, this.globalObject, this.getQueriesArray()); - req.deref(); - this.requests.discard(1); + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + req.deref(); + this.requests.discard(1); - continue; - }; - this.flags.is_ready_for_query = false; - stmt.status = .parsing; - return; - }, - .parsing => { - // we are still parsing, lets wait for it to be prepared or failed - return; - }, + continue; + }; + connection_writer.write(&protocol.Sync) catch |err| { + stmt.error_response = .{ .postgres_error = err }; + stmt.status = .failed; + + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + req.deref(); + this.requests.discard(1); + + continue; + }; + this.flags.is_ready_for_query = false; + stmt.status = .parsing; + return; + }, + .parsing => { + // we are still parsing, lets wait for it to be prepared or failed + return; + }, + } } }, - .running, .binding => { + .running, .binding, .partial_response => { // if we are binding it will switch to running immediately // if we are running, we need to wait for it to be success or fail return; @@ -3420,7 +3528,14 @@ pub const PostgresSQLConnection = struct { this.setStatus(.connected); this.flags.is_ready_for_query = true; this.socket.setTimeout(300); + defer this.updateRef(); + if (this.current()) |request| { + if (request.status == .partial_response) { + // if is a partial response, just signal that the query is now complete + request.onResult("", this.globalObject, this.js_value, true); + } + } try this.advance(); this.flushData(); @@ -3435,7 +3550,13 @@ pub const PostgresSQLConnection = struct { } debug("-> {s}", .{cmd.command_tag.slice()}); defer this.updateRef(); - request.onSuccess(cmd.command_tag.slice(), this.globalObject, this.js_value); + + if (request.flags.simple) { + // simple queries can have multiple commands + request.onResult(cmd.command_tag.slice(), this.globalObject, this.js_value, false); + } else { + request.onResult(cmd.command_tag.slice(), this.globalObject, this.js_value, true); + } }, .BindComplete => { try reader.eatMessage(protocol.BindComplete); @@ -3713,7 +3834,12 @@ pub const PostgresSQLConnection = struct { .CloseComplete => { try reader.eatMessage(protocol.CloseComplete); var request = this.current() orelse return error.ExpectedRequest; - request.onSuccess("CLOSECOMPLETE", this.globalObject, this.getQueriesArray()); + defer this.updateRef(); + if (request.flags.simple) { + request.onResult("CLOSECOMPLETE", this.globalObject, this.js_value, false); + } else { + request.onResult("CLOSECOMPLETE", this.globalObject, this.js_value, true); + } }, .CopyInResponse => { debug("TODO CopyInResponse", .{}); @@ -3728,8 +3854,12 @@ pub const PostgresSQLConnection = struct { .EmptyQueryResponse => { try reader.eatMessage(protocol.EmptyQueryResponse); var request = this.current() orelse return error.ExpectedRequest; - this.updateRef(); - request.onSuccess("", this.globalObject, this.getQueriesArray()); + defer this.updateRef(); + if (request.flags.simple) { + request.onResult("", this.globalObject, this.js_value, false); + } else { + request.onResult("", this.globalObject, this.js_value, true); + } }, .CopyOutResponse => { debug("TODO CopyOutResponse", .{}); @@ -4057,12 +4187,29 @@ const Signature = struct { query: []const u8, prepared_statement_name: []const u8, + pub fn empty() Signature { + return Signature{ + .fields = &[_]int4{}, + .name = &[_]u8{}, + .query = &[_]u8{}, + .prepared_statement_name = &[_]u8{}, + }; + } + const log = bun.Output.scoped(.PostgresSignature, false); pub fn deinit(this: *Signature) void { - bun.default_allocator.free(this.prepared_statement_name); - bun.default_allocator.free(this.fields); - bun.default_allocator.free(this.name); - bun.default_allocator.free(this.query); + if (this.prepared_statement_name.len > 0) { + bun.default_allocator.free(this.prepared_statement_name); + } + if (this.name.len > 0) { + bun.default_allocator.free(this.name); + } + if (this.fields.len > 0) { + bun.default_allocator.free(this.fields); + } + if (this.query.len > 0) { + bun.default_allocator.free(this.query); + } } pub fn hash(this: *const Signature) u64 { @@ -4072,7 +4219,7 @@ const Signature = struct { return hasher.final(); } - pub fn generate(globalObject: *JSC.JSGlobalObject, query: []const u8, array_value: JSValue, columns: JSValue, prepared_statement_id: u64) !Signature { + pub fn generate(globalObject: *JSC.JSGlobalObject, query: []const u8, array_value: JSValue, columns: JSValue, prepared_statement_id: u64, unnamed: bool) !Signature { var fields = std.ArrayList(int4).init(bun.default_allocator); var name = try std.ArrayList(u8).initCapacity(bun.default_allocator, query.len); @@ -4127,7 +4274,7 @@ const Signature = struct { return error.InvalidQueryBinding; } // max u64 length is 20, max prepared_statement_name length is 63 - const prepared_statement_name = try std.fmt.allocPrint(bun.default_allocator, "P{s}${d}", .{ name.items[0..@min(40, name.items.len)], prepared_statement_id }); + const prepared_statement_name = if (unnamed) "" else try std.fmt.allocPrint(bun.default_allocator, "P{s}${d}", .{ name.items[0..@min(40, name.items.len)], prepared_statement_id }); return Signature{ .prepared_statement_name = prepared_statement_name, @@ -4145,7 +4292,7 @@ pub fn createBinding(globalObject: *JSC.JSGlobalObject) JSValue { binding.put( globalObject, ZigString.static("createQuery"), - JSC.JSFunction.create(globalObject, "createQuery", PostgresSQLQuery.call, 2, .{}), + JSC.JSFunction.create(globalObject, "createQuery", PostgresSQLQuery.call, 6, .{}), ); binding.put( diff --git a/src/sql/postgres/postgres_protocol.zig b/src/sql/postgres/postgres_protocol.zig index 6d36812f30..6582a214c1 100644 --- a/src/sql/postgres/postgres_protocol.zig +++ b/src/sql/postgres/postgres_protocol.zig @@ -1255,6 +1255,14 @@ pub const Flush = [_]u8{'H'} ++ toBytes(Int32(4)); pub const SSLRequest = toBytes(Int32(8)) ++ toBytes(Int32(80877103)); pub const NoData = [_]u8{'n'} ++ toBytes(Int32(4)); +pub fn writeQuery(query: []const u8, comptime Context: type, writer: NewWriter(Context)) !void { + const count: u32 = @sizeOf((u32)) + @as(u32, @intCast(query.len)) + 1; + const header = [_]u8{ + 'Q', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(query); +} pub const SASLInitialResponse = struct { mechanism: Data = .{ .empty = {} }, data: Data = .{ .empty = {} }, @@ -1410,30 +1418,6 @@ pub const Describe = struct { pub const write = writeWrap(@This(), writeInternal).write; }; -pub const Query = struct { - message: Data = .{ .empty = {} }, - - pub fn deinit(this: *@This()) void { - this.message.deinit(); - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const message = this.message.slice(); - const count: u32 = @sizeOf((u32)) + message.len + 1; - const header = [_]u8{ - 'Q', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(message); - } - - pub const write = writeWrap(@This(), writeInternal).write; -}; - pub const NegotiateProtocolVersion = struct { version: int4 = 0, unrecognized_options: std.ArrayListUnmanaged(String) = .{}, diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index 45fad4a3b8..6621fe1858 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -1288,6 +1288,14 @@ if (isDockerEnabled()) { expect(await sql.unsafe("select 1 as x")).toEqual([{ x: 1 }]); }); + test("simple query with multiple statements", async () => { + const result = await sql`select 1 as x;select 2 as x`.simple(); + expect(result).toBeDefined(); + expect(result.length).toEqual(2); + expect(result[0][0].x).toEqual(1); + expect(result[1][0].x).toEqual(2); + }); + // t('unsafe simple includes columns', async() => { // return ['x', (await sql.unsafe('select 1 as x').values()).columns[0].name] // }) @@ -1304,22 +1312,14 @@ if (isDockerEnabled()) { // ] // }) - test.todo("simple query using unsafe with multiple statements", async () => { - // bun always uses prepared statements, so this is not supported - // PostgresError: cannot insert multiple commands into a prepared statement - // errno: "42601", - // code: "ERR_POSTGRES_SYNTAX_ERROR" - expect(await sql.unsafe("select 1 as x;select 2 as x")).toEqual([{ x: 1 }, { x: 2 }]); - // return ["1,2", (await sql.unsafe("select 1 as x;select 2 as x")).map(x => x[0].x).join()]; + test("simple query using unsafe with multiple statements", async () => { + const result = await sql.unsafe("select 1 as x;select 2 as x"); + expect(result).toBeDefined(); + expect(result.length).toEqual(2); + expect(result[0][0].x).toEqual(1); + expect(result[1][0].x).toEqual(2); }); - // t('simple query using simple() with multiple statements', async() => { - // return [ - // '1,2', - // (await sql`select 1 as x;select 2 as x`.simple()).map(x => x[0].x).join() - // ] - // }) - // t('listen and notify', async() => { // const sql = postgres(options) // const channel = 'hello' diff --git a/test/js/sql/tls-sql.test.ts b/test/js/sql/tls-sql.test.ts index 2272ef0bd1..e39c8a034f 100644 --- a/test/js/sql/tls-sql.test.ts +++ b/test/js/sql/tls-sql.test.ts @@ -1,240 +1,271 @@ -import { test, expect, mock } from "bun:test"; +import { test, expect, describe } from "bun:test"; import { getSecret } from "harness"; -import { SQL, sql, postgres } from "bun"; +import { SQL, sql, postgres, randomUUIDv7 } from "bun"; const TLS_POSTGRES_DATABASE_URL = getSecret("TLS_POSTGRES_DATABASE_URL"); -const options = { - url: TLS_POSTGRES_DATABASE_URL, - tls: true, - adapter: "postgresql", - max: 1, - bigint: true, -}; +const PG_TRANSACTION_POOL_SUPABASE_URL = getSecret("PG_TRANSACTION_POOL_SUPABASE_URL"); -if (TLS_POSTGRES_DATABASE_URL) { - test("default sql", async () => { - expect(sql.reserve).toBeDefined(); - expect(sql.options).toBeDefined(); - expect(sql[Symbol.asyncDispose]).toBeDefined(); - expect(sql.begin).toBeDefined(); - expect(sql.beginDistributed).toBeDefined(); - expect(sql.distributed).toBeDefined(); - expect(sql.unsafe).toBeDefined(); - expect(sql.end).toBeDefined(); - expect(sql.close).toBeDefined(); - expect(sql.transaction).toBeDefined(); - expect(sql.distributed).toBeDefined(); - expect(sql.unsafe).toBeDefined(); - expect(sql.commitDistributed).toBeDefined(); - expect(sql.rollbackDistributed).toBeDefined(); - }); - test("default postgres", async () => { - expect(postgres.reserve).toBeDefined(); - expect(postgres.options).toBeDefined(); - expect(postgres[Symbol.asyncDispose]).toBeDefined(); - expect(postgres.begin).toBeDefined(); - expect(postgres.beginDistributed).toBeDefined(); - expect(postgres.distributed).toBeDefined(); - expect(postgres.unsafe).toBeDefined(); - expect(postgres.end).toBeDefined(); - expect(postgres.close).toBeDefined(); - expect(postgres.transaction).toBeDefined(); - expect(postgres.distributed).toBeDefined(); - expect(postgres.unsafe).toBeDefined(); - expect(postgres.commitDistributed).toBeDefined(); - expect(postgres.rollbackDistributed).toBeDefined(); - }); - test("tls (explicit)", async () => { - await using sql = new SQL(options); - const [{ one, two }] = await sql`SELECT 1 as one, '2' as two`; - expect(one).toBe(1); - expect(two).toBe("2"); - await sql.close(); - }); +for (const options of [ + { + url: TLS_POSTGRES_DATABASE_URL, + tls: true, + adapter: "postgresql", + max: 1, + bigint: true, + prepare: true, + transactionPool: false, + }, + { + url: PG_TRANSACTION_POOL_SUPABASE_URL, + tls: true, + adapter: "postgresql", + max: 1, + bigint: true, + prepare: false, + transactionPool: true, + }, - test("Throws on illegal transactions", async () => { - await using sql = new SQL({ ...options, max: 2 }); - const error = await sql`BEGIN`.catch(e => e); - return expect(error.code).toBe("ERR_POSTGRES_UNSAFE_TRANSACTION"); - }); + { + url: TLS_POSTGRES_DATABASE_URL, + tls: true, + adapter: "postgresql", + max: 1, + bigint: true, + prepare: false, + transactionPool: false, + }, +]) { + describe(`${options.transactionPool ? "Transaction Pooling" : `Prepared Statements (${options.prepare ? "on" : "off"})`}`, () => { + test("default sql", async () => { + expect(sql.reserve).toBeDefined(); + expect(sql.options).toBeDefined(); + expect(sql[Symbol.asyncDispose]).toBeDefined(); + expect(sql.begin).toBeDefined(); + expect(sql.beginDistributed).toBeDefined(); + expect(sql.distributed).toBeDefined(); + expect(sql.unsafe).toBeDefined(); + expect(sql.end).toBeDefined(); + expect(sql.close).toBeDefined(); + expect(sql.transaction).toBeDefined(); + expect(sql.distributed).toBeDefined(); + expect(sql.unsafe).toBeDefined(); + expect(sql.commitDistributed).toBeDefined(); + expect(sql.rollbackDistributed).toBeDefined(); + }); + test("default postgres", async () => { + expect(postgres.reserve).toBeDefined(); + expect(postgres.options).toBeDefined(); + expect(postgres[Symbol.asyncDispose]).toBeDefined(); + expect(postgres.begin).toBeDefined(); + expect(postgres.beginDistributed).toBeDefined(); + expect(postgres.distributed).toBeDefined(); + expect(postgres.unsafe).toBeDefined(); + expect(postgres.end).toBeDefined(); + expect(postgres.close).toBeDefined(); + expect(postgres.transaction).toBeDefined(); + expect(postgres.distributed).toBeDefined(); + expect(postgres.unsafe).toBeDefined(); + expect(postgres.commitDistributed).toBeDefined(); + expect(postgres.rollbackDistributed).toBeDefined(); + }); + test("tls (explicit)", async () => { + await using sql = new SQL(options); + const [{ one, two }] = await sql`SELECT 1 as one, '2' as two`; + expect(one).toBe(1); + expect(two).toBe("2"); + await sql.close(); + }); + + test("Throws on illegal transactions", async () => { + await using sql = new SQL({ ...options, max: 2 }); + const error = await sql`BEGIN`.catch(e => e); + return expect(error.code).toBe("ERR_POSTGRES_UNSAFE_TRANSACTION"); + }); + + test.skipIf(options.transactionPool)("Transaction throws", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql`insert into ${sql(random_name)} values('hej')`; + }) + .catch(e => e.errno), + ).toBe("22P02"); + }); + + test.skipIf(options.transactionPool)("Transaction rolls back", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; - test("Transaction throws", async () => { - await using sql = new SQL(options); - await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; - expect( await sql .begin(async sql => { - await sql`insert into test values(1)`; - await sql`insert into test values('hej')`; + await sql`insert into ${sql(random_name)} values(1)`; + await sql`insert into ${sql(random_name)} values('hej')`; }) - .catch(e => e.errno), - ).toBe("22P02"); - }); + .catch(() => { + /* ignore */ + }); - test("Transaction rolls back", async () => { - await using sql = new SQL(options); - await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; + expect((await sql`select a from ${sql(random_name)}`).count).toBe(0); + }); - await sql - .begin(async sql => { - await sql`insert into test values(1)`; - await sql`insert into test values('hej')`; - }) - .catch(() => { - /* ignore */ + test.skipIf(options.transactionPool)("Transaction throws on uncaught savepoint", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql.savepoint(async sql => { + await sql`insert into ${sql(random_name)} values(2)`; + throw new Error("fail"); + }); + }) + .catch(err => err.message), + ).toBe("fail"); + }); + + test.skipIf(options.transactionPool)("Transaction throws on uncaught named savepoint", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql.savepoint("watpoint", async sql => { + await sql`insert into ${sql(random_name)} values(2)`; + throw new Error("fail"); + }); + }) + .catch(() => "fail"), + ).toBe("fail"); + }); + + test("Transaction succeeds on caught savepoint", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + try { + await sql.begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql + .savepoint(async sql => { + await sql`insert into ${sql(random_name)} values(2)`; + throw new Error("please rollback"); + }) + .catch(() => { + /* ignore */ + }); + await sql`insert into ${sql(random_name)} values(3)`; + }); + expect((await sql`select count(1) from ${sql(random_name)}`)[0].count).toBe(2n); + } finally { + await sql`DROP TABLE IF EXISTS ${sql(random_name)}`; + } + }); + + test("Savepoint returns Result", async () => { + let result; + await using sql = new SQL(options); + await sql.begin(async t => { + result = await t.savepoint(s => s`select 1 as x`); }); + expect(result[0]?.x).toBe(1); + }); - expect((await sql`select a from test`).count).toBe(0); - }); + test("Transaction requests are executed implicitly", async () => { + await using sql = new SQL(options); + expect( + ( + await sql.begin(sql => [ + sql`select set_config('bun_sql.test', 'testing', true)`, + sql`select current_setting('bun_sql.test') as x`, + ]) + )[1][0].x, + ).toBe("testing"); + }); - test("Transaction throws on uncaught savepoint", async () => { - await using sql = new SQL(options); - await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; - expect( - await sql - .begin(async sql => { - await sql`insert into test values(1)`; - await sql.savepoint(async sql => { - await sql`insert into test values(2)`; - throw new Error("fail"); - }); - }) - .catch(err => err.message), - ).toBe("fail"); - }); + test("Uncaught transaction request errors bubbles to transaction", async () => { + await using sql = new SQL(options); + expect( + await sql + .begin(sql => [sql`select wat`, sql`select current_setting('bun_sql.test') as x, ${1} as a`]) + .catch(e => e.errno || e), + ).toBe("42703"); + }); - test("Transaction throws on uncaught named savepoint", async () => { - await using sql = new SQL(options); - await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; - expect( - await sql - .begin(async sql => { - await sql`insert into test values(1)`; - await sql.savepoit("watpoint", async sql => { - await sql`insert into test values(2)`; - throw new Error("fail"); - }); - }) - .catch(() => "fail"), - ).toBe("fail"); - }); + test("Transaction rejects with rethrown error", async () => { + await using sql = new SQL(options); + expect( + await sql + .begin(async sql => { + try { + await sql`select exception`; + } catch (ex) { + throw new Error("WAT"); + } + }) + .catch(e => e.message), + ).toBe("WAT"); + }); - test("Transaction succeeds on caught savepoint", async () => { - await using sql = new SQL(options); - const table_id = `test_random${Bun.randomUUIDv7().toString().replace(/-/g, "_")}`; - await sql`CREATE TABLE IF NOT EXISTS ${sql(table_id)} (a int)`; - try { + test("Parallel transactions", async () => { + await using sql = new SQL({ ...options, max: 2 }); + + expect( + (await Promise.all([sql.begin(sql => sql`select 1 as count`), sql.begin(sql => sql`select 1 as count`)])) + .map(x => x[0].count) + .join(""), + ).toBe("11"); + }); + + test("Many transactions at beginning of connection", async () => { + await using sql = new SQL({ ...options, max: 2 }); + const xs = await Promise.all(Array.from({ length: 30 }, () => sql.begin(sql => sql`select 1`))); + return expect(xs.length).toBe(30); + }); + + test("Transactions array", async () => { + await using sql = new SQL(options); + expect( + (await sql.begin(sql => [sql`select 1 as count`, sql`select 1 as count`])).map(x => x[0].count).join(""), + ).toBe("11"); + }); + + test.skipIf(options.transactionPool)("Transaction waits", async () => { + await using sql = new SQL({ ...options, max: 2 }); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; await sql.begin(async sql => { - await sql`insert into ${sql(table_id)} values(1)`; + await sql`insert into ${sql(random_name)} values(1)`; await sql .savepoint(async sql => { - await sql`insert into ${sql(table_id)} values(2)`; + await sql`insert into ${sql(random_name)} values(2)`; throw new Error("please rollback"); }) .catch(() => { /* ignore */ }); - await sql`insert into ${sql(table_id)} values(3)`; + await sql`insert into ${sql(random_name)} values(3)`; }); - expect((await sql`select count(1) from ${sql(table_id)}`)[0].count).toBe(2n); - } finally { - await sql`DROP TABLE IF EXISTS ${sql(table_id)}`; - } - }); - - test("Savepoint returns Result", async () => { - let result; - await using sql = new SQL(options); - await sql.begin(async t => { - result = await t.savepoint(s => s`select 1 as x`); + expect( + ( + await Promise.all([ + sql.begin(async sql => await sql`select 1 as count`), + sql.begin(async sql => await sql`select 1 as count`), + ]) + ) + .map(x => x[0].count) + .join(""), + ).toBe("11"); }); - expect(result[0]?.x).toBe(1); - }); - - test("Transaction requests are executed implicitly", async () => { - await using sql = new SQL(options); - expect( - ( - await sql.begin(sql => [ - sql`select set_config('bun_sql.test', 'testing', true)`, - sql`select current_setting('bun_sql.test') as x`, - ]) - )[1][0].x, - ).toBe("testing"); - }); - - test("Uncaught transaction request errors bubbles to transaction", async () => { - await using sql = new SQL(options); - expect( - await sql - .begin(sql => [sql`select wat`, sql`select current_setting('bun_sql.test') as x, ${1} as a`]) - .catch(e => e.errno || e), - ).toBe("42703"); - }); - - test("Transaction rejects with rethrown error", async () => { - await using sql = new SQL(options); - expect( - await sql - .begin(async sql => { - try { - await sql`select exception`; - } catch (ex) { - throw new Error("WAT"); - } - }) - .catch(e => e.message), - ).toBe("WAT"); - }); - - test("Parallel transactions", async () => { - await using sql = new SQL({ ...options, max: 2 }); - - expect( - (await Promise.all([sql.begin(sql => sql`select 1 as count`), sql.begin(sql => sql`select 1 as count`)])) - .map(x => x[0].count) - .join(""), - ).toBe("11"); - }); - - test("Many transactions at beginning of connection", async () => { - await using sql = new SQL({ ...options, max: 2 }); - const xs = await Promise.all(Array.from({ length: 30 }, () => sql.begin(sql => sql`select 1`))); - return expect(xs.length).toBe(30); - }); - - test("Transactions array", async () => { - await using sql = new SQL(options); - await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; - expect( - (await sql.begin(sql => [sql`select 1 as count`, sql`select 1 as count`])).map(x => x[0].count).join(""), - ).toBe("11"); - }); - - test("Transaction waits", async () => { - await using sql = new SQL({ ...options, max: 2 }); - await sql`CREATE TEMPORARY TABLE IF NOT EXISTS test (a int)`; - await sql.begin(async sql => { - await sql`insert into test values(1)`; - await sql - .savepoint(async sql => { - await sql`insert into test values(2)`; - throw new Error("please rollback"); - }) - .catch(() => { - /* ignore */ - }); - await sql`insert into test values(3)`; - }); - expect( - ( - await Promise.all([ - sql.begin(async sql => await sql`select 1 as count`), - sql.begin(async sql => await sql`select 1 as count`), - ]) - ) - .map(x => x[0].count) - .join(""), - ).toBe("11"); }); }