From b49f6d143e8e2fb5ab448bbc52ffdc78ace87287 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Mon, 11 Nov 2024 14:40:02 -0800 Subject: [PATCH] Postgres client - more progress (#15086) --- src/bun.js/api/postgres.classes.ts | 2 +- src/bun.js/bindings/BunObject.cpp | 10 + src/bun.js/bindings/SQLClient.cpp | 14 +- src/bun.js/bindings/bindings.cpp | 39 +- src/bun.js/bindings/bindings.zig | 19 + src/bun.zig | 23 +- src/js/bun/sql.ts | 148 +- src/output.zig | 41 +- src/sql/postgres.zig | 2297 ++++-------------------- src/sql/postgres/postgres_protocol.zig | 1413 +++++++++++++++ src/sql/postgres/postgres_types.zig | 558 ++++++ test/js/sql/sql-fixture-ref.ts | 21 + test/js/sql/sql.test.ts | 114 +- 13 files changed, 2657 insertions(+), 2042 deletions(-) create mode 100644 src/sql/postgres/postgres_protocol.zig create mode 100644 src/sql/postgres/postgres_types.zig create mode 100644 test/js/sql/sql-fixture-ref.ts diff --git a/src/bun.js/api/postgres.classes.ts b/src/bun.js/api/postgres.classes.ts index ddb27007c7..04097296fc 100644 --- a/src/bun.js/api/postgres.classes.ts +++ b/src/bun.js/api/postgres.classes.ts @@ -59,7 +59,7 @@ export default [ length: 0, }, }, - values: ["pendingValue", "binding"], + values: ["pendingValue", "columns", "binding"], estimatedSize: true, }), ]; diff --git a/src/bun.js/bindings/BunObject.cpp b/src/bun.js/bindings/BunObject.cpp index ec7315754a..b48c9dfe92 100644 --- a/src/bun.js/bindings/BunObject.cpp +++ b/src/bun.js/bindings/BunObject.cpp @@ -278,6 +278,15 @@ static JSValue constructPluginObject(VM& vm, JSObject* bunObject) return pluginFunction; } +static JSValue constructBunSQLObject(VM& vm, JSObject* bunObject) +{ + auto scope = DECLARE_THROW_SCOPE(vm); + auto* globalObject = defaultGlobalObject(bunObject->globalObject()); + JSValue sqlValue = globalObject->internalModuleRegistry()->requireId(globalObject, vm, InternalModuleRegistry::BunSql); + RETURN_IF_EXCEPTION(scope, {}); + return sqlValue.getObject()->get(globalObject, vm.propertyNames->defaultKeyword); +} + extern "C" JSC::EncodedJSValue JSPasswordObject__create(JSGlobalObject*); static JSValue constructPasswordObject(VM& vm, JSObject* bunObject) @@ -630,6 +639,7 @@ JSC_DEFINE_HOST_FUNCTION(functionFileURLToPath, (JSC::JSGlobalObject * globalObj resolveSync BunObject_callback_resolveSync DontDelete|Function 1 revision constructBunRevision ReadOnly|DontDelete|PropertyCallback semver BunObject_getter_wrap_semver ReadOnly|DontDelete|PropertyCallback + sql constructBunSQLObject DontDelete|PropertyCallback serve BunObject_callback_serve DontDelete|Function 1 sha BunObject_callback_sha DontDelete|Function 1 shrink BunObject_callback_shrink DontDelete|Function 1 diff --git a/src/bun.js/bindings/SQLClient.cpp b/src/bun.js/bindings/SQLClient.cpp index 807fffffa3..514af1b664 100644 --- a/src/bun.js/bindings/SQLClient.cpp +++ b/src/bun.js/bindings/SQLClient.cpp @@ -48,6 +48,7 @@ typedef union DataCellValue { int64_t bigint; uint8_t boolean; double date; + double date_with_time_zone; size_t bytea[2]; WTF::StringImpl* json; DataCellArray array; @@ -62,10 +63,11 @@ enum class DataCellTag : uint8_t { Bigint = 4, Boolean = 5, Date = 6, - Bytea = 7, - Json = 8, - Array = 9, - TypedArray = 10, + DateWithTimeZone = 7, + Bytea = 8, + Json = 9, + Array = 10, + TypedArray = 11, }; typedef struct DataCell { @@ -96,9 +98,11 @@ static JSC::JSValue toJS(JSC::VM& vm, JSC::JSGlobalObject* globalObject, DataCel case DataCellTag::Boolean: return jsBoolean(cell.value.boolean); break; - case DataCellTag::Date: + case DataCellTag::DateWithTimeZone: + case DataCellTag::Date: { return JSC::DateInstance::create(vm, globalObject->dateStructure(), cell.value.date); break; + } case DataCellTag::Bytea: { Zig::GlobalObject* zigGlobal = jsCast(globalObject); auto* subclassStructure = zigGlobal->JSBufferSubclassStructure(); diff --git a/src/bun.js/bindings/bindings.cpp b/src/bun.js/bindings/bindings.cpp index 51a7da1e23..22e3439d18 100644 --- a/src/bun.js/bindings/bindings.cpp +++ b/src/bun.js/bindings/bindings.cpp @@ -5628,7 +5628,37 @@ CPP_DECL double JSC__JSValue__getUnixTimestamp(JSC__JSValue timeValue) if (!date) return PNaN; - return date->internalNumber(); + double number = date->internalNumber(); + + return number; +} + +extern "C" JSC::EncodedJSValue JSC__JSValue__getOwnByValue(JSC__JSValue value, JSC__JSGlobalObject* globalObject, JSC__JSValue propertyValue) +{ + auto& vm = globalObject->vm(); + auto scope = DECLARE_THROW_SCOPE(vm); + JSC::JSObject* object = JSValue::decode(value).getObject(); + JSC::JSValue property = JSValue::decode(propertyValue); + uint32_t index; + + PropertySlot slot(object, PropertySlot::InternalMethodType::GetOwnProperty); + if (property.getUInt32(index)) { + if (!object->getOwnPropertySlotByIndex(object, globalObject, index, slot)) + return JSC::JSValue::encode({}); + + RETURN_IF_EXCEPTION(scope, {}); + + return JSC::JSValue::encode(slot.getValue(globalObject, index)); + } else { + auto propertyName = property.toPropertyKey(globalObject); + RETURN_IF_EXCEPTION(scope, {}); + if (!object->getOwnNonIndexPropertySlot(vm, object->structure(), propertyName, slot)) + return JSC::JSValue::encode({}); + + RETURN_IF_EXCEPTION(scope, {}); + + return JSC::JSValue::encode(slot.getValue(globalObject, propertyName)); + } } extern "C" double Bun__parseDate(JSC::JSGlobalObject* globalObject, BunString* str) @@ -5637,6 +5667,13 @@ extern "C" double Bun__parseDate(JSC::JSGlobalObject* globalObject, BunString* s return vm.dateCache.parseDate(globalObject, vm, str->toWTFString()); } +extern "C" EncodedJSValue JSC__JSValue__dateInstanceFromNumber(JSC::JSGlobalObject* globalObject, double unixTimestamp) +{ + auto& vm = globalObject->vm(); + JSC::DateInstance* date = JSC::DateInstance::create(vm, globalObject->dateStructure(), unixTimestamp); + return JSValue::encode(date); +} + extern "C" EncodedJSValue JSC__JSValue__dateInstanceFromNullTerminatedString(JSC::JSGlobalObject* globalObject, const LChar* nullTerminatedChars) { double dateSeconds = WTF::parseDate(std::span(nullTerminatedChars, strlen(reinterpret_cast(nullTerminatedChars)))); diff --git a/src/bun.js/bindings/bindings.zig b/src/bun.js/bindings/bindings.zig index 17bea23083..8eea345d73 100644 --- a/src/bun.js/bindings/bindings.zig +++ b/src/bun.js/bindings/bindings.zig @@ -4327,6 +4327,12 @@ pub const JSValue = enum(JSValueReprInt) { return JSC__JSValue__dateInstanceFromNullTerminatedString(globalObject, str); } + extern fn JSC__JSValue__dateInstanceFromNumber(*JSGlobalObject, f64) JSValue; + pub fn fromDateNumber(globalObject: *JSGlobalObject, value: f64) JSValue { + JSC.markBinding(@src()); + return JSC__JSValue__dateInstanceFromNumber(globalObject, value); + } + extern fn JSBuffer__isBuffer(*JSGlobalObject, JSValue) bool; pub fn isBuffer(value: JSValue, global: *JSGlobalObject) bool { JSC.markBinding(@src()); @@ -5288,6 +5294,13 @@ pub const JSValue = enum(JSValueReprInt) { return if (@intFromEnum(value) != 0) value else return null; } + extern fn JSC__JSValue__getOwnByValue(value: JSValue, globalObject: *JSGlobalObject, propertyValue: JSValue) JSValue; + + pub fn getOwnByValue(this: JSValue, global: *JSGlobalObject, property_value: JSValue) ?JSValue { + const value = JSC__JSValue__getOwnByValue(this, global, property_value); + return if (@intFromEnum(value) != 0) value else return null; + } + pub fn getOwnTruthy(this: JSValue, global: *JSGlobalObject, property_name: anytype) ?JSValue { if (getOwn(this, global, property_name)) |prop| { if (prop == .undefined) return null; @@ -5627,6 +5640,12 @@ pub const JSValue = enum(JSValueReprInt) { }); } + extern fn JSC__JSValue__getUTCTimestamp(globalObject: *JSC.JSGlobalObject, this: JSValue) f64; + /// Calls getTime() - getUTCT + pub fn getUTCTimestamp(this: JSValue, globalObject: *JSC.JSGlobalObject) f64 { + return JSC__JSValue__getUTCTimestamp(globalObject, this); + } + pub const StringFormatter = struct { value: JSC.JSValue, globalObject: *JSC.JSGlobalObject, diff --git a/src/bun.zig b/src/bun.zig index 81803574a8..57162e6ae9 100644 --- a/src/bun.zig +++ b/src/bun.zig @@ -909,6 +909,18 @@ pub fn getRuntimeFeatureFlag(comptime flag: [:0]const u8) bool { }.get(); } +pub fn getenvZAnyCase(key: [:0]const u8) ?[]const u8 { + for (std.os.environ) |lineZ| { + const line = sliceTo(lineZ, 0); + const key_end = strings.indexOfCharUsize(line, '=') orelse line.len; + if (strings.eqlCaseInsensitiveASCII(line[0..key_end], key, true)) { + return line[@min(key_end + 1, line.len)..]; + } + } + + return null; +} + /// This wrapper exists to avoid the call to sliceTo(0) /// Zig's sliceTo(0) is scalar pub fn getenvZ(key: [:0]const u8) ?[]const u8 { @@ -917,16 +929,7 @@ pub fn getenvZ(key: [:0]const u8) ?[]const u8 { } if (comptime Environment.isWindows) { - // Windows UCRT will fill this in for us - for (std.os.environ) |lineZ| { - const line = sliceTo(lineZ, 0); - const key_end = strings.indexOfCharUsize(line, '=') orelse line.len; - if (strings.eqlCaseInsensitiveASCII(line[0..key_end], key, true)) { - return line[@min(key_end + 1, line.len)..]; - } - } - - return null; + return getenvZAnyCase(key); } const ptr = std.c.getenv(key.ptr) orelse return null; diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index 7c1c275b44..adbe607882 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -1,3 +1,9 @@ +const enum QueryStatus { + active = 1 << 1, + cancelled = 1 << 2, + error = 1 << 3, + executed = 1 << 4, +} const cmds = ["", "INSERT", "DELETE", "UPDATE", "MERGE", "SELECT", "MOVE", "FETCH", "COPY"]; const PublicArray = globalThis.Array; @@ -10,11 +16,6 @@ class SQLResultArray extends PublicArray { count; } -const queryStatus_active = 1 << 1; -const queryStatus_cancelled = 1 << 2; -const queryStatus_error = 1 << 3; -const queryStatus_executed = 1 << 4; - const rawMode_values = 1; const rawMode_objects = 2; @@ -50,51 +51,51 @@ class Query extends PublicPromise { this[_reject] = reject_; this[_handle] = handle; this[_handler] = handler; - this[_queryStatus] = handle ? 0 : queryStatus_cancelled; + this[_queryStatus] = handle ? 0 : QueryStatus.cancelled; } async [_run]() { const { [_handle]: handle, [_handler]: handler, [_queryStatus]: status } = this; - if (status & (queryStatus_executed | queryStatus_cancelled)) { + if (status & (QueryStatus.executed | QueryStatus.error | QueryStatus.cancelled)) { return; } - this[_queryStatus] |= queryStatus_executed; + this[_queryStatus] |= QueryStatus.executed; await 1; return handler(this, handle); } get active() { - return (this[_queryStatus] & queryStatus_active) !== 0; + return (this[_queryStatus] & QueryStatus.active) != 0; } set active(value) { const status = this[_queryStatus]; - if (status & (queryStatus_cancelled | queryStatus_error)) { + if (status & (QueryStatus.cancelled | QueryStatus.error)) { return; } if (value) { - this[_queryStatus] |= queryStatus_active; + this[_queryStatus] |= QueryStatus.active; } else { - this[_queryStatus] &= ~queryStatus_active; + this[_queryStatus] &= ~QueryStatus.active; } } get cancelled() { - return (this[_queryStatus] & queryStatus_cancelled) !== 0; + return (this[_queryStatus] & QueryStatus.cancelled) !== 0; } resolve(x) { - this[_queryStatus] &= ~queryStatus_active; + this[_queryStatus] &= ~QueryStatus.active; this[_handle].done(); return this[_resolve](x); } reject(x) { - this[_queryStatus] &= ~queryStatus_active; - this[_queryStatus] |= queryStatus_error; + this[_queryStatus] &= ~QueryStatus.active; + this[_queryStatus] |= QueryStatus.error; this[_handle].done(); return this[_reject](x); @@ -102,12 +103,12 @@ class Query extends PublicPromise { cancel() { var status = this[_queryStatus]; - if (status & queryStatus_cancelled) { + if (status & QueryStatus.cancelled) { return this; } - this[_queryStatus] |= queryStatus_cancelled; + this[_queryStatus] |= QueryStatus.cancelled; - if (status & queryStatus_executed) { + if (status & QueryStatus.executed) { this[_handle].cancel(); } @@ -188,7 +189,9 @@ function createConnection({ hostname, port, username, password, tls, query, data ); } -function normalizeStrings(strings) { +var hasSQLArrayParameter = false; +function normalizeStrings(strings, values) { + hasSQLArrayParameter = false; if ($isJSArray(strings)) { const count = strings.length; if (count === 0) { @@ -196,9 +199,43 @@ function normalizeStrings(strings) { } var out = strings[0]; + + // For now, only support insert queries with array parameters + // + // insert into users ${sql(users)} + // + if (values.length > 0 && typeof values[0] === "object" && values[0] && values[0] instanceof SQLArrayParameter) { + if (values.length > 1) { + throw new Error("Cannot mix array parameters with other values"); + } + hasSQLArrayParameter = true; + const { columns, value } = values[0]; + const groupCount = value.length; + out += `values `; + + let columnIndex = 1; + let columnCount = columns.length; + let lastColumnIndex = columnCount - 1; + + for (var i = 0; i < groupCount; i++) { + out += i > 0 ? `, (` : `(`; + + for (var j = 0; j < lastColumnIndex; j++) { + out += `$${columnIndex++}, `; + } + + out += `$${columnIndex++})`; + } + + for (var i = 1; i < count; i++) { + out += strings[i]; + } + + return out; + } + for (var i = 1; i < count; i++) { - out += "$" + i; - out += strings[i]; + out += `$${i}${strings[i]}`; } return out; } @@ -206,6 +243,39 @@ function normalizeStrings(strings) { return strings + ""; } +class SQLArrayParameter { + value: any; + columns: string[]; + constructor(value, keys) { + if (keys?.length === 0) { + keys = Object.keys(value[0]); + } + + for (let key of keys) { + if (typeof key === "string") { + const asNumber = Number(key); + if (Number.isNaN(asNumber)) { + continue; + } + key = asNumber; + } + + if (typeof key !== "string") { + if (Number.isSafeInteger(key)) { + if (key >= 0 && key <= 64 * 1024) { + continue; + } + } + + throw new Error(`Invalid key: ${key}`); + } + } + + this.value = value; + this.columns = keys; + } +} + function loadOptions(o) { var hostname, port, username, password, database, tls, url, query, adapter; const env = Bun.env; @@ -318,8 +388,21 @@ function SQL(o) { onConnected(err, undefined); } + function doCreateQuery(strings, values) { + const sqlString = normalizeStrings(strings, values); + let columns; + if (hasSQLArrayParameter) { + hasSQLArrayParameter = false; + const v = values[0]; + columns = v.columns; + values = v.value; + } + + return createQuery(sqlString, values, new SQLResultArray(), columns); + } + function connectedSQL(strings, values) { - return new Query(createQuery(normalizeStrings(strings), values, new SQLResultArray()), connectedHandler); + return new Query(doCreateQuery(strings, values), connectedHandler); } function closedSQL(strings, values) { @@ -327,10 +410,27 @@ function SQL(o) { } function pendingSQL(strings, values) { - return new Query(createQuery(normalizeStrings(strings), values, new SQLResultArray()), pendingConnectionHandler); + return new Query(doCreateQuery(strings, values), pendingConnectionHandler); } function sql(strings, ...values) { + /** + * const users = [ + * { + * name: "Alice", + * age: 25, + * }, + * { + * name: "Bob", + * age: 30, + * }, + * ] + * sql`insert into users ${sql(users)}` + */ + if ($isJSArray(strings) && strings[0] && typeof strings[0] === "object") { + return new SQLArrayParameter(strings, values); + } + if (closed) { return closedSQL(strings, values); } diff --git a/src/output.zig b/src/output.zig index 1812e29efb..0a358bb1bc 100644 --- a/src/output.zig +++ b/src/output.zig @@ -704,13 +704,25 @@ pub noinline fn print(comptime fmt: string, args: anytype) callconv(std.builtin. /// To enable all logs, set the environment variable /// BUN_DEBUG_ALL=1 pub const LogFunction = fn (comptime fmt: string, args: anytype) callconv(bun.callconv_inline) void; + pub fn Scoped(comptime tag: anytype, comptime disabled: bool) type { - const tagname = switch (@TypeOf(tag)) { - @Type(.EnumLiteral) => @tagName(tag), - else => tag, + const tagname = comptime brk: { + const input = switch (@TypeOf(tag)) { + @Type(.EnumLiteral) => @tagName(tag), + else => tag, + }; + var ascii_slice: [input.len]u8 = undefined; + for (input, &ascii_slice) |in, *out| { + out.* = std.ascii.toLower(in); + } + break :brk ascii_slice; }; - if (comptime !Environment.isDebug and !Environment.enable_logs) { + return ScopedLogger(&tagname, disabled); +} + +fn ScopedLogger(comptime tagname: []const u8, comptime disabled: bool) type { + if (comptime !Environment.enable_logs) { return struct { pub inline fn isVisible() bool { return false; @@ -732,12 +744,22 @@ pub fn Scoped(comptime tag: anytype, comptime disabled: bool) type { pub fn isVisible() bool { if (!evaluated_disable) { evaluated_disable = true; - if (bun.getenvZ("BUN_DEBUG_" ++ tagname)) |val| { + if (bun.getenvZAnyCase("BUN_DEBUG_" ++ tagname)) |val| { really_disable = strings.eqlComptime(val, "0"); - } else if (bun.getenvZ("BUN_DEBUG_ALL")) |val| { + } else if (bun.getenvZAnyCase("BUN_DEBUG_ALL")) |val| { really_disable = strings.eqlComptime(val, "0"); - } else if (bun.getenvZ("BUN_DEBUG_QUIET_LOGS")) |val| { + } else if (bun.getenvZAnyCase("BUN_DEBUG_QUIET_LOGS")) |val| { really_disable = really_disable or !strings.eqlComptime(val, "0"); + } else { + for (bun.argv) |arg| { + if (strings.eqlCaseInsensitiveASCII(arg, comptime "--debug-" ++ tagname, true)) { + really_disable = false; + break; + } else if (strings.eqlCaseInsensitiveASCII(arg, comptime "--debug-all", true)) { + really_disable = false; + break; + } + } } } return !really_disable; @@ -803,7 +825,10 @@ pub fn Scoped(comptime tag: anytype, comptime disabled: bool) type { } pub fn scoped(comptime tag: anytype, comptime disabled: bool) LogFunction { - return Scoped(tag, disabled).log; + return Scoped( + tag, + disabled, + ).log; } // Valid "colors": diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index e51b0e169c..330f4528d0 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -3,15 +3,17 @@ const JSC = bun.JSC; const String = bun.String; const uws = bun.uws; const std = @import("std"); -const debug = bun.Output.scoped(.Postgres, false); -const int4 = u32; -const PostgresInt32 = int4; -const short = u16; -const PostgresShort = u16; +pub const debug = bun.Output.scoped(.Postgres, false); +pub const int4 = u32; +pub const PostgresInt32 = int4; +pub const int8 = i64; +pub const PostgresInt64 = int8; +pub const short = u16; +pub const PostgresShort = u16; const Crypto = JSC.API.Bun.Crypto; const JSValue = JSC.JSValue; -const Data = union(enum) { +pub const Data = union(enum) { owned: bun.ByteList, temporary: []const u8, empty: void, @@ -72,1906 +74,73 @@ const Data = union(enum) { }; } }; - -pub const protocol = struct { - pub const ArrayList = struct { - array: *std.ArrayList(u8), - - pub fn offset(this: @This()) usize { - return this.array.items.len; - } - - pub fn write(this: @This(), bytes: []const u8) anyerror!void { - try this.array.appendSlice(bytes); - } - - pub fn pwrite(this: @This(), bytes: []const u8, i: usize) anyerror!void { - @memcpy(this.array.items[i..][0..bytes.len], bytes); - } - - pub const Writer = NewWriter(@This()); - }; - - pub const StackReader = struct { - buffer: []const u8 = "", - offset: *usize, - message_start: *usize, - - pub fn markMessageStart(this: @This()) void { - this.message_start.* = this.offset.*; - } - - pub fn ensureLength(this: @This(), length: usize) bool { - return this.buffer.len >= (this.offset.* + length); - } - - pub fn init(buffer: []const u8, offset: *usize, message_start: *usize) protocol.NewReader(StackReader) { - return .{ - .wrapped = .{ - .buffer = buffer, - .offset = offset, - .message_start = message_start, - }, - }; - } - - pub fn peek(this: StackReader) []const u8 { - return this.buffer[this.offset.*..]; - } - pub fn skip(this: StackReader, count: usize) void { - if (this.offset.* + count > this.buffer.len) { - this.offset.* = this.buffer.len; - return; - } - - this.offset.* += count; - } - pub fn ensureCapacity(this: StackReader, count: usize) bool { - return this.buffer.len >= (this.offset.* + count); - } - pub fn read(this: StackReader, count: usize) anyerror!Data { - const offset = this.offset.*; - if (!this.ensureCapacity(count)) { - return error.ShortRead; - } - - this.skip(count); - return Data{ - .temporary = this.buffer[offset..this.offset.*], - }; - } - pub fn readZ(this: StackReader) anyerror!Data { - const remaining = this.peek(); - if (bun.strings.indexOfChar(remaining, 0)) |zero| { - this.skip(zero + 1); - return Data{ - .temporary = remaining[0..zero], - }; - } - - return error.ShortRead; - } - }; - - pub fn NewWriterWrap( - comptime Context: type, - comptime offsetFn_: (fn (ctx: Context) usize), - comptime writeFunction_: (fn (ctx: Context, bytes: []const u8) anyerror!void), - comptime pwriteFunction_: (fn (ctx: Context, bytes: []const u8, offset: usize) anyerror!void), - ) type { - return struct { - wrapped: Context, - - const writeFn = writeFunction_; - const pwriteFn = pwriteFunction_; - const offsetFn = offsetFn_; - pub const Ctx = Context; - - pub const WrappedWriter = @This(); - - pub inline fn write(this: @This(), data: []const u8) anyerror!void { - try writeFn(this.wrapped, data); - } - - pub const LengthWriter = struct { - index: usize, - context: WrappedWriter, - - pub fn write(this: LengthWriter) anyerror!void { - try this.context.pwrite(&Int32(this.context.offset() - this.index), this.index); - } - - pub fn writeExcludingSelf(this: LengthWriter) anyerror!void { - try this.context.pwrite(&Int32(this.context.offset() -| (this.index + 4)), this.index); - } - }; - - pub inline fn length(this: @This()) anyerror!LengthWriter { - const i = this.offset(); - try this.int4(0); - return LengthWriter{ - .index = i, - .context = this, - }; - } - - pub inline fn offset(this: @This()) usize { - return offsetFn(this.wrapped); - } - - pub inline fn pwrite(this: @This(), data: []const u8, i: usize) anyerror!void { - try pwriteFn(this.wrapped, data, i); - } - - pub fn int4(this: @This(), value: PostgresInt32) !void { - try this.write(std.mem.asBytes(&@byteSwap(value))); - } - - pub fn sint4(this: @This(), value: i32) !void { - try this.write(std.mem.asBytes(&@byteSwap(value))); - } - - pub fn @"f64"(this: @This(), value: f64) !void { - try this.write(std.mem.asBytes(&@byteSwap(@as(u64, @bitCast(value))))); - } - - pub fn @"f32"(this: @This(), value: f32) !void { - try this.write(std.mem.asBytes(&@byteSwap(@as(u32, @bitCast(value))))); - } - - pub fn short(this: @This(), value: anytype) !void { - try this.write(std.mem.asBytes(&@byteSwap(@as(u16, @intCast(value))))); - } - - pub fn string(this: @This(), value: []const u8) !void { - try this.write(value); - if (value.len == 0 or value[value.len - 1] != 0) - try this.write(&[_]u8{0}); - } - - pub fn bytes(this: @This(), value: []const u8) !void { - try this.write(value); - if (value.len == 0 or value[value.len - 1] != 0) - try this.write(&[_]u8{0}); - } - - pub fn @"bool"(this: @This(), value: bool) !void { - try this.write(if (value) "t" else "f"); - } - - pub fn @"null"(this: @This()) !void { - try this.int4(std.math.maxInt(PostgresInt32)); - } - - pub fn String(this: @This(), value: bun.String) !void { - if (value.isEmpty()) { - try this.write(&[_]u8{0}); - return; - } - - var sliced = value.toUTF8(bun.default_allocator); - defer sliced.deinit(); - const slice = sliced.slice(); - - try this.write(slice); - if (slice.len == 0 or slice[slice.len - 1] != 0) - try this.write(&[_]u8{0}); - } - }; - } - - pub const FieldType = enum(u8) { - /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a localized translation of one of these. Always present. - S = 'S', - - /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message). This is identical to the S field except that the contents are never localized. This is present only in messages generated by PostgreSQL versions 9.6 and later. - V = 'V', - - /// Code: the SQLSTATE code for the error (see Appendix A). Not localizable. Always present. - C = 'C', - - /// Message: the primary human-readable error message. This should be accurate but terse (typically one line). Always present. - M = 'M', - - /// Detail: an optional secondary error message carrying more detail about the problem. Might run to multiple lines. - D = 'D', - - /// Hint: an optional suggestion what to do about the problem. This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. Might run to multiple lines. - H = 'H', - - /// Position: the field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. The first character has index 1, and positions are measured in characters not bytes. - P = 'P', - - /// Internal position: this is defined the same as the P field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. The q field will always appear when this field appears. - p = 'p', - - /// Internal query: the text of a failed internally-generated command. This could be, for example, an SQL query issued by a PL/pgSQL function. - q = 'q', - - /// Where: an indication of the context in which the error occurred. Presently this includes a call stack traceback of active procedural language functions and internally-generated queries. The trace is one entry per line, most recent first. - W = 'W', - - /// Schema name: if the error was associated with a specific database object, the name of the schema containing that object, if any. - s = 's', - - /// Table name: if the error was associated with a specific table, the name of the table. (Refer to the schema name field for the name of the table's schema.) - t = 't', - - /// Column name: if the error was associated with a specific table column, the name of the column. (Refer to the schema and table name fields to identify the table.) - c = 'c', - - /// Data type name: if the error was associated with a specific data type, the name of the data type. (Refer to the schema name field for the name of the data type's schema.) - d = 'd', - - /// Constraint name: if the error was associated with a specific constraint, the name of the constraint. Refer to fields listed above for the associated table or domain. (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) - n = 'n', - - /// File: the file name of the source-code location where the error was reported. - F = 'F', - - /// Line: the line number of the source-code location where the error was reported. - L = 'L', - - /// Routine: the name of the source-code routine reporting the error. - R = 'R', - - _, - }; - - pub const FieldMessage = union(FieldType) { - S: String, - V: String, - C: String, - M: String, - D: String, - H: String, - P: String, - p: String, - q: String, - W: String, - s: String, - t: String, - c: String, - d: String, - n: String, - F: String, - L: String, - R: String, - - pub fn format(this: FieldMessage, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { - switch (this) { - inline else => |str| { - try std.fmt.format(writer, "{}", .{str}); - }, - } - } - - pub fn deinit(this: *FieldMessage) void { - switch (this.*) { - inline else => |*message| { - message.deref(); - }, - } - } - - pub fn decodeList(comptime Context: type, reader: NewReader(Context)) !std.ArrayListUnmanaged(FieldMessage) { - var messages = std.ArrayListUnmanaged(FieldMessage){}; - while (true) { - const field_int = try reader.int(u8); - if (field_int == 0) break; - const field: FieldType = @enumFromInt(field_int); - - var message = try reader.readZ(); - defer message.deinit(); - if (message.slice().len == 0) break; - - try messages.append(bun.default_allocator, FieldMessage.init(field, message.slice()) catch continue); - } - - return messages; - } - - pub fn init(tag: FieldType, message: []const u8) !FieldMessage { - return switch (tag) { - .S => FieldMessage{ .S = String.createUTF8(message) }, - .V => FieldMessage{ .V = String.createUTF8(message) }, - .C => FieldMessage{ .C = String.createUTF8(message) }, - .M => FieldMessage{ .M = String.createUTF8(message) }, - .D => FieldMessage{ .D = String.createUTF8(message) }, - .H => FieldMessage{ .H = String.createUTF8(message) }, - .P => FieldMessage{ .P = String.createUTF8(message) }, - .p => FieldMessage{ .p = String.createUTF8(message) }, - .q => FieldMessage{ .q = String.createUTF8(message) }, - .W => FieldMessage{ .W = String.createUTF8(message) }, - .s => FieldMessage{ .s = String.createUTF8(message) }, - .t => FieldMessage{ .t = String.createUTF8(message) }, - .c => FieldMessage{ .c = String.createUTF8(message) }, - .d => FieldMessage{ .d = String.createUTF8(message) }, - .n => FieldMessage{ .n = String.createUTF8(message) }, - .F => FieldMessage{ .F = String.createUTF8(message) }, - .L => FieldMessage{ .L = String.createUTF8(message) }, - .R => FieldMessage{ .R = String.createUTF8(message) }, - else => error.UnknownFieldType, - }; - } - }; - - pub fn NewReaderWrap( - comptime Context: type, - comptime markMessageStartFn_: (fn (ctx: Context) void), - comptime peekFn_: (fn (ctx: Context) []const u8), - comptime skipFn_: (fn (ctx: Context, count: usize) void), - comptime ensureCapacityFn_: (fn (ctx: Context, count: usize) bool), - comptime readFunction_: (fn (ctx: Context, count: usize) anyerror!Data), - comptime readZ_: (fn (ctx: Context) anyerror!Data), - ) type { - return struct { - wrapped: Context, - const readFn = readFunction_; - const readZFn = readZ_; - const ensureCapacityFn = ensureCapacityFn_; - const skipFn = skipFn_; - const peekFn = peekFn_; - const markMessageStartFn = markMessageStartFn_; - - pub const Ctx = Context; - - pub inline fn markMessageStart(this: @This()) void { - markMessageStartFn(this.wrapped); - } - - pub inline fn read(this: @This(), count: usize) anyerror!Data { - return try readFn(this.wrapped, count); - } - - pub inline fn eatMessage(this: @This(), comptime msg_: anytype) anyerror!void { - const msg = msg_[1..]; - try this.ensureCapacity(msg.len); - - var input = try readFn(this.wrapped, msg.len); - defer input.deinit(); - if (bun.strings.eqlComptime(input.slice(), msg)) return; - return error.InvalidMessage; - } - - pub fn skip(this: @This(), count: usize) anyerror!void { - skipFn(this.wrapped, count); - } - - pub fn peek(this: @This()) []const u8 { - return peekFn(this.wrapped); - } - - pub inline fn readZ(this: @This()) anyerror!Data { - return try readZFn(this.wrapped); - } - - pub inline fn ensureCapacity(this: @This(), count: usize) anyerror!void { - if (!ensureCapacityFn(this.wrapped, count)) { - return error.ShortRead; - } - } - - pub fn int(this: @This(), comptime Int: type) !Int { - var data = try this.read(@sizeOf((Int))); - defer data.deinit(); - if (comptime Int == u8) { - return @as(Int, data.slice()[0]); - } - return @byteSwap(@as(Int, @bitCast(data.slice()[0..@sizeOf(Int)].*))); - } - - pub fn peekInt(this: @This(), comptime Int: type) ?Int { - const remain = this.peek(); - if (remain.len < @sizeOf(Int)) { - return null; - } - return @byteSwap(@as(Int, @bitCast(remain[0..@sizeOf(Int)].*))); - } - - pub fn expectInt(this: @This(), comptime Int: type, comptime value: comptime_int) !bool { - const actual = try this.int(Int); - return actual == value; - } - - pub fn int4(this: @This()) !PostgresInt32 { - return this.int(PostgresInt32); - } - - pub fn short(this: @This()) !PostgresShort { - return this.int(PostgresShort); - } - - pub fn length(this: @This()) !PostgresInt32 { - const expected = try this.int(PostgresInt32); - if (expected > -1) { - try this.ensureCapacity(@intCast(expected -| 4)); - } - - return expected; - } - - pub const bytes = read; - - pub fn String(this: @This()) !bun.String { - var result = try this.readZ(); - defer result.deinit(); - return bun.String.fromUTF8(result.slice()); - } - }; - } - - pub fn NewReader(comptime Context: type) type { - return NewReaderWrap(Context, Context.markMessageStart, Context.peek, Context.skip, Context.ensureLength, Context.read, Context.readZ); - } - - pub fn NewWriter(comptime Context: type) type { - return NewWriterWrap(Context, Context.offset, Context.write, Context.pwrite); - } - - fn decoderWrap(comptime Container: type, comptime decodeFn: anytype) type { - return struct { - pub fn decode(this: *Container, context: anytype) anyerror!void { - const Context = @TypeOf(context); - try decodeFn(this, Context, NewReader(Context){ .wrapped = context }); - } - }; - } - - fn writeWrap(comptime Container: type, comptime writeFn: anytype) type { - return struct { - pub fn write(this: *Container, context: anytype) anyerror!void { - const Context = @TypeOf(context); - try writeFn(this, Context, NewWriter(Context){ .wrapped = context }); - } - }; - } - - pub const Authentication = union(enum) { - Ok: void, - ClearTextPassword: struct {}, - MD5Password: struct { - salt: [4]u8, - }, - KerberosV5: struct {}, - SCMCredential: struct {}, - GSS: struct {}, - GSSContinue: struct { - data: Data, - }, - SSPI: struct {}, - SASL: struct {}, - SASLContinue: struct { - data: Data, - r: []const u8, - s: []const u8, - i: []const u8, - - pub fn iterationCount(this: *const @This()) !u32 { - return try std.fmt.parseInt(u32, this.i, 0); - } - }, - SASLFinal: struct { - data: Data, - }, - Unknown: void, - - pub fn deinit(this: *@This()) void { - switch (this.*) { - .MD5Password => {}, - .SASL => {}, - .SASLContinue => { - this.SASLContinue.data.zdeinit(); - }, - .SASLFinal => { - this.SASLFinal.data.zdeinit(); - }, - else => {}, - } - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const message_length = try reader.length(); - - switch (try reader.int4()) { - 0 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ .Ok = {} }; - }, - 2 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ - .KerberosV5 = .{}, - }; - }, - 3 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ - .ClearTextPassword = .{}, - }; - }, - 5 => { - if (message_length != 12) return error.InvalidMessageLength; - if (!try reader.expectInt(u32, 5)) { - return error.InvalidMessage; - } - var salt_data = try reader.bytes(4); - defer salt_data.deinit(); - this.* = .{ - .MD5Password = .{ - .salt = salt_data.slice()[0..4].*, - }, - }; - }, - 7 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ - .GSS = .{}, - }; - }, - - 8 => { - if (message_length < 9) return error.InvalidMessageLength; - const bytes = try reader.read(message_length - 8); - this.* = .{ - .GSSContinue = .{ - .data = bytes, - }, - }; - }, - 9 => { - if (message_length != 8) return error.InvalidMessageLength; - this.* = .{ - .SSPI = .{}, - }; - }, - - 10 => { - if (message_length < 9) return error.InvalidMessageLength; - try reader.skip(message_length - 8); - this.* = .{ - .SASL = .{}, - }; - }, - - 11 => { - if (message_length < 9) return error.InvalidMessageLength; - var bytes = try reader.bytes(message_length - 8); - errdefer { - bytes.deinit(); - } - - var iter = bun.strings.split(bytes.slice(), ","); - var r: ?[]const u8 = null; - var i: ?[]const u8 = null; - var s: ?[]const u8 = null; - - while (iter.next()) |item| { - if (item.len > 2) { - const key = item[0]; - const after_equals = item[2..]; - if (key == 'r') { - r = after_equals; - } else if (key == 's') { - s = after_equals; - } else if (key == 'i') { - i = after_equals; - } - } - } - - if (r == null) { - debug("Missing r", .{}); - } - - if (s == null) { - debug("Missing s", .{}); - } - - if (i == null) { - debug("Missing i", .{}); - } - - this.* = .{ - .SASLContinue = .{ - .data = bytes, - .r = r orelse return error.InvalidMessage, - .s = s orelse return error.InvalidMessage, - .i = i orelse return error.InvalidMessage, - }, - }; - }, - - 12 => { - if (message_length < 9) return error.InvalidMessageLength; - const remaining: usize = message_length - 8; - - const bytes = try reader.read(remaining); - this.* = .{ - .SASLFinal = .{ - .data = bytes, - }, - }; - }, - - else => { - this.* = .{ .Unknown = {} }; - }, - } - } - - pub const decode = decoderWrap(Authentication, decodeInternal).decode; - }; - - pub const ParameterStatus = struct { - name: Data = .{ .empty = {} }, - value: Data = .{ .empty = {} }, - - pub fn deinit(this: *@This()) void { - this.name.deinit(); - this.value.deinit(); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - this.* = .{ - .name = try reader.readZ(), - .value = try reader.readZ(), - }; - } - - pub const decode = decoderWrap(ParameterStatus, decodeInternal).decode; - }; - - pub const BackendKeyData = struct { - process_id: u32 = 0, - secret_key: u32 = 0, - pub const decode = decoderWrap(BackendKeyData, decodeInternal).decode; - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - if (!try reader.expectInt(u32, 12)) { - return error.InvalidBackendKeyData; - } - - this.* = .{ - .process_id = @bitCast(try reader.int4()), - .secret_key = @bitCast(try reader.int4()), - }; - } - }; - - pub const ErrorResponse = struct { - messages: std.ArrayListUnmanaged(FieldMessage) = .{}, - - pub fn format(formatter: ErrorResponse, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { - for (formatter.messages.items) |message| { - try std.fmt.format(writer, "{}\n", .{message}); - } - } - - pub fn deinit(this: *ErrorResponse) void { - for (this.messages.items) |*message| { - message.deinit(); - } - this.messages.deinit(bun.default_allocator); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var remaining_bytes = try reader.length(); - if (remaining_bytes < 4) return error.InvalidMessageLength; - remaining_bytes -|= 4; - - if (remaining_bytes > 0) { - this.* = .{ - .messages = try FieldMessage.decodeList(Container, reader), - }; - } - } - - pub const decode = decoderWrap(ErrorResponse, decodeInternal).decode; - - pub fn toJS(this: ErrorResponse, globalObject: *JSC.JSGlobalObject) JSValue { - var b = bun.StringBuilder{}; - defer b.deinit(bun.default_allocator); - - for (this.messages.items) |msg| { - b.cap += switch (msg) { - inline else => |m| m.utf8ByteLength(), - } + 1; - } - b.allocate(bun.default_allocator) catch {}; - - for (this.messages.items) |msg| { - var str = switch (msg) { - inline else => |m| m.toUTF8(bun.default_allocator), - }; - defer str.deinit(); - _ = b.append(str.slice()); - _ = b.append("\n"); - } - - return globalObject.createSyntaxErrorInstance("Postgres error occurred\n{s}", .{b.allocatedSlice()[0..b.len]}); - } - }; - - pub const PortalOrPreparedStatement = union(enum) { - portal: []const u8, - prepared_statement: []const u8, - - pub fn slice(this: @This()) []const u8 { - return switch (this) { - .portal => this.portal, - .prepared_statement => this.prepared_statement, - }; - } - - pub fn tag(this: @This()) u8 { - return switch (this) { - .portal => 'P', - .prepared_statement => 'S', - }; - } - }; - - /// Close (F) - /// Byte1('C') - /// - Identifies the message as a Close command. - /// Int32 - /// - Length of message contents in bytes, including self. - /// Byte1 - /// - 'S' to close a prepared statement; or 'P' to close a portal. - /// String - /// - The name of the prepared statement or portal to close (an empty string selects the unnamed prepared statement or portal). - pub const Close = struct { - p: PortalOrPreparedStatement, - - fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const p = this.p; - const count: u32 = @sizeOf((u32)) + 1 + p.slice().len + 1; - const header = [_]u8{ - 'C', - } ++ @byteSwap(count) ++ [_]u8{ - p.tag(), - }; - try writer.write(&header); - try writer.write(p.slice()); - try writer.write(&[_]u8{0}); - } - - pub const write = writeWrap(@This(), writeInternal); - }; - - pub const CloseComplete = [_]u8{'3'} ++ toBytes(Int32(4)); - pub const EmptyQueryResponse = [_]u8{'I'} ++ toBytes(Int32(4)); - pub const Terminate = [_]u8{'X'} ++ toBytes(Int32(4)); - - fn Int32(value: anytype) [4]u8 { - return @bitCast(@byteSwap(@as(int4, @intCast(value)))); - } - - const toBytes = std.mem.toBytes; - - pub const TransactionStatusIndicator = enum(u8) { - /// if idle (not in a transaction block) - I = 'I', - - /// if in a transaction block - T = 'T', - - /// if in a failed transaction block - E = 'E', - - _, - }; - - pub const ReadyForQuery = struct { - status: TransactionStatusIndicator = .I, - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - const status = try reader.int(u8); - this.* = .{ - .status = @enumFromInt(status), - }; - } - - pub const decode = decoderWrap(ReadyForQuery, decodeInternal).decode; - }; - - pub const FormatCode = enum { - text, - binary, - - pub fn from(value: short) !FormatCode { - return switch (value) { - 0 => .text, - 1 => .binary, - else => error.UnknownFormatCode, - }; - } - }; - - pub const null_int4 = 4294967295; - - pub const DataRow = struct { - pub fn decode(context: anytype, comptime ContextType: type, reader: NewReader(ContextType), comptime forEach: fn (@TypeOf(context), index: u32, bytes: ?*Data) anyerror!bool) anyerror!void { - var remaining_bytes = try reader.length(); - remaining_bytes -|= 4; - - const remaining_fields: usize = @intCast(@max(try reader.short(), 0)); - - for (0..remaining_fields) |index| { - const byte_length = try reader.int4(); - switch (byte_length) { - 0 => break, - null_int4 => { - if (!try forEach(context, @intCast(index), null)) break; - }, - else => { - var bytes = try reader.bytes(@intCast(byte_length)); - if (!try forEach(context, @intCast(index), &bytes)) break; - }, - } - } - } - }; - - pub const BindComplete = [_]u8{'2'} ++ toBytes(Int32(4)); - - pub const FieldDescription = struct { - name: Data = .{ .empty = {} }, - table_oid: int4 = 0, - column_index: short = 0, - type_oid: int4 = 0, - - pub fn typeTag(this: @This()) types.Tag { - return @enumFromInt(@as(short, @truncate(this.type_oid))); - } - - pub fn deinit(this: *@This()) void { - this.name.deinit(); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var name = try reader.readZ(); - errdefer { - name.deinit(); - } - // If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. - // Int16 - // If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. - // Int32 - // The object ID of the field's data type. - // Int16 - // The data type size (see pg_type.typlen). Note that negative values denote variable-width types. - // Int32 - // The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. - // Int16 - // The format code being used for the field. Currently will be zero (text) or one (binary). In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero. - this.* = .{ - .table_oid = try reader.int4(), - .column_index = try reader.short(), - .type_oid = try reader.int4(), - .name = .{ .owned = try name.toOwned() }, - }; - - try reader.skip(2 + 4 + 2); - } - - pub const decode = decoderWrap(FieldDescription, decodeInternal).decode; - }; - - pub const RowDescription = struct { - fields: []const FieldDescription = &[_]FieldDescription{}, - pub fn deinit(this: *@This()) void { - for (this.fields) |*field| { - @constCast(field).deinit(); - } - - bun.default_allocator.free(this.fields); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var remaining_bytes = try reader.length(); - remaining_bytes -|= 4; - - const field_count: usize = @intCast(@max(try reader.short(), 0)); - var fields = try bun.default_allocator.alloc( - FieldDescription, - field_count, - ); - var remaining = fields; - errdefer { - for (fields[0 .. field_count - remaining.len]) |*field| { - field.deinit(); - } - - bun.default_allocator.free(fields); - } - while (remaining.len > 0) { - try remaining[0].decodeInternal(Container, reader); - remaining = remaining[1..]; - } - this.* = .{ - .fields = fields, - }; - } - - pub const decode = decoderWrap(RowDescription, decodeInternal).decode; - }; - - pub const ParameterDescription = struct { - parameters: []int4 = &[_]int4{}, - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var remaining_bytes = try reader.length(); - remaining_bytes -|= 4; - - const count = try reader.short(); - const parameters = try bun.default_allocator.alloc(int4, @intCast(@max(count, 0))); - - var data = try reader.read(@as(usize, @intCast(@max(count, 0))) * @sizeOf((int4))); - defer data.deinit(); - const input_params: []align(1) const int4 = toInt32Slice(int4, data.slice()); - for (input_params, parameters) |src, *dest| { - dest.* = @byteSwap(src); - } - - this.* = .{ - .parameters = parameters, - }; - } - - pub const decode = decoderWrap(ParameterDescription, decodeInternal).decode; - }; - - // workaround for zig compiler TODO - fn toInt32Slice(comptime Int: type, slice: []const u8) []align(1) const Int { - return @as([*]align(1) const Int, @ptrCast(slice.ptr))[0 .. slice.len / @sizeOf((Int))]; - } - - pub const NotificationResponse = struct { - pid: int4 = 0, - channel: bun.ByteList = .{}, - payload: bun.ByteList = .{}, - - pub fn deinit(this: *@This()) void { - this.channel.deinitWithAllocator(bun.default_allocator); - this.payload.deinitWithAllocator(bun.default_allocator); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - this.* = .{ - .pid = try reader.int4(), - .channel = (try reader.readZ()).toOwned(), - .payload = (try reader.readZ()).toOwned(), - }; - } - - pub const decode = decoderWrap(NotificationResponse, decodeInternal).decode; - }; - - pub const CommandComplete = struct { - command_tag: Data = .{ .empty = {} }, - - pub fn deinit(this: *@This()) void { - this.command_tag.deinit(); - } - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - const tag = try reader.readZ(); - this.* = .{ - .command_tag = tag, - }; - } - - pub const decode = decoderWrap(CommandComplete, decodeInternal).decode; - }; - - pub const Parse = struct { - name: []const u8 = "", - query: []const u8 = "", - params: []const int4 = &.{}, - - pub fn deinit(this: *Parse) void { - _ = this; - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const parameters = this.params; - const count: usize = @sizeOf((u32)) + @sizeOf(u16) + (parameters.len * @sizeOf(u32)) + @max(zCount(this.name), 1) + @max(zCount(this.query), 1); - const header = [_]u8{ - 'P', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(this.name); - try writer.string(this.query); - try writer.short(parameters.len); - for (parameters) |parameter| { - try writer.int4(parameter); - } - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const ParseComplete = [_]u8{'1'} ++ toBytes(Int32(4)); - - pub const PasswordMessage = struct { - password: Data = .{ .empty = {} }, - - pub fn deinit(this: *PasswordMessage) void { - this.password.deinit(); - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const password = this.password.slice(); - const count: usize = @sizeOf((u32)) + password.len + 1; - const header = [_]u8{ - 'p', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(password); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const CopyData = struct { - data: Data = .{ .empty = {} }, - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - const length = try reader.length(); - - const data = try reader.read(@intCast(length -| 5)); - this.* = .{ - .data = data, - }; - } - - pub const decode = decoderWrap(CopyData, decodeInternal).decode; - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const data = this.data.slice(); - const count: u32 = @sizeOf((u32)) + data.len + 1; - const header = [_]u8{ - 'd', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(data); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const CopyDone = [_]u8{'c'} ++ toBytes(Int32(4)); - pub const Sync = [_]u8{'S'} ++ toBytes(Int32(4)); - pub const Flush = [_]u8{'H'} ++ toBytes(Int32(4)); - pub const SSLRequest = toBytes(Int32(8)) ++ toBytes(Int32(80877103)); - pub const NoData = [_]u8{'n'} ++ toBytes(Int32(4)); - - pub const SASLInitialResponse = struct { - mechanism: Data = .{ .empty = {} }, - data: Data = .{ .empty = {} }, - - pub fn deinit(this: *SASLInitialResponse) void { - this.mechanism.deinit(); - this.data.deinit(); - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const mechanism = this.mechanism.slice(); - const data = this.data.slice(); - const count: usize = @sizeOf(u32) + mechanism.len + 1 + data.len + @sizeOf(u32); - const header = [_]u8{ - 'p', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(mechanism); - try writer.int4(@truncate(data.len)); - try writer.write(data); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const SASLResponse = struct { - data: Data = .{ .empty = {} }, - - pub fn deinit(this: *SASLResponse) void { - this.data.deinit(); - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const data = this.data.slice(); - const count: usize = @sizeOf(u32) + data.len; - const header = [_]u8{ - 'p', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.write(data); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const StartupMessage = struct { - user: Data, - database: Data, - options: Data = Data{ .empty = {} }, - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const user = this.user.slice(); - const database = this.database.slice(); - const options = this.options.slice(); - - const count: usize = @sizeOf((int4)) + @sizeOf((int4)) + zFieldCount("user", user) + zFieldCount("database", database) + zFieldCount("client_encoding", "UTF8") + zFieldCount("", options) + 1; - - const header = toBytes(Int32(@as(u32, @truncate(count)))); - try writer.write(&header); - try writer.int4(196608); - - try writer.string("user"); - if (user.len > 0) - try writer.string(user); - - try writer.string("database"); - - if (database.len == 0) { - // The database to connect to. Defaults to the user name. - try writer.string(user); - } else { - try writer.string(database); - } - - try writer.string("client_encoding"); - try writer.string("UTF8"); - - if (options.len > 0) - try writer.string(options); - - try writer.write(&[_]u8{0}); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - fn zCount(slice: []const u8) usize { - return if (slice.len > 0) slice.len + 1 else 0; - } - - fn zFieldCount(prefix: []const u8, slice: []const u8) usize { - if (slice.len > 0) { - return zCount(prefix) + zCount(slice); - } - - return zCount(prefix); - } - - pub const Execute = struct { - max_rows: int4 = 0, - p: PortalOrPreparedStatement, - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - try writer.write("E"); - const length = try writer.length(); - if (this.p == .portal) - try writer.string(this.p.portal) - else - try writer.write(&[_]u8{0}); - try writer.int4(this.max_rows); - try length.write(); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const Describe = struct { - p: PortalOrPreparedStatement, - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const message = this.p.slice(); - try writer.write(&[_]u8{ - 'D', - }); - const length = try writer.length(); - try writer.write(&[_]u8{ - this.p.tag(), - }); - try writer.string(message); - try length.write(); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const Query = struct { - message: Data = .{ .empty = {} }, - - pub fn deinit(this: *@This()) void { - this.message.deinit(); - } - - pub fn writeInternal( - this: *const @This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const message = this.message.slice(); - const count: u32 = @sizeOf((u32)) + message.len + 1; - const header = [_]u8{ - 'Q', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(message); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const NegotiateProtocolVersion = struct { - version: int4 = 0, - unrecognized_options: std.ArrayListUnmanaged(String) = .{}, - - pub fn decodeInternal( - this: *@This(), - comptime Container: type, - reader: NewReader(Container), - ) !void { - const length = try reader.length(); - bun.assert(length >= 4); - - const version = try reader.int4(); - this.* = .{ - .version = version, - }; - - const unrecognized_options_count: u32 = @intCast(@max(try reader.int4(), 0)); - try this.unrecognized_options.ensureTotalCapacity(bun.default_allocator, unrecognized_options_count); - errdefer { - for (this.unrecognized_options.items) |*option| { - option.deinit(); - } - this.unrecognized_options.deinit(bun.default_allocator); - } - for (0..unrecognized_options_count) |_| { - var option = try reader.readZ(); - if (option.slice().len == 0) break; - defer option.deinit(); - this.unrecognized_options.appendAssumeCapacity( - String.fromUTF8(option), - ); - } - } - }; - - pub const NoticeResponse = struct { - messages: std.ArrayListUnmanaged(FieldMessage) = .{}, - pub fn deinit(this: *NoticeResponse) void { - for (this.messages.items) |*message| { - message.deinit(); - } - this.messages.deinit(bun.default_allocator); - } - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - var remaining_bytes = try reader.length(); - remaining_bytes -|= 4; - - if (remaining_bytes > 0) { - this.* = .{ - .messages = try FieldMessage.decodeList(Container, reader), - }; - } - } - pub const decode = decoderWrap(NoticeResponse, decodeInternal).decode; - }; - - pub const CopyFail = struct { - message: Data = .{ .empty = {} }, - - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - _ = try reader.int4(); - - const message = try reader.readZ(); - this.* = .{ - .message = message, - }; - } - - pub const decode = decoderWrap(CopyFail, decodeInternal).decode; - - pub fn writeInternal( - this: *@This(), - comptime Context: type, - writer: NewWriter(Context), - ) !void { - const message = this.message.slice(); - const count: u32 = @sizeOf((u32)) + message.len + 1; - const header = [_]u8{ - 'f', - } ++ toBytes(Int32(count)); - try writer.write(&header); - try writer.string(message); - } - - pub const write = writeWrap(@This(), writeInternal).write; - }; - - pub const CopyInResponse = struct { - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - _ = reader; - _ = this; - TODO(@This()); - } - - pub const decode = decoderWrap(CopyInResponse, decodeInternal).decode; - }; - - pub const CopyOutResponse = struct { - pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { - _ = reader; - _ = this; - TODO(@This()); - } - - pub const decode = decoderWrap(CopyInResponse, decodeInternal).decode; - }; - - fn TODO(comptime Type: type) !void { - std.debug.panic("TODO: not implemented {s}", .{bun.meta.typeBaseName(@typeName(Type))}); - } -}; - -pub const types = struct { - // select b.typname, b.oid, b.typarray - // from pg_catalog.pg_type a - // left join pg_catalog.pg_type b on b.oid = a.typelem - // where a.typcategory = 'A' - // group by b.oid, b.typarray - // order by b.oid - // ; - // typname | oid | typarray - // ---------------------------------------+-------+---------- - // bool | 16 | 1000 - // bytea | 17 | 1001 - // char | 18 | 1002 - // name | 19 | 1003 - // int8 | 20 | 1016 - // int2 | 21 | 1005 - // int2vector | 22 | 1006 - // int4 | 23 | 1007 - // regproc | 24 | 1008 - // text | 25 | 1009 - // oid | 26 | 1028 - // tid | 27 | 1010 - // xid | 28 | 1011 - // cid | 29 | 1012 - // oidvector | 30 | 1013 - // pg_type | 71 | 210 - // pg_attribute | 75 | 270 - // pg_proc | 81 | 272 - // pg_class | 83 | 273 - // json | 114 | 199 - // xml | 142 | 143 - // point | 600 | 1017 - // lseg | 601 | 1018 - // path | 602 | 1019 - // box | 603 | 1020 - // polygon | 604 | 1027 - // line | 628 | 629 - // cidr | 650 | 651 - // float4 | 700 | 1021 - // float8 | 701 | 1022 - // circle | 718 | 719 - // macaddr8 | 774 | 775 - // money | 790 | 791 - // macaddr | 829 | 1040 - // inet | 869 | 1041 - // aclitem | 1033 | 1034 - // bpchar | 1042 | 1014 - // varchar | 1043 | 1015 - // date | 1082 | 1182 - // time | 1083 | 1183 - // timestamp | 1114 | 1115 - // timestamptz | 1184 | 1185 - // interval | 1186 | 1187 - // pg_database | 1248 | 12052 - // timetz | 1266 | 1270 - // bit | 1560 | 1561 - // varbit | 1562 | 1563 - // numeric | 1700 | 1231 - pub const Tag = enum(short) { - bool = 16, - bytea = 17, - char = 18, - name = 19, - int8 = 20, - int2 = 21, - int2vector = 22, - int4 = 23, - // regproc = 24, - text = 25, - // oid = 26, - // tid = 27, - // xid = 28, - // cid = 29, - // oidvector = 30, - // pg_type = 71, - // pg_attribute = 75, - // pg_proc = 81, - // pg_class = 83, - json = 114, - xml = 142, - point = 600, - lseg = 601, - path = 602, - box = 603, - polygon = 604, - line = 628, - cidr = 650, - float4 = 700, - float8 = 701, - circle = 718, - macaddr8 = 774, - money = 790, - macaddr = 829, - inet = 869, - aclitem = 1033, - bpchar = 1042, - varchar = 1043, - date = 1082, - time = 1083, - timestamp = 1114, - timestamptz = 1184, - interval = 1186, - pg_database = 1248, - timetz = 1266, - bit = 1560, - varbit = 1562, - numeric = 1700, - uuid = 2950, - - bool_array = 1000, - bytea_array = 1001, - char_array = 1002, - name_array = 1003, - int8_array = 1016, - int2_array = 1005, - int2vector_array = 1006, - int4_array = 1007, - // regproc_array = 1008, - text_array = 1009, - oid_array = 1028, - tid_array = 1010, - xid_array = 1011, - cid_array = 1012, - // oidvector_array = 1013, - // pg_type_array = 210, - // pg_attribute_array = 270, - // pg_proc_array = 272, - // pg_class_array = 273, - json_array = 199, - xml_array = 143, - point_array = 1017, - lseg_array = 1018, - path_array = 1019, - box_array = 1020, - polygon_array = 1027, - line_array = 629, - cidr_array = 651, - float4_array = 1021, - float8_array = 1022, - circle_array = 719, - macaddr8_array = 775, - money_array = 791, - macaddr_array = 1040, - inet_array = 1041, - aclitem_array = 1034, - bpchar_array = 1014, - varchar_array = 1015, - date_array = 1182, - time_array = 1183, - timestamp_array = 1115, - timestamptz_array = 1185, - interval_array = 1187, - pg_database_array = 12052, - timetz_array = 1270, - bit_array = 1561, - varbit_array = 1563, - numeric_array = 1231, - _, - - pub fn isBinaryFormatSupported(this: Tag) bool { - return switch (this) { - // TODO: .int2_array, .float8_array, - .int4_array, .float4_array, .int4, .float8, .float4, .bytea, .numeric => true, - - else => false, - }; - } - - pub fn formatCode(this: Tag) short { - if (this.isBinaryFormatSupported()) { - return 1; - } - - return 0; - } - - fn PostgresBinarySingleDimensionArray(comptime T: type) type { - return extern struct { - // struct array_int4 { - // int4_t ndim; /* Number of dimensions */ - // int4_t _ign; /* offset for data, removed by libpq */ - // Oid elemtype; /* type of element in the array */ - - // /* First dimension */ - // int4_t size; /* Number of elements */ - // int4_t index; /* Index of first element */ - // int4_t first_value; /* Beginning of integer data */ - // }; - - ndim: i32, - offset_for_data: i32, - element_type: i32, - - len: i32, - index: i32, - first_value: T, - - pub fn slice(this: *@This()) []T { - if (this.len == 0) return &.{}; - - var head = @as([*]T, @ptrCast(&this.first_value)); - var current = head; - const len: usize = @intCast(this.len); - for (0..len) |i| { - // Skip every other value as it contains the size of the element - current = current[1..]; - - const val = current[0]; - const Int = std.meta.Int(.unsigned, @bitSizeOf(T)); - const swapped = @byteSwap(@as(Int, @bitCast(val))); - - head[i] = @bitCast(swapped); - - current = current[1..]; - } - - return head[0..len]; - } - - pub fn init(bytes: []const u8) *@This() { - const this: *@This() = @alignCast(@ptrCast(@constCast(bytes.ptr))); - this.ndim = @byteSwap(this.ndim); - this.offset_for_data = @byteSwap(this.offset_for_data); - this.element_type = @byteSwap(this.element_type); - this.len = @byteSwap(this.len); - this.index = @byteSwap(this.index); - return this; - } - }; - } - - pub fn toJSTypedArrayType(comptime T: Tag) JSValue.JSType { - return comptime switch (T) { - .int4_array => .Int32Array, - // .int2_array => .Uint2Array, - .float4_array => .Float32Array, - // .float8_array => .Float64Array, - else => @compileError("TODO: not implemented"), - }; - } - - pub fn byteArrayType(comptime T: Tag) type { - return comptime switch (T) { - .int4_array => i32, - // .int2_array => i16, - .float4_array => f32, - // .float8_array => f64, - else => @compileError("TODO: not implemented"), - }; - } - - pub fn unsignedByteArrayType(comptime T: Tag) type { - return comptime switch (T) { - .int4_array => u32, - // .int2_array => u16, - .float4_array => f32, - // .float8_array => f64, - else => @compileError("TODO: not implemented"), - }; - } - - pub fn pgArrayType(comptime T: Tag) type { - return PostgresBinarySingleDimensionArray(byteArrayType(T)); - } - - fn toJSWithType( - tag: Tag, - globalObject: *JSC.JSGlobalObject, - comptime Type: type, - value: Type, - ) anyerror!JSValue { - switch (tag) { - .numeric => { - return numeric.toJS(globalObject, value); - }, - - .float4, .float8 => { - return numeric.toJS(globalObject, value); - }, - - .json => { - return json.toJS(globalObject, value); - }, - - .bool => { - return @"bool".toJS(globalObject, value); - }, - - .timestamp, .timestamptz => { - return date.toJS(globalObject, value); - }, - - .bytea => { - return bytea.toJS(globalObject, value); - }, - - .int8 => { - return JSValue.fromInt64NoTruncate(globalObject, value); - }, - - .int4 => { - return numeric.toJS(globalObject, value); - }, - - else => { - return string.toJS(globalObject, value); - }, - } - } - - pub fn toJS( - tag: Tag, - globalObject: *JSC.JSGlobalObject, - value: anytype, - ) anyerror!JSValue { - return toJSWithType(tag, globalObject, @TypeOf(value), value); - } - - pub fn fromJS(globalObject: *JSC.JSGlobalObject, value: JSValue) anyerror!Tag { - if (value.isEmptyOrUndefinedOrNull()) { - return Tag.numeric; - } - - if (value.isCell()) { - const tag = value.jsType(); - if (tag.isStringLike()) { - return .text; - } - - if (tag == .JSDate) { - return .timestamp; - } - - if (tag.isTypedArray()) { - if (tag == .Int32Array) - return .int4_array; - - return .bytea; - } - - if (tag == .HeapBigInt) { - return .int8; - } - - if (tag.isArrayLike() and value.getLength(globalObject) > 0) { - return Tag.fromJS(globalObject, value.getIndex(globalObject, 0)); - } - - // Ban these types: - if (tag == .NumberObject) { - return error.JSError; - } - - if (tag == .BooleanObject) { - return error.JSError; - } - - // It's something internal - if (!tag.isIndexable()) { - return error.JSError; - } - - // We will JSON.stringify anything else. - if (tag.isObject()) { - return .json; - } - } - - if (value.isInt32()) { - return .int4; - } - - if (value.isNumber()) { - return .float8; - } - - if (value.isBoolean()) { - return .bool; - } - - return .numeric; - } - }; - - pub const string = struct { - pub const to = 25; - pub const from = [_]short{1002}; - - pub fn toJSWithType( - globalThis: *JSC.JSGlobalObject, - comptime Type: type, - value: Type, - ) anyerror!JSValue { - switch (comptime Type) { - [:0]u8, []u8, []const u8, [:0]const u8 => { - var str = String.fromUTF8(value); - defer str.deinit(); - return str.toJS(globalThis); - }, - - bun.String => { - return value.toJS(globalThis); - }, - - *Data => { - var str = String.fromUTF8(value.slice()); - defer str.deinit(); - defer value.deinit(); - return str.toJS(globalThis); - }, - - else => { - @compileError("unsupported type " ++ @typeName(Type)); - }, - } - } - - pub fn toJS( - globalThis: *JSC.JSGlobalObject, - value: anytype, - ) !JSValue { - var str = try toJSWithType(globalThis, @TypeOf(value), value); - defer str.deinit(); - return str.toJS(globalThis); - } - }; - - pub const numeric = struct { - pub const to = 0; - pub const from = [_]short{ 21, 23, 26, 700, 701 }; - - pub fn toJS( - _: *JSC.JSGlobalObject, - value: anytype, - ) anyerror!JSValue { - return JSValue.jsNumber(value); - } - }; - - pub const json = struct { - pub const to = 114; - pub const from = [_]short{ 114, 3802 }; - - pub fn toJS( - globalObject: *JSC.JSGlobalObject, - value: *Data, - ) anyerror!JSValue { - defer value.deinit(); - var str = bun.String.fromUTF8(value.slice()); - defer str.deref(); - const parse_result = JSValue.parse(str.toJS(globalObject), globalObject); - if (parse_result.isAnyError()) { - globalObject.throwValue(parse_result); - return error.JSError; - } - - return parse_result; - } - }; - - pub const @"bool" = struct { - pub const to = 16; - pub const from = [_]short{16}; - - pub fn toJS( - _: *JSC.JSGlobalObject, - value: bool, - ) anyerror!JSValue { - return JSValue.jsBoolean(value); - } - }; - - pub const date = struct { - pub const to = 1184; - pub const from = [_]short{ 1082, 1114, 1184 }; - - pub fn toJS( - globalObject: *JSC.JSGlobalObject, - value: *Data, - ) anyerror!JSValue { - defer value.deinit(); - return JSValue.fromDateString(globalObject, value.sliceZ().ptr); - } - }; - - pub const bytea = struct { - pub const to = 17; - pub const from = [_]short{17}; - - pub fn toJS( - globalObject: *JSC.JSGlobalObject, - value: *Data, - ) anyerror!JSValue { - defer value.deinit(); - - // var slice = value.slice()[@min(1, value.len)..]; - // _ = slice; - return JSValue.createBuffer(globalObject, value.slice(), null); - } - }; -}; +pub const protocol = @import("./postgres/postgres_protocol.zig"); +pub const types = @import("./postgres/postgres_types.zig"); const Socket = uws.AnySocket; const PreparedStatementsMap = std.HashMapUnmanaged(u64, *PostgresSQLStatement, bun.IdentityContext(u64), 80); +const SocketMonitor = struct { + const DebugSocketMonitorWriter = struct { + var file: std.fs.File = undefined; + var enabled = false; + var check = std.once(load); + pub fn write(data: []const u8) void { + file.writeAll(data) catch {}; + } + + fn load() void { + if (bun.getenvZAnyCase("BUN_POSTGRES_SOCKET_MONITOR")) |monitor| { + enabled = true; + file = std.fs.cwd().createFile(monitor, .{ .truncate = true }) catch { + enabled = false; + return; + }; + debug("writing to {s}", .{monitor}); + } + } + }; + + const DebugSocketMonitorReader = struct { + var file: std.fs.File = undefined; + var enabled = false; + var check = std.once(load); + + fn load() void { + if (bun.getenvZAnyCase("BUN_POSTGRES_SOCKET_MONITOR_READER")) |monitor| { + enabled = true; + file = std.fs.cwd().createFile(monitor, .{ .truncate = true }) catch { + enabled = false; + return; + }; + debug("duplicating reads to {s}", .{monitor}); + } + } + + pub fn write(data: []const u8) void { + file.writeAll(data) catch {}; + } + }; + + pub fn write(data: []const u8) void { + if (comptime bun.Environment.isDebug) { + DebugSocketMonitorWriter.check.call(); + if (DebugSocketMonitorWriter.enabled) { + DebugSocketMonitorWriter.write(data); + } + } + } + + pub fn read(data: []const u8) void { + if (comptime bun.Environment.isDebug) { + DebugSocketMonitorReader.check.call(); + if (DebugSocketMonitorReader.enabled) { + DebugSocketMonitorReader.write(data); + } + } + } +}; + pub const PostgresSQLContext = struct { tcp: ?*uws.SocketContext = null, @@ -2248,9 +417,10 @@ pub const PostgresSQLQuery = struct { } pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(JSC.conv) JSValue { - const arguments = callframe.arguments(3).slice(); + const arguments = callframe.arguments(4).slice(); const query = arguments[0]; const values = arguments[1]; + const columns = arguments[3]; if (!query.isString()) { globalThis.throw("query must be a string", .{}); @@ -2284,6 +454,9 @@ pub const PostgresSQLQuery = struct { PostgresSQLQuery.bindingSetCached(this_value, globalThis, values); PostgresSQLQuery.pendingValueSetCached(this_value, globalThis, pending_value); + if (columns != .undefined) { + PostgresSQLQuery.columnsSetCached(this_value, globalThis, columns); + } ptr.pending_value.set(globalThis, pending_value); return this_value; @@ -2318,9 +491,11 @@ pub const PostgresSQLQuery = struct { const binding_value = PostgresSQLQuery.bindingGetCached(callframe.this()) orelse .zero; var query_str = this.query.toUTF8(bun.default_allocator); defer query_str.deinit(); + const columns_value = PostgresSQLQuery.columnsGetCached(callframe.this()) orelse .undefined; - var signature = Signature.generate(globalObject, query_str.slice(), binding_value) catch |err| { - globalObject.throwError(err, "failed to generate signature"); + var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value) catch |err| { + if (!globalObject.hasException()) + globalObject.throwError(err, "failed to generate signature"); return .zero; }; @@ -2346,8 +521,9 @@ pub const PostgresSQLQuery = struct { } else { this.binary = this.statement.?.fields.len > 0; - PostgresRequest.bindAndExecute(globalObject, this.statement.?, binding_value, PostgresSQLConnection.Writer, writer) catch |err| { - globalObject.throwError(err, "failed to bind and execute query"); + PostgresRequest.bindAndExecute(globalObject, this.statement.?, binding_value, columns_value, PostgresSQLConnection.Writer, writer) catch |err| { + if (!globalObject.hasException()) + globalObject.throwError(err, "failed to bind and execute query"); return .zero; }; @@ -2360,19 +536,22 @@ pub const PostgresSQLQuery = struct { // If it does not have params, we can write and execute immediately in one go if (!has_params) { PostgresRequest.prepareAndQueryWithSignature(globalObject, query_str.slice(), binding_value, PostgresSQLConnection.Writer, writer, &signature) catch |err| { - globalObject.throwError(err, "failed to prepare and query"); + if (!globalObject.hasException()) + globalObject.throwError(err, "failed to prepare and query"); signature.deinit(); return .zero; }; did_write = true; } else { PostgresRequest.writeQuery(query_str.slice(), signature.name, signature.fields, PostgresSQLConnection.Writer, writer) catch |err| { - globalObject.throwError(err, "failed to write query"); + if (!globalObject.hasException()) + globalObject.throwError(err, "failed to write query"); signature.deinit(); return .zero; }; writer.write(&protocol.Sync) catch |err| { - globalObject.throwError(err, "failed to flush"); + if (!globalObject.hasException()) + globalObject.throwError(err, "failed to flush"); signature.deinit(); return .zero; }; @@ -2421,6 +600,8 @@ pub const PostgresRequest = struct { cursor_name: bun.String, globalObject: *JSC.JSGlobalObject, values_array: JSValue, + columns_value: JSValue, + parameter_fields: []const int4, result_fields: []const protocol.FieldDescription, comptime Context: type, writer: protocol.NewWriter(Context), @@ -2431,7 +612,7 @@ pub const PostgresRequest = struct { try writer.String(cursor_name); try writer.string(name); - var iter = JSC.JSArrayIterator.init(values_array, globalObject); + const len: u32 = @truncate(parameter_fields.len); // The number of parameter format codes that follow (denoted C // below). This can be zero to indicate that there are no @@ -2439,10 +620,32 @@ pub const PostgresRequest = struct { // (text); or one, in which case the specified format code is // applied to all parameters; or it can equal the actual number // of parameters. - try writer.short(iter.len); + try writer.short(len); - while (iter.next()) |value| { - const tag = try types.Tag.fromJS(globalObject, value); + var iter = QueryBindingIterator.init(values_array, columns_value, globalObject); + for (0..len) |i| { + const tag: types.Tag = @enumFromInt(@as(short, @intCast(parameter_fields[i]))); + + const force_text = tag.isBinaryFormatSupported() and brk: { + iter.to(@truncate(i)); + if (iter.next()) |value| { + break :brk value.isString(); + } + if (iter.anyFailed()) { + return error.InvalidQueryBinding; + } + break :brk false; + }; + + if (force_text) { + // If they pass a value as a string, let's avoid attempting to + // convert it to the binary representation. This minimizes the room + // for mistakes on our end, such as stripping the timezone + // differently than what Postgres does when given a timestamp with + // timezone. + try writer.short(0); + continue; + } try writer.short( tag.formatCode(), @@ -2451,14 +654,14 @@ pub const PostgresRequest = struct { // The number of parameter values that follow (possibly zero). This // must match the number of parameters needed by the query. - try writer.short(iter.len); + try writer.short(len); - iter = JSC.JSArrayIterator.init(values_array, globalObject); - - debug("Bind: {} ({d} args)", .{ bun.fmt.quote(name), iter.len }); - - while (iter.next()) |value| { - if (value.isUndefinedOrNull()) { + debug("Bind: {} ({d} args)", .{ bun.fmt.quote(name), len }); + iter.to(0); + var i: usize = 0; + while (iter.next()) |value| : (i += 1) { + const tag: types.Tag = @enumFromInt(@as(short, @intCast(parameter_fields[i]))); + if (value.isEmptyOrUndefinedOrNull()) { debug(" -> NULL", .{}); // As a special case, -1 indicates a // NULL parameter value. No value bytes follow in the NULL case. @@ -2466,10 +669,14 @@ pub const PostgresRequest = struct { continue; } - const tag = try types.Tag.fromJS(globalObject, value); - debug(" -> {s}", .{@tagName(tag)}); - switch (tag) { + switch ( + // If they pass a value as a string, let's avoid attempting to + // convert it to the binary representation. This minimizes the room + // for mistakes on our end, such as stripping the timezone + // differently than what Postgres does when given a timestamp with + // timezone. + if (tag.isBinaryFormatSupported() and value.isString()) .text else tag) { .json => { var str = bun.String.empty; defer str.deref(); @@ -2482,14 +689,12 @@ pub const PostgresRequest = struct { }, .bool => { const l = try writer.length(); - try writer.bool(value.toBoolean()); + try writer.write(&[1]u8{@intFromBool(value.toBoolean())}); try l.writeExcludingSelf(); }, - .time, .timestamp, .timestamptz => { - var buf = std.mem.zeroes([28]u8); - const str = value.toISOString(globalObject, &buf); + .timestamp, .timestamptz => { const l = try writer.length(); - try writer.write(str); + try writer.int8(types.date.fromJS(globalObject, value)); try l.writeExcludingSelf(); }, .bytea => { @@ -2518,6 +723,7 @@ pub const PostgresRequest = struct { try writer.f64(@bitCast(value.coerceToDouble(globalObject))); try l.writeExcludingSelf(); }, + else => { const str = String.fromJSRef(value, globalObject); defer str.deref(); @@ -2589,7 +795,7 @@ pub const PostgresRequest = struct { signature: *Signature, ) !void { try writeQuery(query, signature.name, signature.fields, Context, writer); - try writeBind(signature.name, bun.String.empty, globalObject, array_value, &.{}, Context, writer); + try writeBind(signature.name, bun.String.empty, globalObject, array_value, .zero, &.{}, &.{}, Context, writer); var exec = protocol.Execute{ .p = .{ .prepared_statement = signature.name, @@ -2601,33 +807,15 @@ pub const PostgresRequest = struct { try writer.write(&protocol.Sync); } - pub fn prepareAndQuery( - globalObject: *JSC.JSGlobalObject, - query: bun.String, - array_value: JSValue, - comptime Context: type, - writer: protocol.NewWriter(Context), - ) !Signature { - var query_ = query.toUTF8(bun.default_allocator); - defer query_.deinit(); - var signature = try Signature.generate(globalObject, query_.slice(), array_value); - errdefer { - signature.deinit(); - } - - try prepareAndQueryWithSignature(globalObject, query_.slice(), array_value, Context, writer, &signature); - - return signature; - } - pub fn bindAndExecute( globalObject: *JSC.JSGlobalObject, statement: *PostgresSQLStatement, array_value: JSValue, + columns_value: JSValue, comptime Context: type, writer: protocol.NewWriter(Context), ) !void { - try writeBind(statement.signature.name, bun.String.empty, globalObject, array_value, statement.fields, Context, writer); + try writeBind(statement.signature.name, bun.String.empty, globalObject, array_value, columns_value, statement.parameters, statement.fields, Context, writer); var exec = protocol.Execute{ .p = .{ .prepared_statement = statement.signature.name, @@ -2867,6 +1055,7 @@ pub const PostgresSQLConnection = struct { if (chunk.len == 0) return; const wrote = this.socket.write(chunk, false); if (wrote > 0) { + SocketMonitor.write(chunk[0..@intCast(wrote)]); this.write_buffer.consume(@intCast(wrote)); } } @@ -2909,24 +1098,43 @@ pub const PostgresSQLConnection = struct { this.fail("Failed to write startup message", err); }; + const event_loop = this.globalObject.bunVM().eventLoop(); + event_loop.enter(); + defer event_loop.exit(); this.flushData(); } pub fn onTimeout(this: *PostgresSQLConnection) void { - var vm = this.globalObject.bunVM(); - defer vm.drainMicrotasks(); + _ = this; // autofix debug("onTimeout", .{}); } pub fn onDrain(this: *PostgresSQLConnection) void { - var vm = this.globalObject.bunVM(); - defer vm.drainMicrotasks(); + const event_loop = this.globalObject.bunVM().eventLoop(); + event_loop.enter(); + defer event_loop.exit(); this.flushData(); } pub fn onData(this: *PostgresSQLConnection, data: []const u8) void { - var vm = this.globalObject.bunVM(); - defer vm.drainMicrotasks(); + this.ref(); + const vm = this.globalObject.bunVM(); + defer { + if (this.status == .connected and this.requests.readableLength() == 0 and this.write_buffer.remaining().len == 0) { + // Don't keep the process alive when there's nothing to do. + this.poll_ref.unref(vm); + } else if (this.status == .connected) { + // Keep the process alive if there's something to do. + this.poll_ref.ref(vm); + } + + this.deref(); + } + + const event_loop = vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + SocketMonitor.read(data); if (this.read_buffer.remaining().len == 0) { var consumed: usize = 0; var offset: usize = 0; @@ -3321,10 +1529,11 @@ pub const PostgresSQLConnection = struct { int8 = 4, bool = 5, date = 6, - bytea = 7, - json = 8, - array = 9, - typed_array = 10, + date_with_time_zone = 7, + bytea = 8, + json = 9, + array = 10, + typed_array = 11, }; pub const Value = extern union { @@ -3335,6 +1544,7 @@ pub const PostgresSQLConnection = struct { int8: i64, bool: u8, date: f64, + date_with_time_zone: f64, bytea: [2]usize, json: bun.WTF.StringImpl, array: Array, @@ -3478,12 +1688,24 @@ pub const PostgresSQLConnection = struct { return DataCell{ .tag = .json, .value = .{ .json = String.createUTF8(bytes).value.WTFStringImpl }, .free_value = 1 }; }, .bool => { - return DataCell{ .tag = .bool, .value = .{ .bool = @intFromBool(bytes.len > 0 and bytes[0] == 't') } }; + if (binary) { + return DataCell{ .tag = .bool, .value = .{ .bool = @intFromBool(bytes.len > 0 and bytes[0] == 1) } }; + } else { + return DataCell{ .tag = .bool, .value = .{ .bool = @intFromBool(bytes.len > 0 and bytes[0] == 't') } }; + } }, - .time, .timestamp, .timestamptz => { - var str = bun.String.init(bytes); - defer str.deref(); - return DataCell{ .tag = .date, .value = .{ .date = str.parseDate(globalObject) } }; + .timestamp, .timestamptz => |tag| { + if (binary and bytes.len == 8) { + switch (tag) { + .timestamptz => return DataCell{ .tag = .date_with_time_zone, .value = .{ .date_with_time_zone = types.date.fromBinary(bytes) } }, + .timestamp => return DataCell{ .tag = .date, .value = .{ .date = types.date.fromBinary(bytes) } }, + else => unreachable, + } + } else { + var str = bun.String.init(bytes); + defer str.deref(); + return DataCell{ .tag = .date, .value = .{ .date = str.parseDate(globalObject) } }; + } }, .bytea => { if (binary) { @@ -3656,7 +1878,8 @@ pub const PostgresSQLConnection = struct { .prepared => { if (req.status == .pending and stmt.status == .prepared) { const binding_value = PostgresSQLQuery.bindingGetCached(req.thisValue) orelse .zero; - PostgresRequest.bindAndExecute(this.globalObject, stmt, binding_value, PostgresSQLConnection.Writer, this.writer()) catch |err| { + const columns_value = PostgresSQLQuery.columnsGetCached(req.thisValue) orelse .zero; + PostgresRequest.bindAndExecute(this.globalObject, stmt, binding_value, columns_value, PostgresSQLConnection.Writer, this.writer()) catch |err| { req.onWriteFail(err, this.globalObject); req.deref(); this.requests.discard(1); @@ -4129,6 +2352,124 @@ pub const PostgresSQLStatement = struct { } }; +const QueryBindingIterator = union(enum) { + array: JSC.JSArrayIterator, + objects: ObjectIterator, + + pub fn init(array: JSValue, columns: JSValue, globalObject: *JSC.JSGlobalObject) QueryBindingIterator { + if (columns.isEmptyOrUndefinedOrNull()) { + return .{ .array = JSC.JSArrayIterator.init(array, globalObject) }; + } + + return .{ + .objects = .{ + .array = array, + .columns = columns, + .globalObject = globalObject, + .columns_count = columns.getLength(globalObject), + .array_length = array.getLength(globalObject), + }, + }; + } + + pub const ObjectIterator = struct { + array: JSValue, + columns: JSValue = .zero, + globalObject: *JSC.JSGlobalObject, + cell_i: usize = 0, + row_i: usize = 0, + current_row: JSC.JSValue = .zero, + columns_count: usize = 0, + array_length: usize = 0, + any_failed: bool = false, + + pub fn next(this: *ObjectIterator) ?JSC.JSValue { + if (this.row_i >= this.array_length) { + return null; + } + + const cell_i = this.cell_i; + this.cell_i += 1; + const row_i = this.row_i; + + const globalObject = this.globalObject; + + if (this.current_row == .zero) { + this.current_row = JSC.JSObject.getIndex(this.array, globalObject, @intCast(row_i)); + if (this.current_row.isEmptyOrUndefinedOrNull()) { + if (!globalObject.hasException()) + globalObject.throw("Expected a row to be returned at index {d}", .{row_i}); + this.any_failed = true; + return null; + } + } + + defer { + if (this.cell_i >= this.columns_count) { + this.cell_i = 0; + this.current_row = .zero; + this.row_i += 1; + } + } + + const property = JSC.JSObject.getIndex(this.columns, globalObject, @intCast(cell_i)); + if (property == .zero or property == .undefined) { + if (!globalObject.hasException()) + globalObject.throw("Expected a column at index {d} in row {d}", .{ cell_i, row_i }); + this.any_failed = true; + return null; + } + + const value = this.current_row.getOwnByValue(globalObject, property); + if (value == .zero or value == .undefined) { + if (!globalObject.hasException()) + globalObject.throw("Expected a value at index {d} in row {d}", .{ cell_i, row_i }); + this.any_failed = true; + return null; + } + return value; + } + }; + + pub fn next(this: *QueryBindingIterator) ?JSC.JSValue { + return switch (this.*) { + .array => |*iter| iter.next(), + .objects => |*iter| iter.next(), + }; + } + + pub fn anyFailed(this: *const QueryBindingIterator) bool { + return switch (this.*) { + .array => false, + .objects => |*iter| iter.any_failed, + }; + } + + pub fn to(this: *QueryBindingIterator, index: u32) void { + switch (this.*) { + .array => |*iter| iter.i = index, + .objects => |*iter| { + iter.cell_i = index % iter.columns_count; + iter.row_i = index / iter.columns_count; + iter.current_row = .zero; + }, + } + } + + pub fn reset(this: *QueryBindingIterator) void { + switch (this.*) { + .array => |*iter| { + iter.i = 0; + }, + .objects => |*iter| { + iter.cell_i = 0; + iter.row_i = 0; + iter.current_row = .zero; + }, + } + } +}; + const Signature = struct { fields: []const int4, name: []const u8, @@ -4147,7 +2488,7 @@ const Signature = struct { return hasher.final(); } - pub fn generate(globalObject: *JSC.JSGlobalObject, query: []const u8, array_value: JSValue) !Signature { + pub fn generate(globalObject: *JSC.JSGlobalObject, query: []const u8, array_value: JSValue, columns: JSValue) !Signature { var fields = std.ArrayList(int4).init(bun.default_allocator); var name = try std.ArrayList(u8).initCapacity(bun.default_allocator, query.len); @@ -4158,17 +2499,17 @@ const Signature = struct { name.deinit(); } - var iter = JSC.JSArrayIterator.init(array_value, globalObject); + var iter = QueryBindingIterator.init(array_value, columns, globalObject); while (iter.next()) |value| { - if (value.isUndefinedOrNull()) { + if (value.isEmptyOrUndefinedOrNull()) { + // Allow postgres to decide the type try fields.append(0); try name.appendSlice(".null"); continue; } const tag = try types.Tag.fromJS(globalObject, value); - try fields.append(@intFromEnum(tag)); switch (tag) { .int8 => try name.appendSlice(".int8"), @@ -4182,10 +2523,24 @@ const Signature = struct { .bool => try name.appendSlice(".bool"), .timestamp => try name.appendSlice(".timestamp"), .timestamptz => try name.appendSlice(".timestamptz"), - .time => try name.appendSlice(".time"), .bytea => try name.appendSlice(".bytea"), else => try name.appendSlice(".string"), } + + switch (tag) { + .bool, .int4, .int8, .float8, .int2, .numeric, .float4, .bytea => { + // We decide the type + try fields.append(@intFromEnum(tag)); + }, + else => { + // Allow postgres to decide the type + try fields.append(0); + }, + } + } + + if (iter.anyFailed()) { + return error.InvalidQueryBinding; } return Signature{ diff --git a/src/sql/postgres/postgres_protocol.zig b/src/sql/postgres/postgres_protocol.zig new file mode 100644 index 0000000000..4aee1791f9 --- /dev/null +++ b/src/sql/postgres/postgres_protocol.zig @@ -0,0 +1,1413 @@ +const std = @import("std"); +const bun = @import("root").bun; +const postgres = bun.JSC.Postgres; +const Data = postgres.Data; +const protocol = @This(); +const PostgresInt32 = postgres.PostgresInt32; +const PostgresShort = postgres.PostgresShort; +const String = bun.String; +const debug = postgres.debug; +const Crypto = JSC.API.Bun.Crypto; +const JSValue = JSC.JSValue; +const JSC = bun.JSC; +const short = postgres.short; +const int4 = postgres.int4; +const int8 = postgres.int8; +const PostgresInt64 = postgres.PostgresInt64; +const types = postgres.types; + +pub const ArrayList = struct { + array: *std.ArrayList(u8), + + pub fn offset(this: @This()) usize { + return this.array.items.len; + } + + pub fn write(this: @This(), bytes: []const u8) anyerror!void { + try this.array.appendSlice(bytes); + } + + pub fn pwrite(this: @This(), bytes: []const u8, i: usize) anyerror!void { + @memcpy(this.array.items[i..][0..bytes.len], bytes); + } + + pub const Writer = NewWriter(@This()); +}; + +pub const StackReader = struct { + buffer: []const u8 = "", + offset: *usize, + message_start: *usize, + + pub fn markMessageStart(this: @This()) void { + this.message_start.* = this.offset.*; + } + + pub fn ensureLength(this: @This(), length: usize) bool { + return this.buffer.len >= (this.offset.* + length); + } + + pub fn init(buffer: []const u8, offset: *usize, message_start: *usize) protocol.NewReader(StackReader) { + return .{ + .wrapped = .{ + .buffer = buffer, + .offset = offset, + .message_start = message_start, + }, + }; + } + + pub fn peek(this: StackReader) []const u8 { + return this.buffer[this.offset.*..]; + } + pub fn skip(this: StackReader, count: usize) void { + if (this.offset.* + count > this.buffer.len) { + this.offset.* = this.buffer.len; + return; + } + + this.offset.* += count; + } + pub fn ensureCapacity(this: StackReader, count: usize) bool { + return this.buffer.len >= (this.offset.* + count); + } + pub fn read(this: StackReader, count: usize) anyerror!Data { + const offset = this.offset.*; + if (!this.ensureCapacity(count)) { + return error.ShortRead; + } + + this.skip(count); + return Data{ + .temporary = this.buffer[offset..this.offset.*], + }; + } + pub fn readZ(this: StackReader) anyerror!Data { + const remaining = this.peek(); + if (bun.strings.indexOfChar(remaining, 0)) |zero| { + this.skip(zero + 1); + return Data{ + .temporary = remaining[0..zero], + }; + } + + return error.ShortRead; + } +}; + +pub fn NewWriterWrap( + comptime Context: type, + comptime offsetFn_: (fn (ctx: Context) usize), + comptime writeFunction_: (fn (ctx: Context, bytes: []const u8) anyerror!void), + comptime pwriteFunction_: (fn (ctx: Context, bytes: []const u8, offset: usize) anyerror!void), +) type { + return struct { + wrapped: Context, + + const writeFn = writeFunction_; + const pwriteFn = pwriteFunction_; + const offsetFn = offsetFn_; + pub const Ctx = Context; + + pub const WrappedWriter = @This(); + + pub inline fn write(this: @This(), data: []const u8) anyerror!void { + try writeFn(this.wrapped, data); + } + + pub const LengthWriter = struct { + index: usize, + context: WrappedWriter, + + pub fn write(this: LengthWriter) anyerror!void { + try this.context.pwrite(&Int32(this.context.offset() - this.index), this.index); + } + + pub fn writeExcludingSelf(this: LengthWriter) anyerror!void { + try this.context.pwrite(&Int32(this.context.offset() -| (this.index + 4)), this.index); + } + }; + + pub inline fn length(this: @This()) anyerror!LengthWriter { + const i = this.offset(); + try this.int4(0); + return LengthWriter{ + .index = i, + .context = this, + }; + } + + pub inline fn offset(this: @This()) usize { + return offsetFn(this.wrapped); + } + + pub inline fn pwrite(this: @This(), data: []const u8, i: usize) anyerror!void { + try pwriteFn(this.wrapped, data, i); + } + + pub fn int4(this: @This(), value: PostgresInt32) !void { + try this.write(std.mem.asBytes(&@byteSwap(value))); + } + + pub fn int8(this: @This(), value: PostgresInt64) !void { + try this.write(std.mem.asBytes(&@byteSwap(value))); + } + + pub fn sint4(this: @This(), value: i32) !void { + try this.write(std.mem.asBytes(&@byteSwap(value))); + } + + pub fn @"f64"(this: @This(), value: f64) !void { + try this.write(std.mem.asBytes(&@byteSwap(@as(u64, @bitCast(value))))); + } + + pub fn @"f32"(this: @This(), value: f32) !void { + try this.write(std.mem.asBytes(&@byteSwap(@as(u32, @bitCast(value))))); + } + + pub fn short(this: @This(), value: anytype) !void { + try this.write(std.mem.asBytes(&@byteSwap(@as(u16, @intCast(value))))); + } + + pub fn string(this: @This(), value: []const u8) !void { + try this.write(value); + if (value.len == 0 or value[value.len - 1] != 0) + try this.write(&[_]u8{0}); + } + + pub fn bytes(this: @This(), value: []const u8) !void { + try this.write(value); + if (value.len == 0 or value[value.len - 1] != 0) + try this.write(&[_]u8{0}); + } + + pub fn @"bool"(this: @This(), value: bool) !void { + try this.write(if (value) "t" else "f"); + } + + pub fn @"null"(this: @This()) !void { + try this.int4(std.math.maxInt(PostgresInt32)); + } + + pub fn String(this: @This(), value: bun.String) !void { + if (value.isEmpty()) { + try this.write(&[_]u8{0}); + return; + } + + var sliced = value.toUTF8(bun.default_allocator); + defer sliced.deinit(); + const slice = sliced.slice(); + + try this.write(slice); + if (slice.len == 0 or slice[slice.len - 1] != 0) + try this.write(&[_]u8{0}); + } + }; +} + +pub const FieldType = enum(u8) { + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a localized translation of one of these. Always present. + S = 'S', + + /// Severity: the field contents are ERROR, FATAL, or PANIC (in an error message), or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message). This is identical to the S field except that the contents are never localized. This is present only in messages generated by PostgreSQL versions 9.6 and later. + V = 'V', + + /// Code: the SQLSTATE code for the error (see Appendix A). Not localizable. Always present. + C = 'C', + + /// Message: the primary human-readable error message. This should be accurate but terse (typically one line). Always present. + M = 'M', + + /// Detail: an optional secondary error message carrying more detail about the problem. Might run to multiple lines. + D = 'D', + + /// Hint: an optional suggestion what to do about the problem. This is intended to differ from Detail in that it offers advice (potentially inappropriate) rather than hard facts. Might run to multiple lines. + H = 'H', + + /// Position: the field value is a decimal ASCII integer, indicating an error cursor position as an index into the original query string. The first character has index 1, and positions are measured in characters not bytes. + P = 'P', + + /// Internal position: this is defined the same as the P field, but it is used when the cursor position refers to an internally generated command rather than the one submitted by the client. The q field will always appear when this field appears. + p = 'p', + + /// Internal query: the text of a failed internally-generated command. This could be, for example, an SQL query issued by a PL/pgSQL function. + q = 'q', + + /// Where: an indication of the context in which the error occurred. Presently this includes a call stack traceback of active procedural language functions and internally-generated queries. The trace is one entry per line, most recent first. + W = 'W', + + /// Schema name: if the error was associated with a specific database object, the name of the schema containing that object, if any. + s = 's', + + /// Table name: if the error was associated with a specific table, the name of the table. (Refer to the schema name field for the name of the table's schema.) + t = 't', + + /// Column name: if the error was associated with a specific table column, the name of the column. (Refer to the schema and table name fields to identify the table.) + c = 'c', + + /// Data type name: if the error was associated with a specific data type, the name of the data type. (Refer to the schema name field for the name of the data type's schema.) + d = 'd', + + /// Constraint name: if the error was associated with a specific constraint, the name of the constraint. Refer to fields listed above for the associated table or domain. (For this purpose, indexes are treated as constraints, even if they weren't created with constraint syntax.) + n = 'n', + + /// File: the file name of the source-code location where the error was reported. + F = 'F', + + /// Line: the line number of the source-code location where the error was reported. + L = 'L', + + /// Routine: the name of the source-code routine reporting the error. + R = 'R', + + _, +}; + +pub const FieldMessage = union(FieldType) { + S: String, + V: String, + C: String, + M: String, + D: String, + H: String, + P: String, + p: String, + q: String, + W: String, + s: String, + t: String, + c: String, + d: String, + n: String, + F: String, + L: String, + R: String, + + pub fn format(this: FieldMessage, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + switch (this) { + inline else => |str| { + try std.fmt.format(writer, "{}", .{str}); + }, + } + } + + pub fn deinit(this: *FieldMessage) void { + switch (this.*) { + inline else => |*message| { + message.deref(); + }, + } + } + + pub fn decodeList(comptime Context: type, reader: NewReader(Context)) !std.ArrayListUnmanaged(FieldMessage) { + var messages = std.ArrayListUnmanaged(FieldMessage){}; + while (true) { + const field_int = try reader.int(u8); + if (field_int == 0) break; + const field: FieldType = @enumFromInt(field_int); + + var message = try reader.readZ(); + defer message.deinit(); + if (message.slice().len == 0) break; + + try messages.append(bun.default_allocator, FieldMessage.init(field, message.slice()) catch continue); + } + + return messages; + } + + pub fn init(tag: FieldType, message: []const u8) !FieldMessage { + return switch (tag) { + .S => FieldMessage{ .S = String.createUTF8(message) }, + .V => FieldMessage{ .V = String.createUTF8(message) }, + .C => FieldMessage{ .C = String.createUTF8(message) }, + .M => FieldMessage{ .M = String.createUTF8(message) }, + .D => FieldMessage{ .D = String.createUTF8(message) }, + .H => FieldMessage{ .H = String.createUTF8(message) }, + .P => FieldMessage{ .P = String.createUTF8(message) }, + .p => FieldMessage{ .p = String.createUTF8(message) }, + .q => FieldMessage{ .q = String.createUTF8(message) }, + .W => FieldMessage{ .W = String.createUTF8(message) }, + .s => FieldMessage{ .s = String.createUTF8(message) }, + .t => FieldMessage{ .t = String.createUTF8(message) }, + .c => FieldMessage{ .c = String.createUTF8(message) }, + .d => FieldMessage{ .d = String.createUTF8(message) }, + .n => FieldMessage{ .n = String.createUTF8(message) }, + .F => FieldMessage{ .F = String.createUTF8(message) }, + .L => FieldMessage{ .L = String.createUTF8(message) }, + .R => FieldMessage{ .R = String.createUTF8(message) }, + else => error.UnknownFieldType, + }; + } +}; + +pub fn NewReaderWrap( + comptime Context: type, + comptime markMessageStartFn_: (fn (ctx: Context) void), + comptime peekFn_: (fn (ctx: Context) []const u8), + comptime skipFn_: (fn (ctx: Context, count: usize) void), + comptime ensureCapacityFn_: (fn (ctx: Context, count: usize) bool), + comptime readFunction_: (fn (ctx: Context, count: usize) anyerror!Data), + comptime readZ_: (fn (ctx: Context) anyerror!Data), +) type { + return struct { + wrapped: Context, + const readFn = readFunction_; + const readZFn = readZ_; + const ensureCapacityFn = ensureCapacityFn_; + const skipFn = skipFn_; + const peekFn = peekFn_; + const markMessageStartFn = markMessageStartFn_; + + pub const Ctx = Context; + + pub inline fn markMessageStart(this: @This()) void { + markMessageStartFn(this.wrapped); + } + + pub inline fn read(this: @This(), count: usize) anyerror!Data { + return try readFn(this.wrapped, count); + } + + pub inline fn eatMessage(this: @This(), comptime msg_: anytype) anyerror!void { + const msg = msg_[1..]; + try this.ensureCapacity(msg.len); + + var input = try readFn(this.wrapped, msg.len); + defer input.deinit(); + if (bun.strings.eqlComptime(input.slice(), msg)) return; + return error.InvalidMessage; + } + + pub fn skip(this: @This(), count: usize) anyerror!void { + skipFn(this.wrapped, count); + } + + pub fn peek(this: @This()) []const u8 { + return peekFn(this.wrapped); + } + + pub inline fn readZ(this: @This()) anyerror!Data { + return try readZFn(this.wrapped); + } + + pub inline fn ensureCapacity(this: @This(), count: usize) anyerror!void { + if (!ensureCapacityFn(this.wrapped, count)) { + return error.ShortRead; + } + } + + pub fn int(this: @This(), comptime Int: type) !Int { + var data = try this.read(@sizeOf((Int))); + defer data.deinit(); + if (comptime Int == u8) { + return @as(Int, data.slice()[0]); + } + return @byteSwap(@as(Int, @bitCast(data.slice()[0..@sizeOf(Int)].*))); + } + + pub fn peekInt(this: @This(), comptime Int: type) ?Int { + const remain = this.peek(); + if (remain.len < @sizeOf(Int)) { + return null; + } + return @byteSwap(@as(Int, @bitCast(remain[0..@sizeOf(Int)].*))); + } + + pub fn expectInt(this: @This(), comptime Int: type, comptime value: comptime_int) !bool { + const actual = try this.int(Int); + return actual == value; + } + + pub fn int4(this: @This()) !PostgresInt32 { + return this.int(PostgresInt32); + } + + pub fn short(this: @This()) !PostgresShort { + return this.int(PostgresShort); + } + + pub fn length(this: @This()) !PostgresInt32 { + const expected = try this.int(PostgresInt32); + if (expected > -1) { + try this.ensureCapacity(@intCast(expected -| 4)); + } + + return expected; + } + + pub const bytes = read; + + pub fn String(this: @This()) !bun.String { + var result = try this.readZ(); + defer result.deinit(); + return bun.String.fromUTF8(result.slice()); + } + }; +} + +pub fn NewReader(comptime Context: type) type { + return NewReaderWrap(Context, Context.markMessageStart, Context.peek, Context.skip, Context.ensureLength, Context.read, Context.readZ); +} + +pub fn NewWriter(comptime Context: type) type { + return NewWriterWrap(Context, Context.offset, Context.write, Context.pwrite); +} + +fn decoderWrap(comptime Container: type, comptime decodeFn: anytype) type { + return struct { + pub fn decode(this: *Container, context: anytype) anyerror!void { + const Context = @TypeOf(context); + try decodeFn(this, Context, NewReader(Context){ .wrapped = context }); + } + }; +} + +fn writeWrap(comptime Container: type, comptime writeFn: anytype) type { + return struct { + pub fn write(this: *Container, context: anytype) anyerror!void { + const Context = @TypeOf(context); + try writeFn(this, Context, NewWriter(Context){ .wrapped = context }); + } + }; +} + +pub const Authentication = union(enum) { + Ok: void, + ClearTextPassword: struct {}, + MD5Password: struct { + salt: [4]u8, + }, + KerberosV5: struct {}, + SCMCredential: struct {}, + GSS: struct {}, + GSSContinue: struct { + data: Data, + }, + SSPI: struct {}, + SASL: struct {}, + SASLContinue: struct { + data: Data, + r: []const u8, + s: []const u8, + i: []const u8, + + pub fn iterationCount(this: *const @This()) !u32 { + return try std.fmt.parseInt(u32, this.i, 0); + } + }, + SASLFinal: struct { + data: Data, + }, + Unknown: void, + + pub fn deinit(this: *@This()) void { + switch (this.*) { + .MD5Password => {}, + .SASL => {}, + .SASLContinue => { + this.SASLContinue.data.zdeinit(); + }, + .SASLFinal => { + this.SASLFinal.data.zdeinit(); + }, + else => {}, + } + } + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const message_length = try reader.length(); + + switch (try reader.int4()) { + 0 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ .Ok = {} }; + }, + 2 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ + .KerberosV5 = .{}, + }; + }, + 3 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ + .ClearTextPassword = .{}, + }; + }, + 5 => { + if (message_length != 12) return error.InvalidMessageLength; + if (!try reader.expectInt(u32, 5)) { + return error.InvalidMessage; + } + var salt_data = try reader.bytes(4); + defer salt_data.deinit(); + this.* = .{ + .MD5Password = .{ + .salt = salt_data.slice()[0..4].*, + }, + }; + }, + 7 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ + .GSS = .{}, + }; + }, + + 8 => { + if (message_length < 9) return error.InvalidMessageLength; + const bytes = try reader.read(message_length - 8); + this.* = .{ + .GSSContinue = .{ + .data = bytes, + }, + }; + }, + 9 => { + if (message_length != 8) return error.InvalidMessageLength; + this.* = .{ + .SSPI = .{}, + }; + }, + + 10 => { + if (message_length < 9) return error.InvalidMessageLength; + try reader.skip(message_length - 8); + this.* = .{ + .SASL = .{}, + }; + }, + + 11 => { + if (message_length < 9) return error.InvalidMessageLength; + var bytes = try reader.bytes(message_length - 8); + errdefer { + bytes.deinit(); + } + + var iter = bun.strings.split(bytes.slice(), ","); + var r: ?[]const u8 = null; + var i: ?[]const u8 = null; + var s: ?[]const u8 = null; + + while (iter.next()) |item| { + if (item.len > 2) { + const key = item[0]; + const after_equals = item[2..]; + if (key == 'r') { + r = after_equals; + } else if (key == 's') { + s = after_equals; + } else if (key == 'i') { + i = after_equals; + } + } + } + + if (r == null) { + debug("Missing r", .{}); + } + + if (s == null) { + debug("Missing s", .{}); + } + + if (i == null) { + debug("Missing i", .{}); + } + + this.* = .{ + .SASLContinue = .{ + .data = bytes, + .r = r orelse return error.InvalidMessage, + .s = s orelse return error.InvalidMessage, + .i = i orelse return error.InvalidMessage, + }, + }; + }, + + 12 => { + if (message_length < 9) return error.InvalidMessageLength; + const remaining: usize = message_length - 8; + + const bytes = try reader.read(remaining); + this.* = .{ + .SASLFinal = .{ + .data = bytes, + }, + }; + }, + + else => { + this.* = .{ .Unknown = {} }; + }, + } + } + + pub const decode = decoderWrap(Authentication, decodeInternal).decode; +}; + +pub const ParameterStatus = struct { + name: Data = .{ .empty = {} }, + value: Data = .{ .empty = {} }, + + pub fn deinit(this: *@This()) void { + this.name.deinit(); + this.value.deinit(); + } + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + this.* = .{ + .name = try reader.readZ(), + .value = try reader.readZ(), + }; + } + + pub const decode = decoderWrap(ParameterStatus, decodeInternal).decode; +}; + +pub const BackendKeyData = struct { + process_id: u32 = 0, + secret_key: u32 = 0, + pub const decode = decoderWrap(BackendKeyData, decodeInternal).decode; + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + if (!try reader.expectInt(u32, 12)) { + return error.InvalidBackendKeyData; + } + + this.* = .{ + .process_id = @bitCast(try reader.int4()), + .secret_key = @bitCast(try reader.int4()), + }; + } +}; + +pub const ErrorResponse = struct { + messages: std.ArrayListUnmanaged(FieldMessage) = .{}, + + pub fn format(formatter: ErrorResponse, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + for (formatter.messages.items) |message| { + try std.fmt.format(writer, "{}\n", .{message}); + } + } + + pub fn deinit(this: *ErrorResponse) void { + for (this.messages.items) |*message| { + message.deinit(); + } + this.messages.deinit(bun.default_allocator); + } + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var remaining_bytes = try reader.length(); + if (remaining_bytes < 4) return error.InvalidMessageLength; + remaining_bytes -|= 4; + + if (remaining_bytes > 0) { + this.* = .{ + .messages = try FieldMessage.decodeList(Container, reader), + }; + } + } + + pub const decode = decoderWrap(ErrorResponse, decodeInternal).decode; + + pub fn toJS(this: ErrorResponse, globalObject: *JSC.JSGlobalObject) JSValue { + var b = bun.StringBuilder{}; + defer b.deinit(bun.default_allocator); + + for (this.messages.items) |msg| { + b.cap += switch (msg) { + inline else => |m| m.utf8ByteLength(), + } + 1; + } + b.allocate(bun.default_allocator) catch {}; + + for (this.messages.items) |msg| { + var str = switch (msg) { + inline else => |m| m.toUTF8(bun.default_allocator), + }; + defer str.deinit(); + _ = b.append(str.slice()); + _ = b.append("\n"); + } + + return globalObject.createSyntaxErrorInstance("Postgres error occurred\n{s}", .{b.allocatedSlice()[0..b.len]}); + } +}; + +pub const PortalOrPreparedStatement = union(enum) { + portal: []const u8, + prepared_statement: []const u8, + + pub fn slice(this: @This()) []const u8 { + return switch (this) { + .portal => this.portal, + .prepared_statement => this.prepared_statement, + }; + } + + pub fn tag(this: @This()) u8 { + return switch (this) { + .portal => 'P', + .prepared_statement => 'S', + }; + } +}; + +/// Close (F) +/// Byte1('C') +/// - Identifies the message as a Close command. +/// Int32 +/// - Length of message contents in bytes, including self. +/// Byte1 +/// - 'S' to close a prepared statement; or 'P' to close a portal. +/// String +/// - The name of the prepared statement or portal to close (an empty string selects the unnamed prepared statement or portal). +pub const Close = struct { + p: PortalOrPreparedStatement, + + fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const p = this.p; + const count: u32 = @sizeOf((u32)) + 1 + p.slice().len + 1; + const header = [_]u8{ + 'C', + } ++ @byteSwap(count) ++ [_]u8{ + p.tag(), + }; + try writer.write(&header); + try writer.write(p.slice()); + try writer.write(&[_]u8{0}); + } + + pub const write = writeWrap(@This(), writeInternal); +}; + +pub const CloseComplete = [_]u8{'3'} ++ toBytes(Int32(4)); +pub const EmptyQueryResponse = [_]u8{'I'} ++ toBytes(Int32(4)); +pub const Terminate = [_]u8{'X'} ++ toBytes(Int32(4)); + +fn Int32(value: anytype) [4]u8 { + return @bitCast(@byteSwap(@as(int4, @intCast(value)))); +} + +const toBytes = std.mem.toBytes; + +pub const TransactionStatusIndicator = enum(u8) { + /// if idle (not in a transaction block) + I = 'I', + + /// if in a transaction block + T = 'T', + + /// if in a failed transaction block + E = 'E', + + _, +}; + +pub const ReadyForQuery = struct { + status: TransactionStatusIndicator = .I, + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + const status = try reader.int(u8); + this.* = .{ + .status = @enumFromInt(status), + }; + } + + pub const decode = decoderWrap(ReadyForQuery, decodeInternal).decode; +}; + +pub const FormatCode = enum { + text, + binary, + + pub fn from(value: short) !FormatCode { + return switch (value) { + 0 => .text, + 1 => .binary, + else => error.UnknownFormatCode, + }; + } +}; + +pub const null_int4 = 4294967295; + +pub const DataRow = struct { + pub fn decode(context: anytype, comptime ContextType: type, reader: NewReader(ContextType), comptime forEach: fn (@TypeOf(context), index: u32, bytes: ?*Data) anyerror!bool) anyerror!void { + var remaining_bytes = try reader.length(); + remaining_bytes -|= 4; + + const remaining_fields: usize = @intCast(@max(try reader.short(), 0)); + + for (0..remaining_fields) |index| { + const byte_length = try reader.int4(); + switch (byte_length) { + 0 => break, + null_int4 => { + if (!try forEach(context, @intCast(index), null)) break; + }, + else => { + var bytes = try reader.bytes(@intCast(byte_length)); + if (!try forEach(context, @intCast(index), &bytes)) break; + }, + } + } + } +}; + +pub const BindComplete = [_]u8{'2'} ++ toBytes(Int32(4)); + +pub const FieldDescription = struct { + name: Data = .{ .empty = {} }, + table_oid: int4 = 0, + column_index: short = 0, + type_oid: int4 = 0, + + pub fn typeTag(this: @This()) types.Tag { + return @enumFromInt(@as(short, @truncate(this.type_oid))); + } + + pub fn deinit(this: *@This()) void { + this.name.deinit(); + } + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var name = try reader.readZ(); + errdefer { + name.deinit(); + } + // If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero. + // Int16 + // If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero. + // Int32 + // The object ID of the field's data type. + // Int16 + // The data type size (see pg_type.typlen). Note that negative values denote variable-width types. + // Int32 + // The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific. + // Int16 + // The format code being used for the field. Currently will be zero (text) or one (binary). In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero. + this.* = .{ + .table_oid = try reader.int4(), + .column_index = try reader.short(), + .type_oid = try reader.int4(), + .name = .{ .owned = try name.toOwned() }, + }; + + try reader.skip(2 + 4 + 2); + } + + pub const decode = decoderWrap(FieldDescription, decodeInternal).decode; +}; + +pub const RowDescription = struct { + fields: []const FieldDescription = &[_]FieldDescription{}, + pub fn deinit(this: *@This()) void { + for (this.fields) |*field| { + @constCast(field).deinit(); + } + + bun.default_allocator.free(this.fields); + } + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var remaining_bytes = try reader.length(); + remaining_bytes -|= 4; + + const field_count: usize = @intCast(@max(try reader.short(), 0)); + var fields = try bun.default_allocator.alloc( + FieldDescription, + field_count, + ); + var remaining = fields; + errdefer { + for (fields[0 .. field_count - remaining.len]) |*field| { + field.deinit(); + } + + bun.default_allocator.free(fields); + } + while (remaining.len > 0) { + try remaining[0].decodeInternal(Container, reader); + remaining = remaining[1..]; + } + this.* = .{ + .fields = fields, + }; + } + + pub const decode = decoderWrap(RowDescription, decodeInternal).decode; +}; + +pub const ParameterDescription = struct { + parameters: []int4 = &[_]int4{}, + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var remaining_bytes = try reader.length(); + remaining_bytes -|= 4; + + const count = try reader.short(); + const parameters = try bun.default_allocator.alloc(int4, @intCast(@max(count, 0))); + + var data = try reader.read(@as(usize, @intCast(@max(count, 0))) * @sizeOf((int4))); + defer data.deinit(); + const input_params: []align(1) const int4 = toInt32Slice(int4, data.slice()); + for (input_params, parameters) |src, *dest| { + dest.* = @byteSwap(src); + } + + this.* = .{ + .parameters = parameters, + }; + } + + pub const decode = decoderWrap(ParameterDescription, decodeInternal).decode; +}; + +// workaround for zig compiler TODO +fn toInt32Slice(comptime Int: type, slice: []const u8) []align(1) const Int { + return @as([*]align(1) const Int, @ptrCast(slice.ptr))[0 .. slice.len / @sizeOf((Int))]; +} + +pub const NotificationResponse = struct { + pid: int4 = 0, + channel: bun.ByteList = .{}, + payload: bun.ByteList = .{}, + + pub fn deinit(this: *@This()) void { + this.channel.deinitWithAllocator(bun.default_allocator); + this.payload.deinitWithAllocator(bun.default_allocator); + } + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + this.* = .{ + .pid = try reader.int4(), + .channel = (try reader.readZ()).toOwned(), + .payload = (try reader.readZ()).toOwned(), + }; + } + + pub const decode = decoderWrap(NotificationResponse, decodeInternal).decode; +}; + +pub const CommandComplete = struct { + command_tag: Data = .{ .empty = {} }, + + pub fn deinit(this: *@This()) void { + this.command_tag.deinit(); + } + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + const tag = try reader.readZ(); + this.* = .{ + .command_tag = tag, + }; + } + + pub const decode = decoderWrap(CommandComplete, decodeInternal).decode; +}; + +pub const Parse = struct { + name: []const u8 = "", + query: []const u8 = "", + params: []const int4 = &.{}, + + pub fn deinit(this: *Parse) void { + _ = this; + } + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const parameters = this.params; + const count: usize = @sizeOf((u32)) + @sizeOf(u16) + (parameters.len * @sizeOf(u32)) + @max(zCount(this.name), 1) + @max(zCount(this.query), 1); + const header = [_]u8{ + 'P', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(this.name); + try writer.string(this.query); + try writer.short(parameters.len); + for (parameters) |parameter| { + try writer.int4(parameter); + } + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const ParseComplete = [_]u8{'1'} ++ toBytes(Int32(4)); + +pub const PasswordMessage = struct { + password: Data = .{ .empty = {} }, + + pub fn deinit(this: *PasswordMessage) void { + this.password.deinit(); + } + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const password = this.password.slice(); + const count: usize = @sizeOf((u32)) + password.len + 1; + const header = [_]u8{ + 'p', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(password); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const CopyData = struct { + data: Data = .{ .empty = {} }, + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + const length = try reader.length(); + + const data = try reader.read(@intCast(length -| 5)); + this.* = .{ + .data = data, + }; + } + + pub const decode = decoderWrap(CopyData, decodeInternal).decode; + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const data = this.data.slice(); + const count: u32 = @sizeOf((u32)) + data.len + 1; + const header = [_]u8{ + 'd', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(data); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const CopyDone = [_]u8{'c'} ++ toBytes(Int32(4)); +pub const Sync = [_]u8{'S'} ++ toBytes(Int32(4)); +pub const Flush = [_]u8{'H'} ++ toBytes(Int32(4)); +pub const SSLRequest = toBytes(Int32(8)) ++ toBytes(Int32(80877103)); +pub const NoData = [_]u8{'n'} ++ toBytes(Int32(4)); + +pub const SASLInitialResponse = struct { + mechanism: Data = .{ .empty = {} }, + data: Data = .{ .empty = {} }, + + pub fn deinit(this: *SASLInitialResponse) void { + this.mechanism.deinit(); + this.data.deinit(); + } + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const mechanism = this.mechanism.slice(); + const data = this.data.slice(); + const count: usize = @sizeOf(u32) + mechanism.len + 1 + data.len + @sizeOf(u32); + const header = [_]u8{ + 'p', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(mechanism); + try writer.int4(@truncate(data.len)); + try writer.write(data); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const SASLResponse = struct { + data: Data = .{ .empty = {} }, + + pub fn deinit(this: *SASLResponse) void { + this.data.deinit(); + } + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const data = this.data.slice(); + const count: usize = @sizeOf(u32) + data.len; + const header = [_]u8{ + 'p', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.write(data); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const StartupMessage = struct { + user: Data, + database: Data, + options: Data = Data{ .empty = {} }, + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const user = this.user.slice(); + const database = this.database.slice(); + const options = this.options.slice(); + + const count: usize = @sizeOf((int4)) + @sizeOf((int4)) + zFieldCount("user", user) + zFieldCount("database", database) + zFieldCount("client_encoding", "UTF8") + zFieldCount("", options) + 1; + + const header = toBytes(Int32(@as(u32, @truncate(count)))); + try writer.write(&header); + try writer.int4(196608); + + try writer.string("user"); + if (user.len > 0) + try writer.string(user); + + try writer.string("database"); + + if (database.len == 0) { + // The database to connect to. Defaults to the user name. + try writer.string(user); + } else { + try writer.string(database); + } + + try writer.string("client_encoding"); + try writer.string("UTF8"); + + if (options.len > 0) + try writer.string(options); + + try writer.write(&[_]u8{0}); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +fn zCount(slice: []const u8) usize { + return if (slice.len > 0) slice.len + 1 else 0; +} + +fn zFieldCount(prefix: []const u8, slice: []const u8) usize { + if (slice.len > 0) { + return zCount(prefix) + zCount(slice); + } + + return zCount(prefix); +} + +pub const Execute = struct { + max_rows: int4 = 0, + p: PortalOrPreparedStatement, + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + try writer.write("E"); + const length = try writer.length(); + if (this.p == .portal) + try writer.string(this.p.portal) + else + try writer.write(&[_]u8{0}); + try writer.int4(this.max_rows); + try length.write(); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const Describe = struct { + p: PortalOrPreparedStatement, + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const message = this.p.slice(); + try writer.write(&[_]u8{ + 'D', + }); + const length = try writer.length(); + try writer.write(&[_]u8{ + this.p.tag(), + }); + try writer.string(message); + try length.write(); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const Query = struct { + message: Data = .{ .empty = {} }, + + pub fn deinit(this: *@This()) void { + this.message.deinit(); + } + + pub fn writeInternal( + this: *const @This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const message = this.message.slice(); + const count: u32 = @sizeOf((u32)) + message.len + 1; + const header = [_]u8{ + 'Q', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(message); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const NegotiateProtocolVersion = struct { + version: int4 = 0, + unrecognized_options: std.ArrayListUnmanaged(String) = .{}, + + pub fn decodeInternal( + this: *@This(), + comptime Container: type, + reader: NewReader(Container), + ) !void { + const length = try reader.length(); + bun.assert(length >= 4); + + const version = try reader.int4(); + this.* = .{ + .version = version, + }; + + const unrecognized_options_count: u32 = @intCast(@max(try reader.int4(), 0)); + try this.unrecognized_options.ensureTotalCapacity(bun.default_allocator, unrecognized_options_count); + errdefer { + for (this.unrecognized_options.items) |*option| { + option.deinit(); + } + this.unrecognized_options.deinit(bun.default_allocator); + } + for (0..unrecognized_options_count) |_| { + var option = try reader.readZ(); + if (option.slice().len == 0) break; + defer option.deinit(); + this.unrecognized_options.appendAssumeCapacity( + String.fromUTF8(option), + ); + } + } +}; + +pub const NoticeResponse = struct { + messages: std.ArrayListUnmanaged(FieldMessage) = .{}, + pub fn deinit(this: *NoticeResponse) void { + for (this.messages.items) |*message| { + message.deinit(); + } + this.messages.deinit(bun.default_allocator); + } + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + var remaining_bytes = try reader.length(); + remaining_bytes -|= 4; + + if (remaining_bytes > 0) { + this.* = .{ + .messages = try FieldMessage.decodeList(Container, reader), + }; + } + } + pub const decode = decoderWrap(NoticeResponse, decodeInternal).decode; +}; + +pub const CopyFail = struct { + message: Data = .{ .empty = {} }, + + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + _ = try reader.int4(); + + const message = try reader.readZ(); + this.* = .{ + .message = message, + }; + } + + pub const decode = decoderWrap(CopyFail, decodeInternal).decode; + + pub fn writeInternal( + this: *@This(), + comptime Context: type, + writer: NewWriter(Context), + ) !void { + const message = this.message.slice(); + const count: u32 = @sizeOf((u32)) + message.len + 1; + const header = [_]u8{ + 'f', + } ++ toBytes(Int32(count)); + try writer.write(&header); + try writer.string(message); + } + + pub const write = writeWrap(@This(), writeInternal).write; +}; + +pub const CopyInResponse = struct { + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + _ = reader; + _ = this; + TODO(@This()); + } + + pub const decode = decoderWrap(CopyInResponse, decodeInternal).decode; +}; + +pub const CopyOutResponse = struct { + pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReader(Container)) !void { + _ = reader; + _ = this; + TODO(@This()); + } + + pub const decode = decoderWrap(CopyInResponse, decodeInternal).decode; +}; + +fn TODO(comptime Type: type) !void { + bun.Output.panic("TODO: not implemented {s}", .{bun.meta.typeBaseName(@typeName(Type))}); +} diff --git a/src/sql/postgres/postgres_types.zig b/src/sql/postgres/postgres_types.zig new file mode 100644 index 0000000000..9ace396880 --- /dev/null +++ b/src/sql/postgres/postgres_types.zig @@ -0,0 +1,558 @@ +const std = @import("std"); +const bun = @import("root").bun; +const postgres = bun.JSC.Postgres; +const Data = postgres.Data; +const protocol = @This(); +const PostgresInt32 = postgres.PostgresInt32; +const PostgresShort = postgres.PostgresShort; +const String = bun.String; +const debug = postgres.debug; +const Crypto = JSC.API.Bun.Crypto; +const JSValue = JSC.JSValue; +const JSC = bun.JSC; +const short = postgres.short; +const int4 = postgres.int4; + +// select b.typname, b.oid, b.typarray +// from pg_catalog.pg_type a +// left join pg_catalog.pg_type b on b.oid = a.typelem +// where a.typcategory = 'A' +// group by b.oid, b.typarray +// order by b.oid +// ; +// typname | oid | typarray +// ---------------------------------------+-------+---------- +// bool | 16 | 1000 +// bytea | 17 | 1001 +// char | 18 | 1002 +// name | 19 | 1003 +// int8 | 20 | 1016 +// int2 | 21 | 1005 +// int2vector | 22 | 1006 +// int4 | 23 | 1007 +// regproc | 24 | 1008 +// text | 25 | 1009 +// oid | 26 | 1028 +// tid | 27 | 1010 +// xid | 28 | 1011 +// cid | 29 | 1012 +// oidvector | 30 | 1013 +// pg_type | 71 | 210 +// pg_attribute | 75 | 270 +// pg_proc | 81 | 272 +// pg_class | 83 | 273 +// json | 114 | 199 +// xml | 142 | 143 +// point | 600 | 1017 +// lseg | 601 | 1018 +// path | 602 | 1019 +// box | 603 | 1020 +// polygon | 604 | 1027 +// line | 628 | 629 +// cidr | 650 | 651 +// float4 | 700 | 1021 +// float8 | 701 | 1022 +// circle | 718 | 719 +// macaddr8 | 774 | 775 +// money | 790 | 791 +// macaddr | 829 | 1040 +// inet | 869 | 1041 +// aclitem | 1033 | 1034 +// bpchar | 1042 | 1014 +// varchar | 1043 | 1015 +// date | 1082 | 1182 +// time | 1083 | 1183 +// timestamp | 1114 | 1115 +// timestamptz | 1184 | 1185 +// interval | 1186 | 1187 +// pg_database | 1248 | 12052 +// timetz | 1266 | 1270 +// bit | 1560 | 1561 +// varbit | 1562 | 1563 +// numeric | 1700 | 1231 +pub const Tag = enum(short) { + bool = 16, + bytea = 17, + char = 18, + name = 19, + int8 = 20, + int2 = 21, + int2vector = 22, + int4 = 23, + // regproc = 24, + text = 25, + // oid = 26, + // tid = 27, + // xid = 28, + // cid = 29, + // oidvector = 30, + // pg_type = 71, + // pg_attribute = 75, + // pg_proc = 81, + // pg_class = 83, + json = 114, + xml = 142, + point = 600, + lseg = 601, + path = 602, + box = 603, + polygon = 604, + line = 628, + cidr = 650, + float4 = 700, + float8 = 701, + circle = 718, + macaddr8 = 774, + money = 790, + macaddr = 829, + inet = 869, + aclitem = 1033, + bpchar = 1042, + varchar = 1043, + date = 1082, + time = 1083, + timestamp = 1114, + timestamptz = 1184, + interval = 1186, + pg_database = 1248, + timetz = 1266, + bit = 1560, + varbit = 1562, + numeric = 1700, + uuid = 2950, + + bool_array = 1000, + bytea_array = 1001, + char_array = 1002, + name_array = 1003, + int8_array = 1016, + int2_array = 1005, + int2vector_array = 1006, + int4_array = 1007, + // regproc_array = 1008, + text_array = 1009, + oid_array = 1028, + tid_array = 1010, + xid_array = 1011, + cid_array = 1012, + // oidvector_array = 1013, + // pg_type_array = 210, + // pg_attribute_array = 270, + // pg_proc_array = 272, + // pg_class_array = 273, + json_array = 199, + xml_array = 143, + point_array = 1017, + lseg_array = 1018, + path_array = 1019, + box_array = 1020, + polygon_array = 1027, + line_array = 629, + cidr_array = 651, + float4_array = 1021, + float8_array = 1022, + circle_array = 719, + macaddr8_array = 775, + money_array = 791, + macaddr_array = 1040, + inet_array = 1041, + aclitem_array = 1034, + bpchar_array = 1014, + varchar_array = 1015, + date_array = 1182, + time_array = 1183, + timestamp_array = 1115, + timestamptz_array = 1185, + interval_array = 1187, + pg_database_array = 12052, + timetz_array = 1270, + bit_array = 1561, + varbit_array = 1563, + numeric_array = 1231, + _, + + pub fn isBinaryFormatSupported(this: Tag) bool { + return switch (this) { + // TODO: .int2_array, .float8_array, + .bool, .timestamp, .timestamptz, .time, .int4_array, .float4_array, .int4, .float8, .float4, .bytea, .numeric => true, + + else => false, + }; + } + + pub fn formatCode(this: Tag) short { + if (this.isBinaryFormatSupported()) { + return 1; + } + + return 0; + } + + fn PostgresBinarySingleDimensionArray(comptime T: type) type { + return extern struct { + // struct array_int4 { + // int4_t ndim; /* Number of dimensions */ + // int4_t _ign; /* offset for data, removed by libpq */ + // Oid elemtype; /* type of element in the array */ + + // /* First dimension */ + // int4_t size; /* Number of elements */ + // int4_t index; /* Index of first element */ + // int4_t first_value; /* Beginning of integer data */ + // }; + + ndim: i32, + offset_for_data: i32, + element_type: i32, + + len: i32, + index: i32, + first_value: T, + + pub fn slice(this: *@This()) []T { + if (this.len == 0) return &.{}; + + var head = @as([*]T, @ptrCast(&this.first_value)); + var current = head; + const len: usize = @intCast(this.len); + for (0..len) |i| { + // Skip every other value as it contains the size of the element + current = current[1..]; + + const val = current[0]; + const Int = std.meta.Int(.unsigned, @bitSizeOf(T)); + const swapped = @byteSwap(@as(Int, @bitCast(val))); + + head[i] = @bitCast(swapped); + + current = current[1..]; + } + + return head[0..len]; + } + + pub fn init(bytes: []const u8) *@This() { + const this: *@This() = @alignCast(@ptrCast(@constCast(bytes.ptr))); + this.ndim = @byteSwap(this.ndim); + this.offset_for_data = @byteSwap(this.offset_for_data); + this.element_type = @byteSwap(this.element_type); + this.len = @byteSwap(this.len); + this.index = @byteSwap(this.index); + return this; + } + }; + } + + pub fn toJSTypedArrayType(comptime T: Tag) JSValue.JSType { + return comptime switch (T) { + .int4_array => .Int32Array, + // .int2_array => .Uint2Array, + .float4_array => .Float32Array, + // .float8_array => .Float64Array, + else => @compileError("TODO: not implemented"), + }; + } + + pub fn byteArrayType(comptime T: Tag) type { + return comptime switch (T) { + .int4_array => i32, + // .int2_array => i16, + .float4_array => f32, + // .float8_array => f64, + else => @compileError("TODO: not implemented"), + }; + } + + pub fn unsignedByteArrayType(comptime T: Tag) type { + return comptime switch (T) { + .int4_array => u32, + // .int2_array => u16, + .float4_array => f32, + // .float8_array => f64, + else => @compileError("TODO: not implemented"), + }; + } + + pub fn pgArrayType(comptime T: Tag) type { + return PostgresBinarySingleDimensionArray(byteArrayType(T)); + } + + fn toJSWithType( + tag: Tag, + globalObject: *JSC.JSGlobalObject, + comptime Type: type, + value: Type, + ) anyerror!JSValue { + switch (tag) { + .numeric => { + return numeric.toJS(globalObject, value); + }, + + .float4, .float8 => { + return numeric.toJS(globalObject, value); + }, + + .json => { + return json.toJS(globalObject, value); + }, + + .bool => { + return @"bool".toJS(globalObject, value); + }, + + .timestamp, .timestamptz => { + return date.toJS(globalObject, value); + }, + + .bytea => { + return bytea.toJS(globalObject, value); + }, + + .int8 => { + return JSValue.fromInt64NoTruncate(globalObject, value); + }, + + .int4 => { + return numeric.toJS(globalObject, value); + }, + + else => { + return string.toJS(globalObject, value); + }, + } + } + + pub fn toJS( + tag: Tag, + globalObject: *JSC.JSGlobalObject, + value: anytype, + ) anyerror!JSValue { + return toJSWithType(tag, globalObject, @TypeOf(value), value); + } + + pub fn fromJS(globalObject: *JSC.JSGlobalObject, value: JSValue) anyerror!Tag { + if (value.isEmptyOrUndefinedOrNull()) { + return Tag.numeric; + } + + if (value.isCell()) { + const tag = value.jsType(); + if (tag.isStringLike()) { + return .text; + } + + if (tag == .JSDate) { + return .timestamptz; + } + + if (tag.isTypedArray()) { + if (tag == .Int32Array) + return .int4_array; + + return .bytea; + } + + if (tag == .HeapBigInt) { + return .int8; + } + + if (tag.isArrayLike() and value.getLength(globalObject) > 0) { + return Tag.fromJS(globalObject, value.getIndex(globalObject, 0)); + } + + // Ban these types: + if (tag == .NumberObject) { + return error.JSError; + } + + if (tag == .BooleanObject) { + return error.JSError; + } + + // It's something internal + if (!tag.isIndexable()) { + return error.JSError; + } + + // We will JSON.stringify anything else. + if (tag.isObject()) { + return .json; + } + } + + if (value.isInt32()) { + return .int4; + } + + if (value.isAnyInt()) { + const int = value.toInt64(); + if (int >= std.math.minInt(u32) and int <= std.math.maxInt(u32)) { + return .int4; + } + + return .int8; + } + + if (value.isNumber()) { + return .float8; + } + + if (value.isBoolean()) { + return .bool; + } + + return .numeric; + } +}; + +pub const string = struct { + pub const to = 25; + pub const from = [_]short{1002}; + + pub fn toJSWithType( + globalThis: *JSC.JSGlobalObject, + comptime Type: type, + value: Type, + ) anyerror!JSValue { + switch (comptime Type) { + [:0]u8, []u8, []const u8, [:0]const u8 => { + var str = String.fromUTF8(value); + defer str.deinit(); + return str.toJS(globalThis); + }, + + bun.String => { + return value.toJS(globalThis); + }, + + *Data => { + var str = String.fromUTF8(value.slice()); + defer str.deinit(); + defer value.deinit(); + return str.toJS(globalThis); + }, + + else => { + @compileError("unsupported type " ++ @typeName(Type)); + }, + } + } + + pub fn toJS( + globalThis: *JSC.JSGlobalObject, + value: anytype, + ) !JSValue { + var str = try toJSWithType(globalThis, @TypeOf(value), value); + defer str.deinit(); + return str.toJS(globalThis); + } +}; + +pub const numeric = struct { + pub const to = 0; + pub const from = [_]short{ 21, 23, 26, 700, 701 }; + + pub fn toJS( + _: *JSC.JSGlobalObject, + value: anytype, + ) anyerror!JSValue { + return JSValue.jsNumber(value); + } +}; + +pub const json = struct { + pub const to = 114; + pub const from = [_]short{ 114, 3802 }; + + pub fn toJS( + globalObject: *JSC.JSGlobalObject, + value: *Data, + ) anyerror!JSValue { + defer value.deinit(); + var str = bun.String.fromUTF8(value.slice()); + defer str.deref(); + const parse_result = JSValue.parse(str.toJS(globalObject), globalObject); + if (parse_result.isAnyError()) { + globalObject.throwValue(parse_result); + return error.JSError; + } + + return parse_result; + } +}; + +pub const @"bool" = struct { + pub const to = 16; + pub const from = [_]short{16}; + + pub fn toJS( + _: *JSC.JSGlobalObject, + value: bool, + ) anyerror!JSValue { + return JSValue.jsBoolean(value); + } +}; + +pub const date = struct { + pub const to = 1184; + pub const from = [_]short{ 1082, 1114, 1184 }; + + // Postgres stores timestamp and timestampz as microseconds since 2000-01-01 + // This is a signed 64-bit integer. + const POSTGRES_EPOCH_DATE = 946684800000; + + pub fn fromBinary(bytes: []const u8) f64 { + const microseconds = std.mem.readInt(i64, bytes[0..8], .big); + const double_microseconds: f64 = @floatFromInt(microseconds); + return (double_microseconds / std.time.us_per_ms) + POSTGRES_EPOCH_DATE; + } + + pub fn fromJS(globalObject: *JSC.JSGlobalObject, value: JSValue) i64 { + const double_value = if (value.isDate()) + value.getUnixTimestamp() + else if (value.isNumber()) + value.asNumber() + else if (value.isString()) brk: { + var str = value.toBunString(globalObject); + defer str.deref(); + break :brk str.parseDate(globalObject); + } else return 0; + + const unix_timestamp: i64 = @intFromFloat(double_value); + return (unix_timestamp - POSTGRES_EPOCH_DATE) * std.time.us_per_ms; + } + + pub fn toJS( + globalObject: *JSC.JSGlobalObject, + value: anytype, + ) JSValue { + switch (@TypeOf(value)) { + i64 => { + // Convert from Postgres timestamp (μs since 2000-01-01) to Unix timestamp (ms) + const ms = @divFloor(value, std.time.us_per_ms) + POSTGRES_EPOCH_DATE; + return JSValue.fromDateNumber(globalObject, @floatFromInt(ms)); + }, + *Data => { + defer value.deinit(); + return JSValue.fromDateString(globalObject, value.sliceZ().ptr); + }, + else => @compileError("unsupported type " ++ @typeName(@TypeOf(value))), + } + } +}; + +pub const bytea = struct { + pub const to = 17; + pub const from = [_]short{17}; + + pub fn toJS( + globalObject: *JSC.JSGlobalObject, + value: *Data, + ) anyerror!JSValue { + defer value.deinit(); + + // var slice = value.slice()[@min(1, value.len)..]; + // _ = slice; + return JSValue.createBuffer(globalObject, value.slice(), null); + } +}; diff --git a/test/js/sql/sql-fixture-ref.ts b/test/js/sql/sql-fixture-ref.ts new file mode 100644 index 0000000000..af8f52dafc --- /dev/null +++ b/test/js/sql/sql-fixture-ref.ts @@ -0,0 +1,21 @@ +// This test passes by printing +// 1 +// 2 +// and exiting with code 0. +import { sql } from "bun"; +process.exitCode = 1; + +async function first() { + const result = await sql`select 1 as x`; + console.log(result[0].x); +} + +async function yo() { + const result2 = await sql`select 2 as x`; + console.log(result2[0].x); + process.exitCode = 0; +} +first(); +Bun.gc(true); +yo(); +Bun.gc(true); diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index cef6943439..8c0089c760 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -1,6 +1,8 @@ import { postgres, sql } from "bun:sql"; import { expect, test } from "bun:test"; -import { isCI } from "harness"; +import { $ } from "bun"; +import { bunExe, isCI, withoutAggressiveGC } from "harness"; +import path from "path"; if (!isCI) { require("./bootstrap.js"); @@ -100,12 +102,13 @@ if (!isCI) { expect((await sql`select ${null} as x`)[0].x).toBeNull(); }); - test("Unsigned Integer", async () => { + test.todo("Unsigned Integer", async () => { expect((await sql`select ${0x7fffffff + 2} as x`)[0].x).toBe(0x7fffffff + 2); }); test("Signed Integer", async () => { expect((await sql`select ${-1} as x`)[0].x).toBe(-1); + expect((await sql`select ${1} as x`)[0].x).toBe(1); }); test("Double", async () => { @@ -120,12 +123,18 @@ if (!isCI) { test("Boolean true", async () => expect((await sql`select ${true} as x`)[0].x).toBe(true)); - test("Date", async () => { + test("Date (timestamp)", async () => { const now = new Date(); const then = (await sql`select ${now}::timestamp as x`)[0].x; expect(then).toEqual(now); }); + test("Date (timestamptz)", async () => { + const now = new Date(); + const then = (await sql`select ${now}::timestamptz as x`)[0].x; + expect(then).toEqual(now); + }); + // t("Json", async () => { // const x = (await sql`select ${sql.json({ a: "hello", b: 42 })} as x`)[0].x; // return ["hello,42", [x.a, x.b].join()]; @@ -142,6 +151,23 @@ if (!isCI) { expect([x.a, x.b].join(",")).toBe("hello,42"); }); + test("bulk insert nested sql()", async () => { + await sql`create table users (name text, age int)`; + const users = [ + { name: "Alice", age: 25 }, + { name: "Bob", age: 30 }, + ]; + try { + const result = await sql`insert into users ${sql(users)} RETURNING *`; + expect(result).toEqual([ + { name: "Alice", age: 25 }, + { name: "Bob", age: 30 }, + ]); + } finally { + await sql`drop table users`; + } + }); + // t("Empty array", async () => [true, Array.isArray((await sql`select ${sql.array([], 1009)} as x`)[0].x)]); test("string arg with ::int -> Array", async () => @@ -991,16 +1017,46 @@ if (!isCI) { // }`.catch(e => e.code)), await sql`drop table test`] // }) - test("let postgres do implicit cast of unknown types", async () => { + test("timestamp with time zone is consistent", async () => { await sql`create table test (x timestamp with time zone)`; try { - const [{ x }] = await sql`insert into test values (${new Date().toISOString()}) returning *`; + const date = new Date(); + const [{ x }] = await sql`insert into test values (${date}) returning *`; expect(x instanceof Date).toBe(true); + expect(x.toISOString()).toBe(date.toISOString()); } finally { await sql`drop table test`; } }); + test("timestamp is consistent", async () => { + await sql`create table test2 (x timestamp)`; + try { + const date = new Date(); + const [{ x }] = await sql`insert into test2 values (${date}) returning *`; + expect(x instanceof Date).toBe(true); + expect(x.toISOString()).toBe(date.toISOString()); + } finally { + await sql`drop table test2`; + } + }); + + test( + "let postgres do implicit cast of unknown types", + async () => { + await sql`create table test3 (x timestamp with time zone)`; + try { + const date = new Date("2024-01-01T00:00:00Z"); + const [{ x }] = await sql`insert into test3 values (${date.toISOString()}) returning *`; + expect(x instanceof Date).toBe(true); + expect(x.toISOString()).toBe(date.toISOString()); + } finally { + await sql`drop table test3`; + } + }, + { timeout: 1000000 }, + ); + // t('only allows one statement', async() => // ['42601', await sql`select 1; select 2`.catch(e => e.code)] // ) @@ -1580,9 +1636,17 @@ if (!isCI) { // return [1, (await sql`select 1 as x`)[0].x] // }) - // t('Big result', async() => { - // return [100000, (await sql`select * from generate_series(1, 100000)`).count] - // }) + test("Big result", async () => { + const result = await sql`select * from generate_series(1, 100000)`; + expect(result.count).toBe(100000); + let i = 1; + + for (const row of result) { + if (row.generate_series !== i++) { + throw new Error(`Row out of order at index ${i - 1}`); + } + } + }); // t('Debug', async() => { // let result @@ -1601,15 +1665,14 @@ if (!isCI) { // typeof (await sql`select 9223372036854777 as x`)[0].x // ]) - // t('int is returned as Number', async() => [ - // 'number', - // typeof (await sql`select 123 as x`)[0].x - // ]) + test("int is returned as Number", async () => { + expect((await sql`select 123 as x`)[0].x).toBe(123); + }); - // t('numeric is returned as string', async() => [ - // 'string', - // typeof (await sql`select 1.2 as x`)[0].x - // ]) + test("numeric is returned as string", async () => { + const result = (await sql`select 1.2 as x`)[0].x; + expect(result).toBe("1.2"); + }); // t('Async stack trace', async() => { // const sql = postgres({ ...options, debug: false }) @@ -1733,9 +1796,9 @@ if (!isCI) { // [true, (await sql`bad keyword`.catch(e => e)) instanceof sql.PostgresError] // ) - // t('Result has columns spec', async() => - // ['x', (await sql`select 1 as x`).columns[0].name] - // ) + test.todo("Result has columns spec", async () => { + expect((await sql`select 1 as x`).columns[0].name).toBe("x"); + }); // t('forEach has result as second argument', async() => { // let x @@ -1921,9 +1984,9 @@ if (!isCI) { // ] // }) - // t('Array returns rows as arrays of columns', async() => { - // return [(await sql`select 1`.values())[0][0], 1] - // }) + test("Array returns rows as arrays of columns", async () => { + return [(await sql`select 1`.values())[0][0], 1]; + }); // t('Copy read', async() => { // const result = [] @@ -2586,4 +2649,11 @@ if (!isCI) { // xs.map(x => x.x).join('') // ] // }) + + test("keeps process alive when it should", async () => { + const file = path.posix.join(__dirname, "sql-fixture-ref.ts"); + const result = await $`DATABASE_URL=${process.env.DATABASE_URL} ${bunExe()} ${file}`; + expect(result.exitCode).toBe(0); + expect(result.stdout.toString().split("\n")).toEqual(["1", "2", ""]); + }); }