From ecbf103bf5356963bf3ed48f51e3cc59a03fdbe4 Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Thu, 21 Aug 2025 15:28:15 -0700 Subject: [PATCH] feat(MYSQL) Bun.SQL mysql support (#21968) ### What does this PR do? Add MySQL support, Refactor will be in a followup PR ### How did you verify your code works? A lot of tests --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: cirospaciari <6379399+cirospaciari@users.noreply.github.com> --- cmake/sources/JavaScriptSources.txt | 2 +- cmake/sources/ZigGeneratedClassesSources.txt | 2 +- cmake/sources/ZigSources.txt | 54 +- packages/bun-types/sql.d.ts | 13 +- src/bun.js/api.zig | 1 + src/bun.js/api/Timer/EventLoopTimer.zig | 10 + src/bun.js/api/postgres.classes.ts | 85 - src/bun.js/api/sql.classes.ts | 94 + src/bun.js/bindings/ErrorCode.ts | 4 + src/bun.js/bindings/JSGlobalObject.zig | 38 + src/bun.js/bindings/JSValue.zig | 7 + src/bun.js/bindings/SQLClient.cpp | 6 +- src/bun.js/bindings/bindings.cpp | 45 + .../bindings/generated_classes_list.zig | 2 + src/bun.js/rare_data.zig | 1 + src/fmt.zig | 1 - src/js/bun/sql.ts | 91 +- src/js/internal/sql/errors.ts | 22 +- src/js/internal/sql/mysql.ts | 1181 ++++++++++ src/js/internal/sql/postgres.ts | 61 +- src/js/internal/sql/query.ts | 13 +- src/js/internal/sql/shared.ts | 127 +- src/js/internal/sql/sqlite.ts | 37 +- src/js/internal/sql/utils.ts | 26 - src/js/private.d.ts | 4 +- src/sql/mysql.zig | 28 + src/sql/mysql/AuthMethod.zig | 37 + src/sql/mysql/Capabilities.zig | 205 ++ src/sql/mysql/ConnectionState.zig | 9 + src/sql/mysql/MySQLConnection.zig | 1949 +++++++++++++++++ src/sql/mysql/MySQLContext.zig | 22 + src/sql/mysql/MySQLQuery.zig | 545 +++++ src/sql/mysql/MySQLRequest.zig | 31 + src/sql/mysql/MySQLStatement.zig | 178 ++ src/sql/mysql/MySQLTypes.zig | 877 ++++++++ src/sql/mysql/SSLMode.zig | 7 + src/sql/mysql/StatusFlags.zig | 66 + src/sql/mysql/TLSStatus.zig | 11 + src/sql/mysql/protocol/AnyMySQLError.zig | 90 + src/sql/mysql/protocol/Auth.zig | 208 ++ src/sql/mysql/protocol/AuthSwitchRequest.zig | 42 + src/sql/mysql/protocol/AuthSwitchResponse.zig | 18 + src/sql/mysql/protocol/CharacterSet.zig | 236 ++ src/sql/mysql/protocol/ColumnDefinition41.zig | 97 + src/sql/mysql/protocol/CommandType.zig | 34 + src/sql/mysql/protocol/DecodeBinaryValue.zig | 153 ++ src/sql/mysql/protocol/EOFPacket.zig | 21 + src/sql/mysql/protocol/EncodeInt.zig | 73 + src/sql/mysql/protocol/ErrorPacket.zig | 82 + .../mysql/protocol/HandshakeResponse41.zig | 108 + src/sql/mysql/protocol/HandshakeV10.zig | 82 + src/sql/mysql/protocol/LocalInfileRequest.zig | 22 + src/sql/mysql/protocol/NewReader.zig | 136 ++ src/sql/mysql/protocol/NewWriter.zig | 132 ++ src/sql/mysql/protocol/OKPacket.zig | 49 + src/sql/mysql/protocol/PacketHeader.zig | 25 + src/sql/mysql/protocol/PacketType.zig | 14 + src/sql/mysql/protocol/PreparedStatement.zig | 115 + src/sql/mysql/protocol/Query.zig | 70 + src/sql/mysql/protocol/ResultSet.zig | 247 +++ src/sql/mysql/protocol/ResultSetHeader.zig | 12 + src/sql/mysql/protocol/Signature.zig | 86 + src/sql/mysql/protocol/StackReader.zig | 78 + .../mysql/protocol/StmtPrepareOKPacket.zig | 26 + src/sql/postgres/AnyPostgresError.zig | 54 +- src/sql/postgres/DataCell.zig | 1898 ++++++++-------- src/sql/postgres/PostgresProtocol.zig | 2 +- src/sql/postgres/PostgresRequest.zig | 2 +- src/sql/postgres/PostgresSQLConnection.zig | 35 +- src/sql/postgres/PostgresSQLQuery.zig | 12 +- .../postgres/PostgresSQLQueryResultMode.zig | 5 - src/sql/postgres/PostgresSQLStatement.zig | 4 +- src/sql/postgres/Signature.zig | 2 +- src/sql/postgres/SocketMonitor.zig | 5 + src/sql/postgres/protocol/Authentication.zig | 2 +- src/sql/postgres/protocol/CommandComplete.zig | 2 +- src/sql/postgres/protocol/CopyData.zig | 2 +- src/sql/postgres/protocol/CopyFail.zig | 2 +- src/sql/postgres/protocol/DataRow.zig | 4 +- .../postgres/protocol/FieldDescription.zig | 2 +- src/sql/postgres/protocol/NewReader.zig | 2 +- src/sql/postgres/protocol/ParameterStatus.zig | 2 +- src/sql/postgres/protocol/PasswordMessage.zig | 2 +- .../postgres/protocol/SASLInitialResponse.zig | 2 +- src/sql/postgres/protocol/SASLResponse.zig | 2 +- src/sql/postgres/protocol/StackReader.zig | 2 +- src/sql/postgres/protocol/StartupMessage.zig | 2 +- src/sql/postgres/types/PostgresString.zig | 2 +- src/sql/postgres/types/bytea.zig | 2 +- src/sql/postgres/types/date.zig | 2 +- src/sql/postgres/types/json.zig | 2 +- .../CachedStructure.zig} | 0 .../protocol => shared}/ColumnIdentifier.zig | 2 +- .../{postgres => shared}/ConnectionFlags.zig | 0 src/sql/{postgres => shared}/Data.zig | 35 +- .../{postgres => shared}/ObjectIterator.zig | 0 .../QueryBindingIterator.zig | 0 src/sql/shared/SQLDataCell.zig | 161 ++ src/sql/shared/SQLQueryResultMode.zig | 5 + test/integration/bun-types/fixture/sql.ts | 2 +- test/internal/ban-limits.json | 6 +- test/js/sql/sql-mysql.helpers.test.ts | 124 ++ test/js/sql/sql-mysql.test.ts | 805 +++++++ test/js/sql/sql-mysql.transactions.test.ts | 183 ++ test/js/sql/sql.test.ts | 8 +- test/js/sql/sqlite-sql.test.ts | 24 +- test/js/sql/sqlite-url-parsing.test.ts | 13 +- 107 files changed, 10184 insertions(+), 1387 deletions(-) delete mode 100644 src/bun.js/api/postgres.classes.ts create mode 100644 src/bun.js/api/sql.classes.ts create mode 100644 src/js/internal/sql/mysql.ts delete mode 100644 src/js/internal/sql/utils.ts create mode 100644 src/sql/mysql.zig create mode 100644 src/sql/mysql/AuthMethod.zig create mode 100644 src/sql/mysql/Capabilities.zig create mode 100644 src/sql/mysql/ConnectionState.zig create mode 100644 src/sql/mysql/MySQLConnection.zig create mode 100644 src/sql/mysql/MySQLContext.zig create mode 100644 src/sql/mysql/MySQLQuery.zig create mode 100644 src/sql/mysql/MySQLRequest.zig create mode 100644 src/sql/mysql/MySQLStatement.zig create mode 100644 src/sql/mysql/MySQLTypes.zig create mode 100644 src/sql/mysql/SSLMode.zig create mode 100644 src/sql/mysql/StatusFlags.zig create mode 100644 src/sql/mysql/TLSStatus.zig create mode 100644 src/sql/mysql/protocol/AnyMySQLError.zig create mode 100644 src/sql/mysql/protocol/Auth.zig create mode 100644 src/sql/mysql/protocol/AuthSwitchRequest.zig create mode 100644 src/sql/mysql/protocol/AuthSwitchResponse.zig create mode 100644 src/sql/mysql/protocol/CharacterSet.zig create mode 100644 src/sql/mysql/protocol/ColumnDefinition41.zig create mode 100644 src/sql/mysql/protocol/CommandType.zig create mode 100644 src/sql/mysql/protocol/DecodeBinaryValue.zig create mode 100644 src/sql/mysql/protocol/EOFPacket.zig create mode 100644 src/sql/mysql/protocol/EncodeInt.zig create mode 100644 src/sql/mysql/protocol/ErrorPacket.zig create mode 100644 src/sql/mysql/protocol/HandshakeResponse41.zig create mode 100644 src/sql/mysql/protocol/HandshakeV10.zig create mode 100644 src/sql/mysql/protocol/LocalInfileRequest.zig create mode 100644 src/sql/mysql/protocol/NewReader.zig create mode 100644 src/sql/mysql/protocol/NewWriter.zig create mode 100644 src/sql/mysql/protocol/OKPacket.zig create mode 100644 src/sql/mysql/protocol/PacketHeader.zig create mode 100644 src/sql/mysql/protocol/PacketType.zig create mode 100644 src/sql/mysql/protocol/PreparedStatement.zig create mode 100644 src/sql/mysql/protocol/Query.zig create mode 100644 src/sql/mysql/protocol/ResultSet.zig create mode 100644 src/sql/mysql/protocol/ResultSetHeader.zig create mode 100644 src/sql/mysql/protocol/Signature.zig create mode 100644 src/sql/mysql/protocol/StackReader.zig create mode 100644 src/sql/mysql/protocol/StmtPrepareOKPacket.zig delete mode 100644 src/sql/postgres/PostgresSQLQueryResultMode.zig rename src/sql/{postgres/PostgresCachedStructure.zig => shared/CachedStructure.zig} (100%) rename src/sql/{postgres/protocol => shared}/ColumnIdentifier.zig (95%) rename src/sql/{postgres => shared}/ConnectionFlags.zig (100%) rename src/sql/{postgres => shared}/Data.zig (52%) rename src/sql/{postgres => shared}/ObjectIterator.zig (100%) rename src/sql/{postgres => shared}/QueryBindingIterator.zig (100%) create mode 100644 src/sql/shared/SQLDataCell.zig create mode 100644 src/sql/shared/SQLQueryResultMode.zig create mode 100644 test/js/sql/sql-mysql.helpers.test.ts create mode 100644 test/js/sql/sql-mysql.test.ts create mode 100644 test/js/sql/sql-mysql.transactions.test.ts diff --git a/cmake/sources/JavaScriptSources.txt b/cmake/sources/JavaScriptSources.txt index 1ae3a19d0e..4202470fab 100644 --- a/cmake/sources/JavaScriptSources.txt +++ b/cmake/sources/JavaScriptSources.txt @@ -66,11 +66,11 @@ src/js/internal/primordials.js src/js/internal/promisify.ts src/js/internal/shared.ts src/js/internal/sql/errors.ts +src/js/internal/sql/mysql.ts src/js/internal/sql/postgres.ts src/js/internal/sql/query.ts src/js/internal/sql/shared.ts src/js/internal/sql/sqlite.ts -src/js/internal/sql/utils.ts src/js/internal/stream.promises.ts src/js/internal/stream.ts src/js/internal/streams/add-abort-signal.ts diff --git a/cmake/sources/ZigGeneratedClassesSources.txt b/cmake/sources/ZigGeneratedClassesSources.txt index 116f1cc26d..3bb2bdf968 100644 --- a/cmake/sources/ZigGeneratedClassesSources.txt +++ b/cmake/sources/ZigGeneratedClassesSources.txt @@ -6,7 +6,6 @@ src/bun.js/api/Glob.classes.ts src/bun.js/api/h2.classes.ts src/bun.js/api/html_rewriter.classes.ts src/bun.js/api/JSBundler.classes.ts -src/bun.js/api/postgres.classes.ts src/bun.js/api/ResumableSink.classes.ts src/bun.js/api/S3Client.classes.ts src/bun.js/api/S3Stat.classes.ts @@ -15,6 +14,7 @@ src/bun.js/api/Shell.classes.ts src/bun.js/api/ShellArgs.classes.ts src/bun.js/api/sockets.classes.ts src/bun.js/api/sourcemap.classes.ts +src/bun.js/api/sql.classes.ts src/bun.js/api/streams.classes.ts src/bun.js/api/valkey.classes.ts src/bun.js/api/zlib.classes.ts diff --git a/cmake/sources/ZigSources.txt b/cmake/sources/ZigSources.txt index f112c64494..e106f04854 100644 --- a/cmake/sources/ZigSources.txt +++ b/cmake/sources/ZigSources.txt @@ -884,30 +884,63 @@ src/sourcemap/JSSourceMap.zig src/sourcemap/LineOffsetTable.zig src/sourcemap/sourcemap.zig src/sourcemap/VLQ.zig +src/sql/mysql.zig +src/sql/mysql/AuthMethod.zig +src/sql/mysql/Capabilities.zig +src/sql/mysql/ConnectionState.zig +src/sql/mysql/MySQLConnection.zig +src/sql/mysql/MySQLContext.zig +src/sql/mysql/MySQLQuery.zig +src/sql/mysql/MySQLRequest.zig +src/sql/mysql/MySQLStatement.zig +src/sql/mysql/MySQLTypes.zig +src/sql/mysql/protocol/AnyMySQLError.zig +src/sql/mysql/protocol/Auth.zig +src/sql/mysql/protocol/AuthSwitchRequest.zig +src/sql/mysql/protocol/AuthSwitchResponse.zig +src/sql/mysql/protocol/CharacterSet.zig +src/sql/mysql/protocol/ColumnDefinition41.zig +src/sql/mysql/protocol/CommandType.zig +src/sql/mysql/protocol/DecodeBinaryValue.zig +src/sql/mysql/protocol/EncodeInt.zig +src/sql/mysql/protocol/EOFPacket.zig +src/sql/mysql/protocol/ErrorPacket.zig +src/sql/mysql/protocol/HandshakeResponse41.zig +src/sql/mysql/protocol/HandshakeV10.zig +src/sql/mysql/protocol/LocalInfileRequest.zig +src/sql/mysql/protocol/NewReader.zig +src/sql/mysql/protocol/NewWriter.zig +src/sql/mysql/protocol/OKPacket.zig +src/sql/mysql/protocol/PacketHeader.zig +src/sql/mysql/protocol/PacketType.zig +src/sql/mysql/protocol/PreparedStatement.zig +src/sql/mysql/protocol/Query.zig +src/sql/mysql/protocol/ResultSet.zig +src/sql/mysql/protocol/ResultSetHeader.zig +src/sql/mysql/protocol/Signature.zig +src/sql/mysql/protocol/StackReader.zig +src/sql/mysql/protocol/StmtPrepareOKPacket.zig +src/sql/mysql/SSLMode.zig +src/sql/mysql/StatusFlags.zig +src/sql/mysql/TLSStatus.zig src/sql/postgres.zig src/sql/postgres/AnyPostgresError.zig src/sql/postgres/AuthenticationState.zig src/sql/postgres/CommandTag.zig -src/sql/postgres/ConnectionFlags.zig -src/sql/postgres/Data.zig src/sql/postgres/DataCell.zig src/sql/postgres/DebugSocketMonitorReader.zig src/sql/postgres/DebugSocketMonitorWriter.zig -src/sql/postgres/ObjectIterator.zig -src/sql/postgres/PostgresCachedStructure.zig src/sql/postgres/PostgresProtocol.zig src/sql/postgres/PostgresRequest.zig src/sql/postgres/PostgresSQLConnection.zig src/sql/postgres/PostgresSQLContext.zig src/sql/postgres/PostgresSQLQuery.zig -src/sql/postgres/PostgresSQLQueryResultMode.zig src/sql/postgres/PostgresSQLStatement.zig src/sql/postgres/PostgresTypes.zig src/sql/postgres/protocol/ArrayList.zig src/sql/postgres/protocol/Authentication.zig src/sql/postgres/protocol/BackendKeyData.zig src/sql/postgres/protocol/Close.zig -src/sql/postgres/protocol/ColumnIdentifier.zig src/sql/postgres/protocol/CommandComplete.zig src/sql/postgres/protocol/CopyData.zig src/sql/postgres/protocol/CopyFail.zig @@ -940,7 +973,6 @@ src/sql/postgres/protocol/StartupMessage.zig src/sql/postgres/protocol/TransactionStatusIndicator.zig src/sql/postgres/protocol/WriteWrap.zig src/sql/postgres/protocol/zHelpers.zig -src/sql/postgres/QueryBindingIterator.zig src/sql/postgres/SASL.zig src/sql/postgres/Signature.zig src/sql/postgres/SocketMonitor.zig @@ -955,6 +987,14 @@ src/sql/postgres/types/json.zig src/sql/postgres/types/numeric.zig src/sql/postgres/types/PostgresString.zig src/sql/postgres/types/Tag.zig +src/sql/shared/CachedStructure.zig +src/sql/shared/ColumnIdentifier.zig +src/sql/shared/ConnectionFlags.zig +src/sql/shared/Data.zig +src/sql/shared/ObjectIterator.zig +src/sql/shared/QueryBindingIterator.zig +src/sql/shared/SQLDataCell.zig +src/sql/shared/SQLQueryResultMode.zig src/StandaloneModuleGraph.zig src/StaticHashMap.zig src/string.zig diff --git a/packages/bun-types/sql.d.ts b/packages/bun-types/sql.d.ts index a85278b8c5..b074e9d2a4 100644 --- a/packages/bun-types/sql.d.ts +++ b/packages/bun-types/sql.d.ts @@ -82,6 +82,13 @@ declare module "bun" { ); } + class MySQLError extends SQLError { + public readonly code: string; + public readonly errno: number | undefined; + public readonly sqlState: string | undefined; + constructor(message: string, options: { code: string; errno: number | undefined; sqlState: string | undefined }); + } + class SQLiteError extends SQLError { public readonly code: string; public readonly errno: number; @@ -128,7 +135,7 @@ declare module "bun" { onclose?: ((err: Error | null) => void) | undefined; } - interface PostgresOptions { + interface PostgresOrMySQLOptions { /** * Connection URL (can be string or URL object) */ @@ -196,7 +203,7 @@ declare module "bun" { * Database adapter/driver to use * @default "postgres" */ - adapter?: "postgres"; + adapter?: "postgres" | "mysql" | "mariadb"; /** * Maximum time in seconds to wait for connection to become available @@ -332,7 +339,7 @@ declare module "bun" { * }; * ``` */ - type Options = SQLiteOptions | PostgresOptions; + type Options = SQLiteOptions | PostgresOrMySQLOptions; /** * Represents a SQL query that can be executed, with additional control diff --git a/src/bun.js/api.zig b/src/bun.js/api.zig index fd1f1b66e6..caf67ce1ec 100644 --- a/src/bun.js/api.zig +++ b/src/bun.js/api.zig @@ -43,6 +43,7 @@ pub const MatchedRoute = @import("./api/filesystem_router.zig").MatchedRoute; pub const NativeBrotli = @import("./node/zlib/NativeBrotli.zig"); pub const NativeZlib = @import("./node/zlib/NativeZlib.zig"); pub const Postgres = @import("../sql/postgres.zig"); +pub const MySQL = @import("../sql/mysql.zig"); pub const ResolveMessage = @import("./ResolveMessage.zig").ResolveMessage; pub const Shell = @import("../shell/shell.zig"); pub const UDPSocket = @import("./api/bun/udp_socket.zig").UDPSocket; diff --git a/src/bun.js/api/Timer/EventLoopTimer.zig b/src/bun.js/api/Timer/EventLoopTimer.zig index f50270ef5c..e4fb58ab22 100644 --- a/src/bun.js/api/Timer/EventLoopTimer.zig +++ b/src/bun.js/api/Timer/EventLoopTimer.zig @@ -59,6 +59,8 @@ pub const Tag = if (Environment.isWindows) enum { WTFTimer, PostgresSQLConnectionTimeout, PostgresSQLConnectionMaxLifetime, + MySQLConnectionTimeout, + MySQLConnectionMaxLifetime, ValkeyConnectionTimeout, ValkeyConnectionReconnect, SubprocessTimeout, @@ -80,6 +82,8 @@ pub const Tag = if (Environment.isWindows) enum { .WTFTimer => WTFTimer, .PostgresSQLConnectionTimeout => jsc.Postgres.PostgresSQLConnection, .PostgresSQLConnectionMaxLifetime => jsc.Postgres.PostgresSQLConnection, + .MySQLConnectionTimeout => jsc.MySQL.MySQLConnection, + .MySQLConnectionMaxLifetime => jsc.MySQL.MySQLConnection, .SubprocessTimeout => jsc.Subprocess, .ValkeyConnectionReconnect => jsc.API.Valkey, .ValkeyConnectionTimeout => jsc.API.Valkey, @@ -101,6 +105,8 @@ pub const Tag = if (Environment.isWindows) enum { DNSResolver, PostgresSQLConnectionTimeout, PostgresSQLConnectionMaxLifetime, + MySQLConnectionTimeout, + MySQLConnectionMaxLifetime, ValkeyConnectionTimeout, ValkeyConnectionReconnect, SubprocessTimeout, @@ -121,6 +127,8 @@ pub const Tag = if (Environment.isWindows) enum { .DNSResolver => DNSResolver, .PostgresSQLConnectionTimeout => jsc.Postgres.PostgresSQLConnection, .PostgresSQLConnectionMaxLifetime => jsc.Postgres.PostgresSQLConnection, + .MySQLConnectionTimeout => jsc.MySQL.MySQLConnection, + .MySQLConnectionMaxLifetime => jsc.MySQL.MySQLConnection, .ValkeyConnectionTimeout => jsc.API.Valkey, .ValkeyConnectionReconnect => jsc.API.Valkey, .SubprocessTimeout => jsc.Subprocess, @@ -189,6 +197,8 @@ pub fn fire(self: *Self, now: *const timespec, vm: *VirtualMachine) Arm { switch (self.tag) { .PostgresSQLConnectionTimeout => return @as(*api.Postgres.PostgresSQLConnection, @alignCast(@fieldParentPtr("timer", self))).onConnectionTimeout(), .PostgresSQLConnectionMaxLifetime => return @as(*api.Postgres.PostgresSQLConnection, @alignCast(@fieldParentPtr("max_lifetime_timer", self))).onMaxLifetimeTimeout(), + .MySQLConnectionTimeout => return @as(*api.MySQL.MySQLConnection, @alignCast(@fieldParentPtr("timer", self))).onConnectionTimeout(), + .MySQLConnectionMaxLifetime => return @as(*api.MySQL.MySQLConnection, @alignCast(@fieldParentPtr("max_lifetime_timer", self))).onMaxLifetimeTimeout(), .ValkeyConnectionTimeout => return @as(*api.Valkey, @alignCast(@fieldParentPtr("timer", self))).onConnectionTimeout(), .ValkeyConnectionReconnect => return @as(*api.Valkey, @alignCast(@fieldParentPtr("reconnect_timer", self))).onReconnectTimer(), .DevServerMemoryVisualizerTick => return bun.bake.DevServer.emitMemoryVisualizerMessageTimer(self, now), diff --git a/src/bun.js/api/postgres.classes.ts b/src/bun.js/api/postgres.classes.ts deleted file mode 100644 index a210706462..0000000000 --- a/src/bun.js/api/postgres.classes.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { define } from "../../codegen/class-definitions"; - -export default [ - define({ - name: "PostgresSQLConnection", - construct: true, - finalize: true, - configurable: false, - hasPendingActivity: true, - klass: { - // escapeString: { - // fn: "escapeString", - // }, - // escapeIdentifier: { - // fn: "escapeIdentifier", - // }, - }, - JSType: "0b11101110", - proto: { - close: { - fn: "doClose", - }, - connected: { - getter: "getConnected", - }, - ref: { - fn: "doRef", - }, - unref: { - fn: "doUnref", - }, - flush: { - fn: "doFlush", - }, - queries: { - getter: "getQueries", - this: true, - }, - onconnect: { - getter: "getOnConnect", - setter: "setOnConnect", - this: true, - }, - onclose: { - getter: "getOnClose", - setter: "setOnClose", - this: true, - }, - }, - values: ["onconnect", "onclose", "queries"], - }), - define({ - name: "PostgresSQLQuery", - construct: true, - finalize: true, - configurable: false, - - JSType: "0b11101110", - klass: {}, - proto: { - run: { - fn: "doRun", - length: 2, - }, - cancel: { - fn: "doCancel", - length: 0, - }, - done: { - fn: "doDone", - length: 0, - }, - setMode: { - fn: "setMode", - length: 1, - }, - setPendingValue: { - fn: "setPendingValue", - length: 1, - }, - }, - values: ["pendingValue", "target", "columns", "binding"], - estimatedSize: true, - }), -]; diff --git a/src/bun.js/api/sql.classes.ts b/src/bun.js/api/sql.classes.ts new file mode 100644 index 0000000000..db29a3dc1f --- /dev/null +++ b/src/bun.js/api/sql.classes.ts @@ -0,0 +1,94 @@ +import { define } from "../../codegen/class-definitions"; + +const types = ["PostgresSQL", "MySQL"]; +const classes = []; +for (const type of types) { + classes.push( + define({ + name: `${type}Connection`, + construct: true, + finalize: true, + configurable: false, + hasPendingActivity: true, + klass: { + // escapeString: { + // fn: "escapeString", + // }, + // escapeIdentifier: { + // fn: "escapeIdentifier", + // }, + }, + JSType: "0b11101110", + proto: { + close: { + fn: "doClose", + }, + connected: { + getter: "getConnected", + }, + ref: { + fn: "doRef", + }, + unref: { + fn: "doUnref", + }, + flush: { + fn: "doFlush", + }, + queries: { + getter: "getQueries", + this: true, + }, + onconnect: { + getter: "getOnConnect", + setter: "setOnConnect", + this: true, + }, + onclose: { + getter: "getOnClose", + setter: "setOnClose", + this: true, + }, + }, + values: ["onconnect", "onclose", "queries"], + }), + ); + + classes.push( + define({ + name: `${type}Query`, + construct: true, + finalize: true, + configurable: false, + + JSType: "0b11101110", + klass: {}, + proto: { + run: { + fn: "doRun", + length: 2, + }, + cancel: { + fn: "doCancel", + length: 0, + }, + done: { + fn: "doDone", + length: 0, + }, + setMode: { + fn: "setMode", + length: 1, + }, + setPendingValue: { + fn: "setPendingValue", + length: 1, + }, + }, + values: ["pendingValue", "target", "columns", "binding"], + estimatedSize: true, + }), + ); +} + +export default classes; diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index fcdf9ef6c2..d8a12b99e3 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -204,6 +204,10 @@ const errors: ErrorCodeMapping = [ ["ERR_POSTGRES_UNSUPPORTED_BYTEA_FORMAT", TypeError, "PostgresError"], ["ERR_POSTGRES_UNSUPPORTED_INTEGER_SIZE", TypeError, "PostgresError"], ["ERR_POSTGRES_UNSUPPORTED_NUMERIC_FORMAT", TypeError, "PostgresError"], + ["ERR_MYSQL_CONNECTION_CLOSED", Error, "MySQLError"], + ["ERR_MYSQL_CONNECTION_TIMEOUT", Error, "MySQLError"], + ["ERR_MYSQL_IDLE_TIMEOUT", Error, "MySQLError"], + ["ERR_MYSQL_LIFETIME_TIMEOUT", Error, "MySQLError"], ["ERR_UNHANDLED_REJECTION", Error, "UnhandledPromiseRejection"], ["ERR_REQUIRE_ASYNC_MODULE", Error], ["ERR_S3_INVALID_ENDPOINT", Error], diff --git a/src/bun.js/bindings/JSGlobalObject.zig b/src/bun.js/bindings/JSGlobalObject.zig index 7cbada9a94..64b40096f0 100644 --- a/src/bun.js/bindings/JSGlobalObject.zig +++ b/src/bun.js/bindings/JSGlobalObject.zig @@ -21,6 +21,10 @@ pub const JSGlobalObject = opaque { JSGlobalObject__throwOutOfMemoryError(this); return .zero; } + pub fn gregorianDateTimeToMS(this: *jsc.JSGlobalObject, year: i32, month: i32, day: i32, hour: i32, minute: i32, second: i32, millisecond: i32) bun.JSError!f64 { + jsc.markBinding(@src()); + return bun.cpp.Bun__gregorianDateTimeToMS(this, year, month, day, hour, minute, second, millisecond); + } pub fn throwTODO(this: *JSGlobalObject, msg: []const u8) bun.JSError { const err = this.createErrorInstance("{s}", .{msg}); @@ -667,6 +671,40 @@ pub const JSGlobalObject = opaque { always_allow_zero: bool = false, }; + pub fn validateBigIntRange(this: *JSGlobalObject, value: JSValue, comptime T: type, default: T, comptime range: IntegerRange) bun.JSError!T { + if (value.isUndefined() or value == .zero) { + return 0; + } + + const TypeInfo = @typeInfo(T); + if (TypeInfo != .int) { + @compileError("T must be an integer type"); + } + const signed = TypeInfo.int.signedness == .signed; + + const min_t = comptime @max(range.min, std.math.minInt(T)); + const max_t = comptime @min(range.max, std.math.maxInt(T)); + if (value.isBigInt()) { + if (signed) { + if (value.isBigIntInInt64Range(min_t, max_t)) { + return value.toInt64(); + } + } else { + if (value.isBigIntInUInt64Range(min_t, max_t)) { + return value.toUInt64NoTruncate(); + } + } + return this.ERR(.OUT_OF_RANGE, "The value is out of range. It must be >= {d} and <= {d}.", .{ min_t, max_t }).throw(); + } + + return try this.validateIntegerRange(value, T, default, .{ + .min = comptime @max(min_t, jsc.MIN_SAFE_INTEGER), + .max = comptime @min(max_t, jsc.MAX_SAFE_INTEGER), + .field_name = range.field_name, + .always_allow_zero = range.always_allow_zero, + }); + } + pub fn validateIntegerRange(this: *JSGlobalObject, value: JSValue, comptime T: type, default: T, comptime range: IntegerRange) bun.JSError!T { if (value.isUndefined() or value == .zero) { return default; diff --git a/src/bun.js/bindings/JSValue.zig b/src/bun.js/bindings/JSValue.zig index 9d4dec28d0..2136786844 100644 --- a/src/bun.js/bindings/JSValue.zig +++ b/src/bun.js/bindings/JSValue.zig @@ -33,6 +33,13 @@ pub const JSValue = enum(i64) { return @as(JSValue, @enumFromInt(@as(i64, @bitCast(@intFromPtr(ptr))))); } + pub fn isBigIntInUInt64Range(this: JSValue, min: u64, max: u64) bool { + return bun.cpp.JSC__isBigIntInUInt64Range(this, min, max); + } + + pub fn isBigIntInInt64Range(this: JSValue, min: i64, max: i64) bool { + return bun.cpp.JSC__isBigIntInInt64Range(this, min, max); + } pub fn coerceToInt32(this: JSValue, globalThis: *jsc.JSGlobalObject) bun.JSError!i32 { return bun.cpp.JSC__JSValue__coerceToInt32(this, globalThis); } diff --git a/src/bun.js/bindings/SQLClient.cpp b/src/bun.js/bindings/SQLClient.cpp index af1eab7776..012bd68a77 100644 --- a/src/bun.js/bindings/SQLClient.cpp +++ b/src/bun.js/bindings/SQLClient.cpp @@ -64,6 +64,7 @@ typedef union DataCellValue { double number; int32_t integer; int64_t bigint; + uint64_t unsigned_bigint; uint8_t boolean; double date; double date_with_time_zone; @@ -90,6 +91,7 @@ enum class DataCellTag : uint8_t { TypedArray = 11, Raw = 12, UnsignedInteger = 13, + UnsignedBigint = 14, }; enum class BunResultMode : uint8_t { @@ -161,6 +163,9 @@ static JSC::JSValue toJS(JSC::VM& vm, JSC::JSGlobalObject* globalObject, DataCel case DataCellTag::Bigint: return JSC::JSBigInt::createFrom(globalObject, cell.value.bigint); break; + case DataCellTag::UnsignedBigint: + return JSC::JSBigInt::createFrom(globalObject, cell.value.unsigned_bigint); + break; case DataCellTag::Boolean: return jsBoolean(cell.value.boolean); break; @@ -317,7 +322,6 @@ static JSC::JSValue toJS(JSC::Structure* structure, DataCell* cells, uint32_t co ASSERT(!cell.isIndexedColumn()); ASSERT(cell.isNamedColumn()); if (names.has_value()) { - auto name = names.value()[i]; object->putDirect(vm, Identifier::fromString(vm, name.name.toWTFString()), value); diff --git a/src/bun.js/bindings/bindings.cpp b/src/bun.js/bindings/bindings.cpp index 9205857e71..7cd1a672a5 100644 --- a/src/bun.js/bindings/bindings.cpp +++ b/src/bun.js/bindings/bindings.cpp @@ -73,6 +73,8 @@ #include "wtf/text/StringImpl.h" #include "wtf/text/StringView.h" #include "wtf/text/WTFString.h" +#include "wtf/GregorianDateTime.h" + #include "JavaScriptCore/FunctionPrototype.h" #include "JSFetchHeaders.h" #include "FetchHeaders.h" @@ -5889,6 +5891,36 @@ extern "C" void JSC__JSValue__forEachPropertyNonIndexed(JSC::EncodedJSValue JSVa JSC__JSValue__forEachPropertyImpl(JSValue0, globalObject, arg2, iter); } +extern "C" [[ZIG_EXPORT(nothrow)]] bool JSC__isBigIntInUInt64Range(JSC::EncodedJSValue value, uint64_t max, uint64_t min) +{ + JSValue jsValue = JSValue::decode(value); + if (!jsValue.isHeapBigInt()) + return false; + + JSC::JSBigInt* bigInt = jsValue.asHeapBigInt(); + auto result = bigInt->compare(bigInt, min); + if (result == JSBigInt::ComparisonResult::GreaterThan || result == JSBigInt::ComparisonResult::Equal) { + return true; + } + result = bigInt->compare(bigInt, max); + return result == JSBigInt::ComparisonResult::LessThan || result == JSBigInt::ComparisonResult::Equal; +} + +extern "C" [[ZIG_EXPORT(nothrow)]] bool JSC__isBigIntInInt64Range(JSC::EncodedJSValue value, int64_t max, int64_t min) +{ + JSValue jsValue = JSValue::decode(value); + if (!jsValue.isHeapBigInt()) + return false; + + JSC::JSBigInt* bigInt = jsValue.asHeapBigInt(); + auto result = bigInt->compare(bigInt, min); + if (result == JSBigInt::ComparisonResult::GreaterThan || result == JSBigInt::ComparisonResult::Equal) { + return true; + } + result = bigInt->compare(bigInt, max); + return result == JSBigInt::ComparisonResult::LessThan || result == JSBigInt::ComparisonResult::Equal; +} + [[ZIG_EXPORT(check_slow)]] void JSC__JSValue__forEachPropertyOrdered(JSC::EncodedJSValue JSValue0, JSC::JSGlobalObject* globalObject, void* arg2, void (*iter)([[ZIG_NONNULL]] JSC::JSGlobalObject* arg0, void* ctx, [[ZIG_NONNULL]] ZigString* arg2, JSC::EncodedJSValue JSValue3, bool isSymbol, bool isPrivateSymbol)) { JSC::JSValue value = JSC::JSValue::decode(JSValue0); @@ -6208,6 +6240,19 @@ extern "C" [[ZIG_EXPORT(check_slow)]] double Bun__parseDate(JSC::JSGlobalObject* return vm.dateCache.parseDate(globalObject, vm, str->toWTFString()); } +extern "C" [[ZIG_EXPORT(check_slow)]] double Bun__gregorianDateTimeToMS(JSC::JSGlobalObject* globalObject, int year, int month, int day, int hour, int minute, int second, int millisecond) +{ + auto& vm = JSC::getVM(globalObject); + WTF::GregorianDateTime dateTime; + dateTime.setYear(year); + dateTime.setMonth(month - 1); + dateTime.setMonthDay(day); + dateTime.setHour(hour); + dateTime.setMinute(minute); + dateTime.setSecond(second); + return vm.dateCache.gregorianDateTimeToMS(dateTime, millisecond, WTF::TimeType::LocalTime); +} + extern "C" EncodedJSValue JSC__JSValue__dateInstanceFromNumber(JSC::JSGlobalObject* globalObject, double unixTimestamp) { auto& vm = JSC::getVM(globalObject); diff --git a/src/bun.js/bindings/generated_classes_list.zig b/src/bun.js/bindings/generated_classes_list.zig index d5fd4778bc..e47b2877dd 100644 --- a/src/bun.js/bindings/generated_classes_list.zig +++ b/src/bun.js/bindings/generated_classes_list.zig @@ -69,7 +69,9 @@ pub const Classes = struct { pub const BlobInternalReadableStreamSource = webcore.ByteBlobLoader.Source; pub const BytesInternalReadableStreamSource = webcore.ByteStream.Source; pub const PostgresSQLConnection = api.Postgres.PostgresSQLConnection; + pub const MySQLConnection = api.MySQL.MySQLConnection; pub const PostgresSQLQuery = api.Postgres.PostgresSQLQuery; + pub const MySQLQuery = api.MySQL.MySQLQuery; pub const TextEncoderStreamEncoder = webcore.TextEncoderStreamEncoder; pub const NativeZlib = api.NativeZlib; pub const NativeBrotli = api.NativeBrotli; diff --git a/src/bun.js/rare_data.zig b/src/bun.js/rare_data.zig index aa22c96454..261c77e7d2 100644 --- a/src/bun.js/rare_data.zig +++ b/src/bun.js/rare_data.zig @@ -7,6 +7,7 @@ stderr_store: ?*Blob.Store = null, stdin_store: ?*Blob.Store = null, stdout_store: ?*Blob.Store = null, +mysql_context: bun.api.MySQL.MySQLContext = .{}, postgresql_context: bun.api.Postgres.PostgresSQLContext = .{}, entropy_cache: ?*EntropyCache = null, diff --git a/src/fmt.zig b/src/fmt.zig index a699bfb000..a2938ab668 100644 --- a/src/fmt.zig +++ b/src/fmt.zig @@ -1836,7 +1836,6 @@ fn OutOfRangeFormatter(comptime T: type) type { } else if (T == bun.String) { return BunStringOutOfRangeFormatter; } - return IntOutOfRangeFormatter; } diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index ffc317bad1..ffd108424c 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -1,13 +1,15 @@ +import type { MySQLAdapter } from "internal/sql/mysql"; import type { PostgresAdapter } from "internal/sql/postgres"; import type { BaseQueryHandle, Query } from "internal/sql/query"; import type { SQLHelper } from "internal/sql/shared"; const { Query, SQLQueryFlags } = require("internal/sql/query"); const { PostgresAdapter } = require("internal/sql/postgres"); +const { MySQLAdapter } = require("internal/sql/mysql"); const { SQLiteAdapter } = require("internal/sql/sqlite"); const { SQLHelper, parseOptions } = require("internal/sql/shared"); -const { connectionClosedError } = require("internal/sql/utils"); -const { SQLError, PostgresError, SQLiteError } = require("internal/sql/errors"); + +const { SQLError, PostgresError, SQLiteError, MySQLError } = require("internal/sql/errors"); const defineProperties = Object.defineProperties; @@ -29,6 +31,8 @@ function adapterFromOptions(options: Bun.SQL.__internal.DefinedOptions) { switch (options.adapter) { case "postgres": return new PostgresAdapter(options); + case "mysql": + return new MySQLAdapter(options); case "sqlite": return new SQLiteAdapter(options); default: @@ -41,7 +45,6 @@ const SQL: typeof Bun.SQL = function SQL( definitelyOptionsButMaybeEmpty: Bun.SQL.Options = {}, ): Bun.SQL { const connectionInfo = parseOptions(stringOrUrlOrOptions, definitelyOptionsButMaybeEmpty); - const pool = adapterFromOptions(connectionInfo); function onQueryDisconnected(this: Query, err: Error) { @@ -54,11 +57,7 @@ const SQL: typeof Bun.SQL = function SQL( // query is cancelled when waiting for a connection from the pool if (query.cancelled) { - return query.reject( - new PostgresError("Query cancelled", { - code: "ERR_POSTGRES_QUERY_CANCELLED", - }), - ); + return query.reject(pool.queryCancelledError()); } } @@ -76,11 +75,7 @@ const SQL: typeof Bun.SQL = function SQL( // query is cancelled when waiting for a connection from the pool if (query.cancelled) { pool.release(connectionHandle); // release the connection back to the pool - return query.reject( - new PostgresError("Query cancelled", { - code: "ERR_POSTGRES_QUERY_CANCELLED", - }), - ); + return query.reject(pool.queryCancelledError()); } if (connectionHandle.bindQuery) { @@ -106,11 +101,7 @@ const SQL: typeof Bun.SQL = function SQL( // query is cancelled if (!handle || query.cancelled) { - return query.reject( - new PostgresError("Query cancelled", { - code: "ERR_POSTGRES_QUERY_CANCELLED", - }), - ); + return query.reject(pool.queryCancelledError()); } pool.connect(onQueryConnected.bind(query, handle)); @@ -163,11 +154,7 @@ const SQL: typeof Bun.SQL = function SQL( // query is cancelled if (query.cancelled) { transactionQueries.delete(query); - return query.reject( - new PostgresError("Query cancelled", { - code: "ERR_POSTGRES_QUERY_CANCELLED", - }), - ); + return query.reject(pool.queryCancelledError()); } query.finally(onTransactionQueryDisconnected.bind(transactionQueries, query)); @@ -275,7 +262,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.reject(connectionClosedError()); + return Promise.reject(pool.connectionClosedError()); } if ($isArray(strings)) { // detect if is tagged template @@ -303,7 +290,7 @@ const SQL: typeof Bun.SQL = function SQL( reserved_sql.connect = () => { if (state.connectionState & ReservedConnectionState.closed) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } return Promise.resolve(reserved_sql); }; @@ -334,7 +321,7 @@ const SQL: typeof Bun.SQL = function SQL( reserved_sql.beginDistributed = (name: string, fn: TransactionCallback) => { // begin is allowed the difference is that we need to make sure to use the same connection and never release it if (state.connectionState & ReservedConnectionState.closed) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } let callback = fn; @@ -358,7 +345,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } let callback = fn; let options: string | undefined = options_or_fn as unknown as string; @@ -381,7 +368,7 @@ const SQL: typeof Bun.SQL = function SQL( reserved_sql.flush = () => { if (state.connectionState & ReservedConnectionState.closed) { - throw connectionClosedError(); + throw this.connectionClosedError(); } // Use pooled connection's flush if available, otherwise use adapter's flush if (pooledConnection.flush) { @@ -441,7 +428,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } // just release the connection back to the pool state.connectionState |= ReservedConnectionState.closed; @@ -564,7 +551,7 @@ const SQL: typeof Bun.SQL = function SQL( function run_internal_transaction_sql(string) { if (state.connectionState & ReservedConnectionState.closed) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } return unsafeQueryFromTransaction(string, [], pooledConnection, state.queries); } @@ -576,7 +563,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } if ($isArray(strings)) { // detect if is tagged template @@ -605,7 +592,7 @@ const SQL: typeof Bun.SQL = function SQL( transaction_sql.connect = () => { if (state.connectionState & ReservedConnectionState.closed) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } return Promise.resolve(transaction_sql); @@ -629,29 +616,23 @@ const SQL: typeof Bun.SQL = function SQL( // begin is not allowed on a transaction we need to use savepoint() instead transaction_sql.begin = function () { if (distributed) { - throw new PostgresError("cannot call begin inside a distributed transaction", { - code: "ERR_POSTGRES_INVALID_TRANSACTION_STATE", - }); + throw pool.invalidTransactionStateError("cannot call begin inside a distributed transaction"); } - throw new PostgresError("cannot call begin inside a transaction use savepoint() instead", { - code: "POSTGRES_INVALID_TRANSACTION_STATE", - }); + throw pool.invalidTransactionStateError("cannot call begin inside a transaction use savepoint() instead"); }; transaction_sql.beginDistributed = function () { if (distributed) { - throw new PostgresError("cannot call beginDistributed inside a distributed transaction", { - code: "ERR_POSTGRES_INVALID_TRANSACTION_STATE", - }); + throw pool.invalidTransactionStateError("cannot call beginDistributed inside a distributed transaction"); } - throw new PostgresError("cannot call beginDistributed inside a transaction use savepoint() instead", { - code: "POSTGRES_INVALID_TRANSACTION_STATE", - }); + throw pool.invalidTransactionStateError( + "cannot call beginDistributed inside a transaction use savepoint() instead", + ); }; transaction_sql.flush = function () { if (state.connectionState & ReservedConnectionState.closed) { - throw connectionClosedError(); + throw pool.connectionClosedError(); } // Use pooled connection's flush if available, otherwise use adapter's flush if (pooledConnection.flush) { @@ -740,9 +721,7 @@ const SQL: typeof Bun.SQL = function SQL( } if (distributed) { transaction_sql.savepoint = async (_fn: TransactionCallback, _name?: string): Promise => { - throw new PostgresError("cannot call savepoint inside a distributed transaction", { - code: "ERR_POSTGRES_INVALID_TRANSACTION_STATE", - }); + throw pool.invalidTransactionStateError("cannot call savepoint inside a distributed transaction"); }; } else { transaction_sql.savepoint = async (fn: TransactionCallback, name?: string): Promise => { @@ -752,7 +731,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - throw connectionClosedError(); + throw this.connectionClosedError(); } if ($isCallable(name)) { @@ -837,7 +816,7 @@ const SQL: typeof Bun.SQL = function SQL( sql.reserve = () => { if (pool.closed) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } // Check if adapter supports reserved connections @@ -852,7 +831,7 @@ const SQL: typeof Bun.SQL = function SQL( }; sql.rollbackDistributed = async function (name: string) { if (pool.closed) { - throw connectionClosedError(); + throw this.connectionClosedError(); } if (!pool.getRollbackDistributedSQL) { @@ -865,7 +844,7 @@ const SQL: typeof Bun.SQL = function SQL( sql.commitDistributed = async function (name: string) { if (pool.closed) { - throw connectionClosedError(); + throw this.connectionClosedError(); } if (!pool.getCommitDistributedSQL) { @@ -878,7 +857,7 @@ const SQL: typeof Bun.SQL = function SQL( sql.beginDistributed = (name: string, fn: TransactionCallback) => { if (pool.closed) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } let callback = fn; @@ -897,7 +876,7 @@ const SQL: typeof Bun.SQL = function SQL( sql.begin = (options_or_fn: string | TransactionCallback, fn?: TransactionCallback) => { if (pool.closed) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } let callback = fn; let options: string | undefined = options_or_fn as unknown as string; @@ -917,7 +896,7 @@ const SQL: typeof Bun.SQL = function SQL( }; sql.connect = () => { if (pool.closed) { - return Promise.reject(connectionClosedError()); + return Promise.reject(this.connectionClosedError()); } if (pool.isConnected()) { @@ -1045,6 +1024,7 @@ defineProperties(defaultSQLObject, { SQL.SQLError = SQLError; SQL.PostgresError = PostgresError; SQL.SQLiteError = SQLiteError; +SQL.MySQLError = MySQLError; // // Helper functions for native code to create error instances // // These are internal functions used by Zig/C++ code @@ -1082,5 +1062,6 @@ export default { postgres: SQL, SQLError, PostgresError, + MySQLError, SQLiteError, }; diff --git a/src/js/internal/sql/errors.ts b/src/js/internal/sql/errors.ts index a2f5d5a98a..408090085b 100644 --- a/src/js/internal/sql/errors.ts +++ b/src/js/internal/sql/errors.ts @@ -92,4 +92,24 @@ class SQLiteError extends SQLError implements Bun.SQL.SQLiteError { } } -export default { PostgresError, SQLError, SQLiteError }; +export interface MySQLErrorOptions { + code: string; + errno: number | undefined; + sqlState: string | undefined; +} + +class MySQLError extends SQLError implements Bun.SQL.MySQLError { + public readonly code: string; + public readonly errno: number | undefined; + public readonly sqlState: string | undefined; + + constructor(message: string, options: MySQLErrorOptions) { + super(message); + + this.name = "MySQLError"; + this.code = options.code; + this.errno = options.errno; + this.sqlState = options.sqlState; + } +} +export default { PostgresError, SQLError, SQLiteError, MySQLError }; diff --git a/src/js/internal/sql/mysql.ts b/src/js/internal/sql/mysql.ts new file mode 100644 index 0000000000..4d121f84b9 --- /dev/null +++ b/src/js/internal/sql/mysql.ts @@ -0,0 +1,1181 @@ +import type { MySQLErrorOptions } from "internal/sql/errors"; +import type { Query } from "./query"; +import type { DatabaseAdapter, SQLHelper, SQLResultArray, SSLMode } from "./shared"; +const { SQLHelper, SSLMode, SQLResultArray } = require("internal/sql/shared"); +const { + Query, + SQLQueryFlags, + symbols: { _strings, _values, _flags, _results, _handle }, +} = require("internal/sql/query"); +const { MySQLError } = require("internal/sql/errors"); + +const { + createConnection: createMySQLConnection, + createQuery: createMySQLQuery, + init: initMySQL, +} = $zig("mysql.zig", "createBinding") as MySQLDotZig; + +function wrapError(error: Error | MySQLErrorOptions) { + if (Error.isError(error)) { + return error; + } + return new MySQLError(error.message, error); +} +initMySQL( + function onResolveMySQLQuery(query, result, commandTag, count, queries, is_last) { + /// simple queries + if (query[_flags] & SQLQueryFlags.simple) { + $assert(result instanceof SQLResultArray, "Invalid result array"); + // prepare for next query + query[_handle].setPendingValue(new SQLResultArray()); + + result.count = count || 0; + const last_result = query[_results]; + + if (!last_result) { + query[_results] = result; + } else { + if (last_result instanceof SQLResultArray) { + // multiple results + query[_results] = [last_result, result]; + } else { + // 3 or more results + last_result.push(result); + } + } + if (is_last) { + if (queries) { + const queriesIndex = queries.indexOf(query); + if (queriesIndex !== -1) { + queries.splice(queriesIndex, 1); + } + } + try { + query.resolve(query[_results]); + } catch {} + } + return; + } + /// prepared statements + $assert(result instanceof SQLResultArray, "Invalid result array"); + + result.count = count || 0; + if (queries) { + const queriesIndex = queries.indexOf(query); + if (queriesIndex !== -1) { + queries.splice(queriesIndex, 1); + } + } + try { + query.resolve(result); + } catch {} + }, + + function onRejectMySQLQuery(query: Query, reject: Error | MySQLErrorOptions, queries: Query[]) { + reject = wrapError(reject); + if (queries) { + const queriesIndex = queries.indexOf(query); + if (queriesIndex !== -1) { + queries.splice(queriesIndex, 1); + } + } + + try { + query.reject(reject as Error); + } catch {} + }, +); + +export interface MySQLDotZig { + init: ( + onResolveQuery: ( + query: Query, + result: SQLResultArray, + commandTag: string, + count: number, + queries: any, + is_last: boolean, + ) => void, + onRejectQuery: (query: Query, err: Error, queries) => void, + ) => void; + createConnection: ( + hostname: string | undefined, + port: number, + username: string, + password: string, + databae: string, + sslmode: SSLMode, + tls: Bun.TLSOptions | boolean | null, // boolean true => empty TLSOptions object `{}`, boolean false or null => nothing + query: string, + path: string, + onConnected: (err: Error | null, connection: $ZigGeneratedClasses.MySQLConnection) => void, + onDisconnected: (err: Error | null, connection: $ZigGeneratedClasses.MySQLConnection) => void, + idleTimeout: number, + connectionTimeout: number, + maxLifetime: number, + useUnnamedPreparedStatements: boolean, + ) => $ZigGeneratedClasses.MySQLConnection; + createQuery: ( + sql: string, + values: unknown[], + pendingValue: SQLResultArray, + columns: string[] | undefined, + bigint: boolean, + simple: boolean, + ) => $ZigGeneratedClasses.MySQLSQLQuery; +} + +const enum SQLCommand { + insert = 0, + update = 1, + updateSet = 2, + where = 3, + whereIn = 4, + none = -1, +} +export type { SQLCommand }; + +function commandToString(command: SQLCommand): string { + switch (command) { + case SQLCommand.insert: + return "INSERT"; + case SQLCommand.updateSet: + case SQLCommand.update: + return "UPDATE"; + case SQLCommand.whereIn: + case SQLCommand.where: + return "WHERE"; + default: + return ""; + } +} + +function detectCommand(query: string): SQLCommand { + const text = query.toLowerCase().trim(); + const text_len = text.length; + + let token = ""; + let command = SQLCommand.none; + let quoted = false; + for (let i = 0; i < text_len; i++) { + const char = text[i]; + switch (char) { + case " ": // Space + case "\n": // Line feed + case "\t": // Tab character + case "\r": // Carriage return + case "\f": // Form feed + case "\v": { + switch (token) { + case "insert": { + if (command === SQLCommand.none) { + return SQLCommand.insert; + } + return command; + } + case "update": { + if (command === SQLCommand.none) { + command = SQLCommand.update; + token = ""; + continue; // try to find SET + } + return command; + } + case "where": { + command = SQLCommand.where; + token = ""; + continue; // try to find IN + } + case "set": { + if (command === SQLCommand.update) { + command = SQLCommand.updateSet; + token = ""; + continue; // try to find WHERE + } + return command; + } + case "in": { + if (command === SQLCommand.where) { + return SQLCommand.whereIn; + } + return command; + } + default: { + token = ""; + continue; + } + } + } + default: { + // skip quoted commands + if (char === '"') { + quoted = !quoted; + continue; + } + if (!quoted) { + token += char; + } + } + } + } + if (token) { + switch (command) { + case SQLCommand.none: { + switch (token) { + case "insert": + return SQLCommand.insert; + case "update": + return SQLCommand.update; + case "where": + return SQLCommand.where; + default: + return SQLCommand.none; + } + } + case SQLCommand.update: { + if (token === "set") { + return SQLCommand.updateSet; + } + return SQLCommand.update; + } + case SQLCommand.where: { + if (token === "in") { + return SQLCommand.whereIn; + } + return SQLCommand.where; + } + } + } + + return command; +} + +const enum PooledConnectionState { + pending = 0, + connected = 1, + closed = 2, +} + +const enum PooledConnectionFlags { + /// canBeConnected is used to indicate that at least one time we were able to connect to the database + canBeConnected = 1 << 0, + /// reserved is used to indicate that the connection is currently reserved + reserved = 1 << 1, + /// preReserved is used to indicate that the connection will be reserved in the future when queryCount drops to 0 + preReserved = 1 << 2, +} + +function onQueryFinish(this: PooledMySQLConnection, onClose: (err: Error) => void) { + this.queries.delete(onClose); + this.adapter.release(this); +} + +class PooledMySQLConnection { + private static async createConnection( + options: Bun.SQL.__internal.DefinedMySQLOptions, + onConnected: (err: Error | null, connection: $ZigGeneratedClasses.MySQLSQLConnection) => void, + onClose: (err: Error | null) => void, + ): Promise<$ZigGeneratedClasses.MySQLSQLConnection | null> { + const { + hostname, + port, + username, + tls, + query, + database, + sslMode, + idleTimeout = 0, + connectionTimeout = 30 * 1000, + maxLifetime = 0, + prepare = true, + + // @ts-expect-error path is currently removed from the types + path, + } = options; + + let password: Bun.MaybePromise | string | undefined | (() => Bun.MaybePromise) = options.password; + + try { + if (typeof password === "function") { + password = password(); + + if (password && $isPromise(password)) { + password = await password; + } + } + + return createMySQLConnection( + hostname, + Number(port), + username || "", + password || "", + database || "", + // > The default value for sslmode is prefer. As is shown in the table, this + // makes no sense from a security point of view, and it only promises + // performance overhead if possible. It is only provided as the default for + // backward compatibility, and is not recommended in secure deployments. + sslMode || SSLMode.disable, + tls || null, + query || "", + path || "", + onConnected, + onClose, + idleTimeout, + connectionTimeout, + maxLifetime, + !prepare, + ); + } catch (e) { + onClose(e as Error); + return null; + } + } + + adapter: MySQLAdapter; + connection: $ZigGeneratedClasses.MySQLSQLConnection | null = null; + state: PooledConnectionState = PooledConnectionState.pending; + storedError: Error | null = null; + queries: Set<(err: Error) => void> = new Set(); + onFinish: ((err: Error | null) => void) | null = null; + connectionInfo: Bun.SQL.__internal.DefinedMySQLOptions; + flags: number = 0; + /// queryCount is used to indicate the number of queries using the connection, if a connection is reserved or if its a transaction queryCount will be 1 independently of the number of queries + queryCount: number = 0; + + #onConnected(err, _) { + if (err) { + err = wrapError(err); + } + const connectionInfo = this.connectionInfo; + if (connectionInfo?.onconnect) { + connectionInfo.onconnect(err); + } + this.storedError = err; + if (!err) { + this.flags |= PooledConnectionFlags.canBeConnected; + } + this.state = err ? PooledConnectionState.closed : PooledConnectionState.connected; + const onFinish = this.onFinish; + if (onFinish) { + this.queryCount = 0; + this.flags &= ~PooledConnectionFlags.reserved; + this.flags &= ~PooledConnectionFlags.preReserved; + + // pool is closed, lets finish the connection + // pool is closed, lets finish the connection + if (err) { + onFinish(err); + } else { + this.connection?.close(); + } + return; + } + this.adapter.release(this, true); + } + + #onClose(err) { + if (err) { + err = wrapError(err); + } + const connectionInfo = this.connectionInfo; + if (connectionInfo?.onclose) { + connectionInfo.onclose(err); + } + this.state = PooledConnectionState.closed; + this.connection = null; + this.storedError = err; + + // remove from ready connections if its there + this.adapter.readyConnections.delete(this); + const queries = new Set(this.queries); + this.queries.clear(); + this.queryCount = 0; + this.flags &= ~PooledConnectionFlags.reserved; + + // notify all queries that the connection is closed + for (const onClose of queries) { + onClose(err); + } + const onFinish = this.onFinish; + if (onFinish) { + onFinish(err); + } + + this.adapter.release(this, true); + } + + constructor(connectionInfo: Bun.SQL.__internal.DefinedMySQLOptions, adapter: MySQLAdapter) { + this.state = PooledConnectionState.pending; + this.adapter = adapter; + this.connectionInfo = connectionInfo; + this.#startConnection(); + } + + async #startConnection() { + this.connection = await PooledMySQLConnection.createConnection( + this.connectionInfo, + this.#onConnected.bind(this), + this.#onClose.bind(this), + ); + } + + onClose(onClose: (err: Error) => void) { + this.queries.add(onClose); + } + + bindQuery(query: Query, onClose: (err: Error) => void) { + this.queries.add(onClose); + query.finally(onQueryFinish.bind(this, onClose)); + } + + #doRetry() { + if (this.adapter.closed) { + return; + } + // reset error and state + this.storedError = null; + this.state = PooledConnectionState.pending; + // retry connection + this.#startConnection(); + } + close() { + try { + if (this.state === PooledConnectionState.connected) { + this.connection?.close(); + } + } catch {} + } + flush() { + this.connection?.flush(); + } + retry() { + // if pool is closed, we can't retry + if (this.adapter.closed) { + return false; + } + // we need to reconnect + // lets use a retry strategy + + // we can only retry if one day we are able to connect + if (this.flags & PooledConnectionFlags.canBeConnected) { + this.#doRetry(); + } else { + // analyse type of error to see if we can retry + switch (this.storedError?.code) { + case "ERR_MYSQL_PASSWORD_REQUIRED": + case "ERR_MYSQL_MISSING_AUTH_DATA": + case "ERR_MYSQL_FAILED_TO_ENCRYPT_PASSWORD": + case "ERR_MYSQL_INVALID_PUBLIC_KEY": + case "ERR_MYSQL_UNSUPPORTED_PROTOCOL_VERSION": + case "ERR_MYSQL_UNSUPPORTED_AUTH_PLUGIN": + case "ERR_MYSQL_AUTHENTICATION_FAILED": + // we can't retry these are authentication errors + return false; + default: + // we can retry + this.#doRetry(); + } + } + return true; + } +} + +export class MySQLAdapter + implements + DatabaseAdapter +{ + public readonly connectionInfo: Bun.SQL.__internal.DefinedMySQLOptions; + + public readonly connections: PooledMySQLConnection[]; + public readonly readyConnections: Set; + + public waitingQueue: Array<(err: Error | null, result: any) => void> = []; + public reservedQueue: Array<(err: Error | null, result: any) => void> = []; + + public poolStarted: boolean = false; + public closed: boolean = false; + public totalQueries: number = 0; + public onAllQueriesFinished: (() => void) | null = null; + + constructor(connectionInfo: Bun.SQL.__internal.DefinedMySQLOptions) { + this.connectionInfo = connectionInfo; + this.connections = new Array(connectionInfo.max); + this.readyConnections = new Set(); + } + + escapeIdentifier(str: string) { + return "`" + str.replaceAll("`", "``") + "`"; + } + + connectionClosedError() { + return new MySQLError("Connection closed", { + code: "ERR_MYSQL_CONNECTION_CLOSED", + }); + } + notTaggedCallError() { + return new MySQLError("Query not called as a tagged template literal", { + code: "ERR_MYSQL_NOT_TAGGED_CALL", + }); + } + queryCancelledError() { + return new MySQLError("Query cancelled", { + code: "ERR_MYSQL_QUERY_CANCELLED", + }); + } + invalidTransactionStateError(message: string) { + return new MySQLError(message, { + code: "ERR_MYSQL_INVALID_TRANSACTION_STATE", + }); + } + supportsReservedConnections() { + return true; + } + + getConnectionForQuery(pooledConnection: PooledMySQLConnection) { + return pooledConnection.connection; + } + + attachConnectionCloseHandler(connection: PooledMySQLConnection, handler: () => void): void { + if (connection.onClose) { + connection.onClose(handler); + } + } + + detachConnectionCloseHandler(connection: PooledMySQLConnection, handler: () => void): void { + if (connection.queries) { + connection.queries.delete(handler); + } + } + + getTransactionCommands(options?: string): import("./shared").TransactionCommands { + let BEGIN = "START TRANSACTION"; + if (options) { + BEGIN = `START TRANSACTION ${options}`; + } + + return { + BEGIN, + COMMIT: "COMMIT", + ROLLBACK: "ROLLBACK", + SAVEPOINT: "SAVEPOINT", + RELEASE_SAVEPOINT: "RELEASE SAVEPOINT", + ROLLBACK_TO_SAVEPOINT: "ROLLBACK TO SAVEPOINT", + }; + } + + getDistributedTransactionCommands(name: string): import("./shared").TransactionCommands | null { + if (!this.validateDistributedTransactionName(name).valid) { + return null; + } + + return { + BEGIN: `XA START '${name}'`, + COMMIT: `XA PREPARE '${name}'`, + ROLLBACK: `XA ROLLBACK '${name}'`, + SAVEPOINT: "SAVEPOINT", + RELEASE_SAVEPOINT: "RELEASE SAVEPOINT", + ROLLBACK_TO_SAVEPOINT: "ROLLBACK TO SAVEPOINT", + BEFORE_COMMIT_OR_ROLLBACK: `XA END '${name}'`, + }; + } + + validateTransactionOptions(_options: string): { valid: boolean; error?: string } { + return { valid: true }; + } + + validateDistributedTransactionName(name: string): { valid: boolean; error?: string } { + if (name.indexOf("'") !== -1) { + return { + valid: false, + error: "Distributed transaction name cannot contain single quotes.", + }; + } + return { valid: true }; + } + + getCommitDistributedSQL(name: string): string { + const validation = this.validateDistributedTransactionName(name); + if (!validation.valid) { + throw new Error(validation.error); + } + return `XA COMMIT '${name}'`; + } + + getRollbackDistributedSQL(name: string): string { + const validation = this.validateDistributedTransactionName(name); + if (!validation.valid) { + throw new Error(validation.error); + } + return `XA ROLLBACK '${name}'`; + } + + createQueryHandle(sql: string, values: unknown[], flags: number) { + if (!(flags & SQLQueryFlags.allowUnsafeTransaction)) { + if (this.connectionInfo.max !== 1) { + const upperCaseSqlString = sql.toUpperCase().trim(); + if (upperCaseSqlString.startsWith("BEGIN") || upperCaseSqlString.startsWith("START TRANSACTION")) { + throw new MySQLError("Only use sql.begin, sql.reserved or max: 1", { + code: "ERR_MYSQL_UNSAFE_TRANSACTION", + }); + } + } + } + + return createMySQLQuery( + sql, + values, + new SQLResultArray(), + undefined, + !!(flags & SQLQueryFlags.bigint), + !!(flags & SQLQueryFlags.simple), + ); + } + + maxDistribution() { + if (!this.waitingQueue.length) return 0; + const result = Math.ceil((this.waitingQueue.length + this.totalQueries) / this.connections.length); + return result ? result : 1; + } + + flushConcurrentQueries() { + const maxDistribution = this.maxDistribution(); + if (maxDistribution === 0) { + return; + } + + while (true) { + const nonReservedConnections = Array.from(this.readyConnections).filter( + c => !(c.flags & PooledConnectionFlags.preReserved) && c.queryCount < maxDistribution, + ); + if (nonReservedConnections.length === 0) { + return; + } + const orderedConnections = nonReservedConnections.sort((a, b) => a.queryCount - b.queryCount); + for (const connection of orderedConnections) { + const pending = this.waitingQueue.shift(); + if (!pending) { + return; + } + connection.queryCount++; + this.totalQueries++; + pending(null, connection); + } + } + } + + release(connection: PooledMySQLConnection, connectingEvent: boolean = false) { + if (!connectingEvent) { + connection.queryCount--; + this.totalQueries--; + } + const currentQueryCount = connection.queryCount; + if (currentQueryCount == 0) { + connection.flags &= ~PooledConnectionFlags.reserved; + connection.flags &= ~PooledConnectionFlags.preReserved; + } + if (this.onAllQueriesFinished) { + // we are waiting for all queries to finish, lets check if we can call it + if (!this.hasPendingQueries()) { + this.onAllQueriesFinished(); + } + } + + if (connection.state !== PooledConnectionState.connected) { + // connection is not ready + if (connection.storedError) { + // this connection got a error but maybe we can wait for another + + if (this.hasConnectionsAvailable()) { + return; + } + + const waitingQueue = this.waitingQueue; + const reservedQueue = this.reservedQueue; + + this.waitingQueue = []; + this.reservedQueue = []; + // we have no connections available so lets fails + for (const pending of waitingQueue) { + pending(connection.storedError, connection); + } + for (const pending of reservedQueue) { + pending(connection.storedError, connection); + } + } + return; + } + + if (currentQueryCount == 0) { + // ok we can actually bind reserved queries to it + const pendingReserved = this.reservedQueue.shift(); + if (pendingReserved) { + connection.flags |= PooledConnectionFlags.reserved; + connection.queryCount++; + this.totalQueries++; + // we have a connection waiting for a reserved connection lets prioritize it + pendingReserved(connection.storedError, connection); + return; + } + } + this.readyConnections.add(connection); + this.flushConcurrentQueries(); + } + + hasConnectionsAvailable() { + if (this.readyConnections.size > 0) return true; + if (this.poolStarted) { + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + if (connection.state !== PooledConnectionState.closed) { + // some connection is connecting or connected + return true; + } + } + } + return false; + } + + hasPendingQueries() { + if (this.waitingQueue.length > 0 || this.reservedQueue.length > 0) return true; + if (this.poolStarted) { + return this.totalQueries > 0; + } + return false; + } + isConnected() { + if (this.readyConnections.size > 0) { + return true; + } + if (this.poolStarted) { + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + if (connection.state === PooledConnectionState.connected) { + return true; + } + } + } + return false; + } + flush() { + if (this.closed) { + return; + } + if (this.poolStarted) { + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + if (connection.state === PooledConnectionState.connected) { + connection.connection?.flush(); + } + } + } + } + + async #close() { + let pending; + while ((pending = this.waitingQueue.shift())) { + pending(this.connectionClosedError(), null); + } + while (this.reservedQueue.length > 0) { + const pendingReserved = this.reservedQueue.shift(); + if (pendingReserved) { + pendingReserved(this.connectionClosedError(), null); + } + } + + const promises: Array> = []; + + if (this.poolStarted) { + this.poolStarted = false; + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + switch (connection.state) { + case PooledConnectionState.pending: + { + const { promise, resolve } = Promise.withResolvers(); + connection.onFinish = resolve; + promises.push(promise); + connection.connection?.close(); + } + break; + + case PooledConnectionState.connected: + { + const { promise, resolve } = Promise.withResolvers(); + connection.onFinish = resolve; + promises.push(promise); + connection.connection?.close(); + } + break; + } + // clean connection reference + // @ts-ignore + this.connections[i] = null; + } + } + + this.readyConnections.clear(); + this.waitingQueue.length = 0; + return Promise.all(promises); + } + + async close(options?: { timeout?: number }) { + if (this.closed) { + return; + } + + let timeout = options?.timeout; + if (timeout) { + timeout = Number(timeout); + if (timeout > 2 ** 31 || timeout < 0 || timeout !== timeout) { + throw $ERR_INVALID_ARG_VALUE("options.timeout", timeout, "must be a non-negative integer less than 2^31"); + } + + this.closed = true; + if (timeout === 0 || !this.hasPendingQueries()) { + // close immediately + await this.#close(); + return; + } + + const { promise, resolve } = Promise.withResolvers(); + const timer = setTimeout(() => { + // timeout is reached, lets close and probably fail some queries + this.#close().finally(resolve); + }, timeout * 1000); + timer.unref(); // dont block the event loop + + this.onAllQueriesFinished = () => { + clearTimeout(timer); + // everything is closed, lets close the pool + this.#close().finally(resolve); + }; + + return promise; + } else { + this.closed = true; + if (!this.hasPendingQueries()) { + // close immediately + await this.#close(); + return; + } + + // gracefully close the pool + const { promise, resolve } = Promise.withResolvers(); + + this.onAllQueriesFinished = () => { + // everything is closed, lets close the pool + this.#close().finally(resolve); + }; + + return promise; + } + } + + /** + * @param {function} onConnected - The callback function to be called when the connection is established. + * @param {boolean} reserved - Whether the connection is reserved, if is reserved the connection will not be released until release is called, if not release will only decrement the queryCount counter + */ + connect(onConnected: (err: Error | null, result: any) => void, reserved: boolean = false) { + if (this.closed) { + return onConnected(this.connectionClosedError(), null); + } + + if (this.readyConnections.size === 0) { + // no connection ready lets make some + let retry_in_progress = false; + let all_closed = true; + let storedError: Error | null = null; + + if (this.poolStarted) { + // we already started the pool + // lets check if some connection is available to retry + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + // we need a new connection and we have some connections that can retry + if (connection.state === PooledConnectionState.closed) { + if (connection.retry()) { + // lets wait for connection to be released + if (!retry_in_progress) { + // avoid adding to the queue twice, we wanna to retry every available pool connection + retry_in_progress = true; + if (reserved) { + // we are not sure what connection will be available so we dont pre reserve + this.reservedQueue.push(onConnected); + } else { + this.waitingQueue.push(onConnected); + } + } + } else { + // we have some error, lets grab it and fail if unable to start a connection + storedError = connection.storedError; + } + } else { + // we have some pending or open connections + all_closed = false; + } + } + if (!all_closed && !retry_in_progress) { + // is possible to connect because we have some working connections, or we are just without network for some reason + // wait for connection to be released or fail + if (reserved) { + // we are not sure what connection will be available so we dont pre reserve + this.reservedQueue.push(onConnected); + } else { + this.waitingQueue.push(onConnected); + } + } else if (!retry_in_progress) { + // impossible to connect or retry + onConnected(storedError ?? this.connectionClosedError(), null); + } + return; + } + // we never started the pool, lets start it + if (reserved) { + this.reservedQueue.push(onConnected); + } else { + this.waitingQueue.push(onConnected); + } + this.poolStarted = true; + const pollSize = this.connections.length; + // pool is always at least 1 connection + const firstConnection = new PooledMySQLConnection(this.connectionInfo, this); + this.connections[0] = firstConnection; + if (reserved) { + firstConnection.flags |= PooledConnectionFlags.preReserved; // lets pre reserve the first connection + } + for (let i = 1; i < pollSize; i++) { + this.connections[i] = new PooledMySQLConnection(this.connectionInfo, this); + } + return; + } + if (reserved) { + let connectionWithLeastQueries: PooledMySQLConnection | null = null; + let leastQueries = Infinity; + for (const connection of this.readyConnections) { + if (connection.flags & PooledConnectionFlags.preReserved || connection.flags & PooledConnectionFlags.reserved) + continue; + const queryCount = connection.queryCount; + if (queryCount > 0) { + if (queryCount < leastQueries) { + leastQueries = queryCount; + connectionWithLeastQueries = connection; + } + continue; + } + connection.flags |= PooledConnectionFlags.reserved; + connection.queryCount++; + this.totalQueries++; + this.readyConnections.delete(connection); + onConnected(null, connection); + return; + } + + if (connectionWithLeastQueries) { + // lets mark the connection with the least queries as preReserved if any + connectionWithLeastQueries.flags |= PooledConnectionFlags.preReserved; + } + + // no connection available to be reserved lets wait for a connection to be released + this.reservedQueue.push(onConnected); + } else { + this.waitingQueue.push(onConnected); + this.flushConcurrentQueries(); + } + } + + normalizeQuery(strings: string | TemplateStringsArray, values: unknown[], binding_idx = 1): [string, unknown[]] { + if (typeof strings === "string") { + // identifier or unsafe query + return [strings, values || []]; + } + + if (!$isArray(strings)) { + // we should not hit this path + throw new SyntaxError("Invalid query: SQL Fragment cannot be executed or was misused"); + } + + const str_len = strings.length; + if (str_len === 0) { + return ["", []]; + } + + let binding_values: any[] = []; + let query = ""; + + for (let i = 0; i < str_len; i++) { + const string = strings[i]; + + if (typeof string === "string") { + query += string; + + if (values.length > i) { + const value = values[i]; + + if (value instanceof Query) { + const q = value as Query; + const [sub_query, sub_values] = this.normalizeQuery(q[_strings], q[_values], binding_idx); + + query += sub_query; + for (let j = 0; j < sub_values.length; j++) { + binding_values.push(sub_values[j]); + } + binding_idx += sub_values.length; + } else if (value instanceof SQLHelper) { + const command = detectCommand(query); + // only selectIn, insert, update, updateSet are allowed + if (command === SQLCommand.none || command === SQLCommand.where) { + throw new SyntaxError("Helpers are only allowed for INSERT, UPDATE and WHERE IN commands"); + } + const { columns, value: items } = value as SQLHelper; + const columnCount = columns.length; + if (columnCount === 0 && command !== SQLCommand.whereIn) { + throw new SyntaxError(`Cannot ${commandToString(command)} with no columns`); + } + const lastColumnIndex = columns.length - 1; + + if (command === SQLCommand.insert) { + // + // insert into users ${sql(users)} or insert into users ${sql(user)} + // + + query += "("; + for (let j = 0; j < columnCount; j++) { + query += this.escapeIdentifier(columns[j]); + if (j < lastColumnIndex) { + query += ", "; + } + } + query += ") VALUES"; + if ($isArray(items)) { + const itemsCount = items.length; + const lastItemIndex = itemsCount - 1; + for (let j = 0; j < itemsCount; j++) { + query += "("; + const item = items[j]; + for (let k = 0; k < columnCount; k++) { + const column = columns[k]; + const columnValue = item[column]; + query += `?${k < lastColumnIndex ? ", " : ""}`; + if (typeof columnValue === "undefined") { + binding_values.push(null); + } else { + binding_values.push(columnValue); + } + } + if (j < lastItemIndex) { + query += "),"; + } else { + query += ") "; // the user can add RETURNING * or RETURNING id + } + } + } else { + query += "("; + const item = items; + for (let j = 0; j < columnCount; j++) { + const column = columns[j]; + const columnValue = item[column]; + query += `?${j < lastColumnIndex ? ", " : ""}`; + if (typeof columnValue === "undefined") { + binding_values.push(null); + } else { + binding_values.push(columnValue); + } + } + query += ") "; // the user can add RETURNING * or RETURNING id + } + } else if (command === SQLCommand.whereIn) { + // SELECT * FROM users WHERE id IN (${sql([1, 2, 3])}) + if (!$isArray(items)) { + throw new SyntaxError("An array of values is required for WHERE IN helper"); + } + const itemsCount = items.length; + const lastItemIndex = itemsCount - 1; + query += "("; + for (let j = 0; j < itemsCount; j++) { + query += `?${j < lastItemIndex ? ", " : ""}`; + if (columnCount > 0) { + // we must use a key from a object + if (columnCount > 1) { + // we should not pass multiple columns here + throw new SyntaxError("Cannot use WHERE IN helper with multiple columns"); + } + // SELECT * FROM users WHERE id IN (${sql(users, "id")}) + const value = items[j]; + if (typeof value === "undefined") { + binding_values.push(null); + } else { + const value_from_key = value[columns[0]]; + + if (typeof value_from_key === "undefined") { + binding_values.push(null); + } else { + binding_values.push(value_from_key); + } + } + } else { + const value = items[j]; + if (typeof value === "undefined") { + binding_values.push(null); + } else { + binding_values.push(value); + } + } + } + query += ") "; // more conditions can be added after this + } else { + // UPDATE users SET ${sql({ name: "John", age: 31 })} WHERE id = 1 + let item; + if ($isArray(items)) { + if (items.length > 1) { + throw new SyntaxError("Cannot use array of objects for UPDATE"); + } + item = items[0]; + } else { + item = items; + } + // no need to include if is updateSet + if (command === SQLCommand.update) { + query += " SET "; + } + for (let i = 0; i < columnCount; i++) { + const column = columns[i]; + const columnValue = item[column]; + query += `${this.escapeIdentifier(column)} = ?${i < lastColumnIndex ? ", " : ""}`; + if (typeof columnValue === "undefined") { + binding_values.push(null); + } else { + binding_values.push(columnValue); + } + } + query += " "; // the user can add where clause after this + } + } else { + //TODO: handle sql.array parameters + query += `? `; + if (typeof value === "undefined") { + binding_values.push(null); + } else { + binding_values.push(value); + } + } + } + } else { + throw new SyntaxError("Invalid query: SQL Fragment cannot be executed or was misused"); + } + } + + return [query, binding_values]; + } +} + +export default { + MySQLAdapter, + SQLCommand, + commandToString, + detectCommand, +}; diff --git a/src/js/internal/sql/postgres.ts b/src/js/internal/sql/postgres.ts index 24f44e8cae..73f17dbb0e 100644 --- a/src/js/internal/sql/postgres.ts +++ b/src/js/internal/sql/postgres.ts @@ -1,13 +1,12 @@ +import type { PostgresErrorOptions } from "internal/sql/errors"; import type { Query } from "./query"; import type { DatabaseAdapter, SQLHelper, SQLResultArray, SSLMode } from "./shared"; - const { SQLHelper, SSLMode, SQLResultArray } = require("internal/sql/shared"); const { Query, SQLQueryFlags, symbols: { _strings, _values, _flags, _results, _handle }, } = require("internal/sql/query"); -const { escapeIdentifier, connectionClosedError } = require("internal/sql/utils"); const { PostgresError } = require("internal/sql/errors"); const { @@ -18,6 +17,13 @@ const { const cmds = ["", "INSERT", "DELETE", "UPDATE", "MERGE", "SELECT", "MOVE", "FETCH", "COPY"]; +function wrapPostgresError(error: Error | PostgresErrorOptions) { + if (Error.isError(error)) { + return error; + } + return new PostgresError(error.message, error); +} + initPostgres( function onResolvePostgresQuery(query, result, commandTag, count, queries, is_last) { /// simple queries @@ -85,7 +91,12 @@ initPostgres( } catch {} }, - function onRejectPostgresQuery(query: Query, reject: Error, queries: Query[]) { + function onRejectPostgresQuery( + query: Query, + reject: Error | PostgresErrorOptions, + queries: Query[], + ) { + reject = wrapPostgresError(reject); if (queries) { const queriesIndex = queries.indexOf(query); if (queriesIndex !== -1) { @@ -94,7 +105,7 @@ initPostgres( } try { - query.reject(reject); + query.reject(reject as Error); } catch {} }, ); @@ -356,6 +367,9 @@ class PooledPostgresConnection { queryCount: number = 0; #onConnected(err, _) { + if (err) { + err = wrapPostgresError(err); + } const connectionInfo = this.connectionInfo; if (connectionInfo?.onconnect) { connectionInfo.onconnect(err); @@ -384,6 +398,9 @@ class PooledPostgresConnection { } #onClose(err) { + if (err) { + err = wrapPostgresError(err); + } const connectionInfo = this.connectionInfo; if (connectionInfo?.onclose) { connectionInfo.onclose(err); @@ -514,6 +531,30 @@ export class PostgresAdapter this.readyConnections = new Set(); } + escapeIdentifier(str: string) { + return '"' + str.replaceAll('"', '""').replaceAll(".", '"."') + '"'; + } + + connectionClosedError() { + return new PostgresError("Connection closed", { + code: "ERR_POSTGRES_CONNECTION_CLOSED", + }); + } + notTaggedCallError() { + return new PostgresError("Query not called as a tagged template literal", { + code: "ERR_POSTGRES_NOT_TAGGED_CALL", + }); + } + queryCancelledError(): Error { + return new PostgresError("Query cancelled", { + code: "ERR_POSTGRES_QUERY_CANCELLED", + }); + } + invalidTransactionStateError(message: string) { + return new PostgresError(message, { + code: "ERR_POSTGRES_INVALID_TRANSACTION_STATE", + }); + } supportsReservedConnections() { return true; } @@ -766,12 +807,12 @@ export class PostgresAdapter async #close() { let pending; while ((pending = this.waitingQueue.shift())) { - pending(connectionClosedError(), null); + pending(this.connectionClosedError(), null); } while (this.reservedQueue.length > 0) { const pendingReserved = this.reservedQueue.shift(); if (pendingReserved) { - pendingReserved(connectionClosedError(), null); + pendingReserved(this.connectionClosedError(), null); } } @@ -871,7 +912,7 @@ export class PostgresAdapter */ connect(onConnected: (err: Error | null, result: any) => void, reserved: boolean = false) { if (this.closed) { - return onConnected(connectionClosedError(), null); + return onConnected(this.connectionClosedError(), null); } if (this.readyConnections.size === 0) { @@ -920,7 +961,7 @@ export class PostgresAdapter } } else if (!retry_in_progress) { // impossible to connect or retry - onConnected(storedError ?? connectionClosedError(), null); + onConnected(storedError ?? this.connectionClosedError(), null); } return; } @@ -1035,7 +1076,7 @@ export class PostgresAdapter query += "("; for (let j = 0; j < columnCount; j++) { - query += escapeIdentifier(columns[j]); + query += this.escapeIdentifier(columns[j]); if (j < lastColumnIndex) { query += ", "; } @@ -1135,7 +1176,7 @@ export class PostgresAdapter for (let i = 0; i < columnCount; i++) { const column = columns[i]; const columnValue = item[column]; - query += `${escapeIdentifier(column)} = $${binding_idx++}${i < lastColumnIndex ? ", " : ""}`; + query += `${this.escapeIdentifier(column)} = $${binding_idx++}${i < lastColumnIndex ? ", " : ""}`; if (typeof columnValue === "undefined") { binding_values.push(null); } else { diff --git a/src/js/internal/sql/query.ts b/src/js/internal/sql/query.ts index dedd2016cd..3387f9edb2 100644 --- a/src/js/internal/sql/query.ts +++ b/src/js/internal/sql/query.ts @@ -1,5 +1,4 @@ import type { DatabaseAdapter } from "./shared.ts"; -const { escapeIdentifier, notTaggedCallError } = require("internal/sql/utils"); const _resolve = Symbol("resolve"); const _reject = Symbol("reject"); @@ -83,7 +82,7 @@ class Query> extends PublicPromise { if (!(flags & SQLQueryFlags.unsafe)) { // identifier (cannot be executed in safe mode) flags |= SQLQueryFlags.notTagged; - strings = escapeIdentifier(strings); + strings = adapter.escapeIdentifier(strings); } } @@ -110,7 +109,7 @@ class Query> extends PublicPromise { } if (this[_flags] & SQLQueryFlags.notTagged) { - this.reject(notTaggedCallError()); + this.reject(this[_adapter].notTaggedCallError()); return; } @@ -211,7 +210,7 @@ class Query> extends PublicPromise { async run() { if (this[_flags] & SQLQueryFlags.notTagged) { - throw notTaggedCallError(); + throw this[_adapter].notTaggedCallError(); } await this[_run](true); @@ -247,7 +246,7 @@ class Query> extends PublicPromise { then() { if (this[_flags] & SQLQueryFlags.notTagged) { - throw notTaggedCallError(); + throw this[_adapter].notTaggedCallError(); } this[_run](true); @@ -260,7 +259,7 @@ class Query> extends PublicPromise { catch() { if (this[_flags] & SQLQueryFlags.notTagged) { - throw notTaggedCallError(); + throw this[_adapter].notTaggedCallError(); } this[_run](true); @@ -273,7 +272,7 @@ class Query> extends PublicPromise { finally(_onfinally?: (() => void) | undefined | null) { if (this[_flags] & SQLQueryFlags.notTagged) { - throw notTaggedCallError(); + throw this[_adapter].notTaggedCallError(); } this[_run](true); diff --git a/src/js/internal/sql/shared.ts b/src/js/internal/sql/shared.ts index 81c7d81545..adabcbbcf2 100644 --- a/src/js/internal/sql/shared.ts +++ b/src/js/internal/sql/shared.ts @@ -200,16 +200,47 @@ function assertIsOptionsOfAdapter( } } +function hasProtocol(url: string) { + if (typeof url !== "string") return false; + const protocols: string[] = [ + "http", + "https", + "ftp", + "postgres", + "postgresql", + "mysql", + "mysql2", + "mariadb", + "file", + "sqlite", + ]; + for (const protocol of protocols) { + if (url.startsWith(protocol + "://")) { + return true; + } + } + return false; +} + +function defaultToPostgresIfNoProtocol(url: string | URL | null): URL { + if (url instanceof URL) { + return url; + } + if (hasProtocol(url as string)) { + return new URL(url as string); + } + return new URL("postgres://" + url); +} function parseOptions( stringOrUrlOrOptions: Bun.SQL.Options | string | URL | undefined, definitelyOptionsButMaybeEmpty: Bun.SQL.Options, ): Bun.SQL.__internal.DefinedOptions { const env = Bun.env; - let [stringOrUrl = env.POSTGRES_URL || env.DATABASE_URL || env.PGURL || env.PG_URL || null, options]: [ - string | URL | null, - Bun.SQL.Options, - ] = + let [ + stringOrUrl = env.POSTGRES_URL || env.DATABASE_URL || env.PGURL || env.PG_URL || env.MYSQL_URL || null, + options, + ]: [string | URL | null, Bun.SQL.Options] = typeof stringOrUrlOrOptions === "string" || stringOrUrlOrOptions instanceof URL ? [stringOrUrlOrOptions, definitelyOptionsButMaybeEmpty] : stringOrUrlOrOptions @@ -250,17 +281,15 @@ function parseOptions( return parseSQLiteOptionsWithQueryParams(sqliteOptions, stringOrUrl); } - if (options.adapter !== undefined && options.adapter !== "postgres" && options.adapter !== "postgresql") { - options.adapter satisfies never; // This will type error if we support a new adapter in the future, which will let us know to update this check - throw new Error(`Unsupported adapter: ${options.adapter}. Supported adapters: "postgres", "sqlite"`); + if (!stringOrUrl) { + const url = options?.url; + if (typeof url === "string") { + stringOrUrl = defaultToPostgresIfNoProtocol(url); + } else if (url instanceof URL) { + stringOrUrl = url; + } } - // @ts-expect-error Compatibility - if (options.adapter === "postgresql") options.adapter = "postgres"; - if (options.adapter === undefined) options.adapter = "postgres"; - - assertIsOptionsOfAdapter(options, "postgres"); - let hostname: string | undefined, port: number | string | undefined, username: string | null | undefined, @@ -276,7 +305,8 @@ function parseOptions( onclose: ((client: Bun.SQL) => void) | undefined, max: number | null | undefined, bigint: boolean | undefined, - path: string | string[]; + path: string | string[], + adapter: Bun.SQL.__internal.Adapter; let prepare = true; let sslMode: SSLMode = SSLMode.disable; @@ -311,7 +341,7 @@ function parseOptions( } else if (options?.url) { const _url = options.url; if (typeof _url === "string") { - url = new URL(_url); + url = defaultToPostgresIfNoProtocol(_url); } else if (_url && typeof _url === "object" && _url instanceof URL) { url = _url; } @@ -322,7 +352,7 @@ function parseOptions( } } else if (typeof stringOrUrl === "string") { try { - url = new URL(stringOrUrl); + url = defaultToPostgresIfNoProtocol(stringOrUrl); } catch (e) { throw new Error(`Invalid URL '${stringOrUrl}' for postgres. Did you mean to specify \`{ adapter: "sqlite" }\`?`, { cause: e, @@ -330,14 +360,18 @@ function parseOptions( } } query = ""; - + adapter = options.adapter; if (url) { - ({ hostname, port, username, password } = options); + ({ hostname, port, username, password, adapter } = options); // object overrides url hostname ||= url.hostname; port ||= url.port; username ||= decodeIfValid(url.username); password ||= decodeIfValid(url.password); + adapter ||= url.protocol as Bun.SQL.__internal.Adapter; + if (adapter && adapter[adapter.length - 1] === ":") { + adapter = adapter.slice(0, -1) as Bun.SQL.__internal.Adapter; + } const queryObject = url.searchParams.toJSON(); for (const key in queryObject) { @@ -355,20 +389,57 @@ function parseOptions( } query = query.trim(); } + if (adapter) { + switch (adapter) { + case "http": + case "https": + case "ftp": + case "postgres": + case "postgresql": + adapter = "postgres"; + break; + case "mysql": + case "mysql2": + case "mariadb": + adapter = "mysql"; + break; + case "file": + case "sqlite": + adapter = "sqlite"; + break; + default: + options.adapter satisfies never; // This will type error if we support a new adapter in the future, which will let us know to update this check + throw new Error(`Unsupported adapter: ${options.adapter}. Supported adapters: "postgres", "sqlite", "mysql"`); + } + } else { + adapter = "postgres"; + } + options.adapter = adapter; + assertIsOptionsOfAdapter(options, adapter); hostname ||= options.hostname || options.host || env.PGHOST || "localhost"; - port ||= Number(options.port || env.PGPORT || 5432); + port ||= Number(options.port || env.PGPORT || (adapter === "mysql" ? 3306 : 5432)); path ||= (options as { path?: string }).path || ""; // add /.s.PGSQL.${port} if it doesn't exist - if (path && path?.indexOf("/.s.PGSQL.") === -1) { + if (path && path?.indexOf("/.s.PGSQL.") === -1 && adapter === "postgres") { path = `${path}/.s.PGSQL.${port}`; } username ||= - options.username || options.user || env.PGUSERNAME || env.PGUSER || env.USER || env.USERNAME || "postgres"; + options.username || + options.user || + env.PGUSERNAME || + env.PGUSER || + env.USER || + env.USERNAME || + (adapter === "mysql" ? "root" : "postgres"); // default username for mysql is root and for postgres is postgres; database ||= - options.database || options.db || decodeIfValid((url?.pathname ?? "").slice(1)) || env.PGDATABASE || username; + options.database || + options.db || + decodeIfValid((url?.pathname ?? "").slice(1)) || + env.PGDATABASE || + (adapter === "mysql" ? "mysql" : username); // default database; password ||= options.password || options.pass || env.PGPASSWORD || ""; const connection = options.connection; if (connection && $isObject(connection)) { @@ -393,6 +464,9 @@ function parseOptions( bigint ??= options.bigint; // we need to explicitly set prepare to false if it is false if (options.prepare === false) { + if (adapter === "mysql") { + throw $ERR_INVALID_ARG_VALUE("options.prepare", false, "prepared: false is not supported in MySQL"); + } prepare = false; } @@ -470,8 +544,8 @@ function parseOptions( throw $ERR_INVALID_ARG_VALUE("port", port, "must be a non-negative integer between 1 and 65535"); } - const ret: Bun.SQL.__internal.DefinedPostgresOptions = { - adapter: "postgres", + const ret: Bun.SQL.__internal.DefinedOptions = { + adapter, hostname, port, username, @@ -545,6 +619,11 @@ export interface DatabaseAdapter { getCommitDistributedSQL?(name: string): string; getRollbackDistributedSQL?(name: string): string; + escapeIdentifier(name: string): string; + notTaggedCallError(): Error; + connectionClosedError(): Error; + queryCancelledError(): Error; + invalidTransactionStateError(message: string): Error; } export default { diff --git a/src/js/internal/sql/sqlite.ts b/src/js/internal/sql/sqlite.ts index 42b7cc439a..11304a7e87 100644 --- a/src/js/internal/sql/sqlite.ts +++ b/src/js/internal/sql/sqlite.ts @@ -8,7 +8,6 @@ const { SQLQueryResultMode, symbols: { _strings, _values }, } = require("internal/sql/query"); -const { escapeIdentifier, connectionClosedError } = require("internal/sql/utils"); const { SQLiteError } = require("internal/sql/errors"); let lazySQLiteModule: typeof BunSQLiteModule; @@ -447,7 +446,33 @@ export class SQLiteAdapter createQueryHandle(sql: string, values: unknown[] | undefined | null = []): SQLiteQueryHandle { return new SQLiteQueryHandle(sql, values ?? []); } - + escapeIdentifier(str: string) { + return '"' + str.replaceAll('"', '""').replaceAll(".", '"."') + '"'; + } + connectionClosedError() { + return new SQLiteError("Connection closed", { + code: "ERR_SQLITE_CONNECTION_CLOSED", + errno: 0, + }); + } + notTaggedCallError() { + return new SQLiteError("Query not called as a tagged template literal", { + code: "ERR_SQLITE_NOT_TAGGED_CALL", + errno: 0, + }); + } + queryCancelledError() { + return new SQLiteError("Query cancelled", { + code: "ERR_SQLITE_QUERY_CANCELLED", + errno: 0, + }); + } + invalidTransactionStateError(message: string) { + return new SQLiteError(message, { + code: "ERR_SQLITE_INVALID_TRANSACTION_STATE", + errno: 0, + }); + } normalizeQuery(strings: string | TemplateStringsArray, values: unknown[], binding_idx = 1): [string, unknown[]] { if (typeof strings === "string") { // identifier or unsafe query @@ -511,7 +536,7 @@ export class SQLiteAdapter query += "("; for (let j = 0; j < columnCount; j++) { - query += escapeIdentifier(columns[j]); + query += this.escapeIdentifier(columns[j]); if (j < lastColumnIndex) { query += ", "; } @@ -615,7 +640,7 @@ export class SQLiteAdapter const column = columns[i]; const columnValue = item[column]; // SQLite uses ? for placeholders - query += `${escapeIdentifier(column)} = ?${i < lastColumnIndex ? ", " : ""}`; + query += `${this.escapeIdentifier(column)} = ?${i < lastColumnIndex ? ", " : ""}`; if (typeof columnValue === "undefined") { binding_values.push(null); } else { @@ -644,7 +669,7 @@ export class SQLiteAdapter connect(onConnected: OnConnected, reserved?: boolean) { if (this._closed) { - return onConnected(connectionClosedError(), null); + return onConnected(this.connectionClosedError(), null); } // SQLite doesn't support reserved connections since it doesn't have a connection pool @@ -659,7 +684,7 @@ export class SQLiteAdapter } else if (this.db) { onConnected(null, this.db); } else { - onConnected(connectionClosedError(), null); + onConnected(this.connectionClosedError(), null); } } diff --git a/src/js/internal/sql/utils.ts b/src/js/internal/sql/utils.ts deleted file mode 100644 index 8b2e0b68ad..0000000000 --- a/src/js/internal/sql/utils.ts +++ /dev/null @@ -1,26 +0,0 @@ -const { hideFromStack } = require("../shared.ts"); -const { PostgresError } = require("./errors"); - -function connectionClosedError() { - return new PostgresError("Connection closed", { - code: "ERR_POSTGRES_CONNECTION_CLOSED", - }); -} -hideFromStack(connectionClosedError); - -function notTaggedCallError() { - return new PostgresError("Query not called as a tagged template literal", { - code: "ERR_POSTGRES_NOT_TAGGED_CALL", - }); -} -hideFromStack(notTaggedCallError); - -function escapeIdentifier(str: string) { - return '"' + str.replaceAll('"', '""').replaceAll(".", '"."') + '"'; -} - -export default { - connectionClosedError, - notTaggedCallError, - escapeIdentifier, -}; diff --git a/src/js/private.d.ts b/src/js/private.d.ts index 3f46c32c9f..738743c842 100644 --- a/src/js/private.d.ts +++ b/src/js/private.d.ts @@ -31,7 +31,9 @@ declare module "bun" { query: string; }; - type DefinedOptions = DefinedSQLiteOptions | DefinedPostgresOptions; + type DefinedMySQLOptions = DefinedPostgresOptions; + + type DefinedOptions = DefinedSQLiteOptions | DefinedPostgresOptions | DefinedMySQLOptions; } } diff --git a/src/sql/mysql.zig b/src/sql/mysql.zig new file mode 100644 index 0000000000..ad391c73ab --- /dev/null +++ b/src/sql/mysql.zig @@ -0,0 +1,28 @@ +pub fn createBinding(globalObject: *jsc.JSGlobalObject) JSValue { + const binding = JSValue.createEmptyObjectWithNullPrototype(globalObject); + binding.put(globalObject, ZigString.static("MySQLConnection"), MySQLConnection.js.getConstructor(globalObject)); + binding.put(globalObject, ZigString.static("init"), jsc.JSFunction.create(globalObject, "init", MySQLContext.init, 0, .{})); + binding.put( + globalObject, + ZigString.static("createQuery"), + jsc.JSFunction.create(globalObject, "createQuery", MySQLQuery.call, 6, .{}), + ); + + binding.put( + globalObject, + ZigString.static("createConnection"), + jsc.JSFunction.create(globalObject, "createQuery", MySQLConnection.call, 2, .{}), + ); + + return binding; +} + +pub const MySQLConnection = @import("./mysql/MySQLConnection.zig"); +pub const MySQLContext = @import("./mysql/MySQLContext.zig"); +pub const MySQLQuery = @import("./mysql/MySQLQuery.zig"); + +const bun = @import("bun"); + +const jsc = bun.jsc; +const JSValue = jsc.JSValue; +const ZigString = jsc.ZigString; diff --git a/src/sql/mysql/AuthMethod.zig b/src/sql/mysql/AuthMethod.zig new file mode 100644 index 0000000000..35374e3ca3 --- /dev/null +++ b/src/sql/mysql/AuthMethod.zig @@ -0,0 +1,37 @@ +// MySQL authentication methods +pub const AuthMethod = enum { + mysql_native_password, + caching_sha2_password, + sha256_password, + + pub fn scramble(this: AuthMethod, password: []const u8, auth_data: []const u8, buf: *[32]u8) ![]u8 { + if (password.len == 0) { + return &.{}; + } + + const len = scrambleLength(this); + + switch (this) { + .mysql_native_password => @memcpy(buf[0..len], &try Auth.mysql_native_password.scramble(password, auth_data)), + .caching_sha2_password => @memcpy(buf[0..len], &try Auth.caching_sha2_password.scramble(password, auth_data)), + .sha256_password => @memcpy(buf[0..len], &try Auth.mysql_native_password.scramble(password, auth_data)), + } + + return buf[0..len]; + } + + pub fn scrambleLength(this: AuthMethod) usize { + return switch (this) { + .mysql_native_password => 20, + .caching_sha2_password => 32, + .sha256_password => 20, + }; + } + + const Map = bun.ComptimeEnumMap(AuthMethod); + + pub const fromString = Map.get; +}; + +const Auth = @import("./protocol/Auth.zig"); +const bun = @import("bun"); diff --git a/src/sql/mysql/Capabilities.zig b/src/sql/mysql/Capabilities.zig new file mode 100644 index 0000000000..3ccfa1c44b --- /dev/null +++ b/src/sql/mysql/Capabilities.zig @@ -0,0 +1,205 @@ +// MySQL capability flags +const Capabilities = @This(); +CLIENT_LONG_PASSWORD: bool = false, +CLIENT_FOUND_ROWS: bool = false, +CLIENT_LONG_FLAG: bool = false, +CLIENT_CONNECT_WITH_DB: bool = false, +CLIENT_NO_SCHEMA: bool = false, +CLIENT_COMPRESS: bool = false, +CLIENT_ODBC: bool = false, +CLIENT_LOCAL_FILES: bool = false, +CLIENT_IGNORE_SPACE: bool = false, +CLIENT_PROTOCOL_41: bool = false, +CLIENT_INTERACTIVE: bool = false, +CLIENT_SSL: bool = false, +CLIENT_IGNORE_SIGPIPE: bool = false, +CLIENT_TRANSACTIONS: bool = false, +CLIENT_RESERVED: bool = false, +CLIENT_SECURE_CONNECTION: bool = false, +CLIENT_MULTI_STATEMENTS: bool = false, +CLIENT_MULTI_RESULTS: bool = false, +CLIENT_PS_MULTI_RESULTS: bool = false, +CLIENT_PLUGIN_AUTH: bool = false, +CLIENT_CONNECT_ATTRS: bool = false, +CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: bool = false, +CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: bool = false, +CLIENT_SESSION_TRACK: bool = false, +CLIENT_DEPRECATE_EOF: bool = false, +CLIENT_OPTIONAL_RESULTSET_METADATA: bool = false, +CLIENT_ZSTD_COMPRESSION_ALGORITHM: bool = false, +CLIENT_QUERY_ATTRIBUTES: bool = false, +MULTI_FACTOR_AUTHENTICATION: bool = false, +CLIENT_CAPABILITY_EXTENSION: bool = false, +CLIENT_SSL_VERIFY_SERVER_CERT: bool = false, +CLIENT_REMEMBER_OPTIONS: bool = false, + +// Constants with correct shift values from MySQL protocol +const _CLIENT_LONG_PASSWORD = 1; // 1 << 0 +const _CLIENT_FOUND_ROWS = 2; // 1 << 1 +const _CLIENT_LONG_FLAG = 4; // 1 << 2 +const _CLIENT_CONNECT_WITH_DB = 8; // 1 << 3 +const _CLIENT_NO_SCHEMA = 16; // 1 << 4 +const _CLIENT_COMPRESS = 32; // 1 << 5 +const _CLIENT_ODBC = 64; // 1 << 6 +const _CLIENT_LOCAL_FILES = 128; // 1 << 7 +const _CLIENT_IGNORE_SPACE = 256; // 1 << 8 +const _CLIENT_PROTOCOL_41 = 512; // 1 << 9 +const _CLIENT_INTERACTIVE = 1024; // 1 << 10 +const _CLIENT_SSL = 2048; // 1 << 11 +const _CLIENT_IGNORE_SIGPIPE = 4096; // 1 << 12 +const _CLIENT_TRANSACTIONS = 8192; // 1 << 13 +const _CLIENT_RESERVED = 16384; // 1 << 14 +const _CLIENT_SECURE_CONNECTION = 32768; // 1 << 15 +const _CLIENT_MULTI_STATEMENTS = 65536; // 1 << 16 +const _CLIENT_MULTI_RESULTS = 131072; // 1 << 17 +const _CLIENT_PS_MULTI_RESULTS = 262144; // 1 << 18 +const _CLIENT_PLUGIN_AUTH = 524288; // 1 << 19 +const _CLIENT_CONNECT_ATTRS = 1048576; // 1 << 20 +const _CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 2097152; // 1 << 21 +const _CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = 4194304; // 1 << 22 +const _CLIENT_SESSION_TRACK = 8388608; // 1 << 23 +const _CLIENT_DEPRECATE_EOF = 16777216; // 1 << 24 +const _CLIENT_OPTIONAL_RESULTSET_METADATA = 33554432; // 1 << 25 +const _CLIENT_ZSTD_COMPRESSION_ALGORITHM = 67108864; // 1 << 26 +const _CLIENT_QUERY_ATTRIBUTES = 134217728; // 1 << 27 +const _MULTI_FACTOR_AUTHENTICATION = 268435456; // 1 << 28 +const _CLIENT_CAPABILITY_EXTENSION = 536870912; // 1 << 29 +const _CLIENT_SSL_VERIFY_SERVER_CERT = 1073741824; // 1 << 30 +const _CLIENT_REMEMBER_OPTIONS = 2147483648; // 1 << 31 + +comptime { + _ = .{ + .CLIENT_LONG_PASSWORD = _CLIENT_LONG_PASSWORD, + .CLIENT_FOUND_ROWS = _CLIENT_FOUND_ROWS, + .CLIENT_LONG_FLAG = _CLIENT_LONG_FLAG, + .CLIENT_CONNECT_WITH_DB = _CLIENT_CONNECT_WITH_DB, + .CLIENT_NO_SCHEMA = _CLIENT_NO_SCHEMA, + .CLIENT_COMPRESS = _CLIENT_COMPRESS, + .CLIENT_ODBC = _CLIENT_ODBC, + .CLIENT_LOCAL_FILES = _CLIENT_LOCAL_FILES, + .CLIENT_IGNORE_SPACE = _CLIENT_IGNORE_SPACE, + .CLIENT_PROTOCOL_41 = _CLIENT_PROTOCOL_41, + .CLIENT_INTERACTIVE = _CLIENT_INTERACTIVE, + .CLIENT_SSL = _CLIENT_SSL, + .CLIENT_IGNORE_SIGPIPE = _CLIENT_IGNORE_SIGPIPE, + .CLIENT_TRANSACTIONS = _CLIENT_TRANSACTIONS, + .CLIENT_RESERVED = _CLIENT_RESERVED, + .CLIENT_SECURE_CONNECTION = _CLIENT_SECURE_CONNECTION, + .CLIENT_MULTI_STATEMENTS = _CLIENT_MULTI_STATEMENTS, + .CLIENT_MULTI_RESULTS = _CLIENT_MULTI_RESULTS, + .CLIENT_PS_MULTI_RESULTS = _CLIENT_PS_MULTI_RESULTS, + .CLIENT_PLUGIN_AUTH = _CLIENT_PLUGIN_AUTH, + .CLIENT_CONNECT_ATTRS = _CLIENT_CONNECT_ATTRS, + .CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = _CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA, + .CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = _CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS, + .CLIENT_SESSION_TRACK = _CLIENT_SESSION_TRACK, + .CLIENT_DEPRECATE_EOF = _CLIENT_DEPRECATE_EOF, + .CLIENT_OPTIONAL_RESULTSET_METADATA = _CLIENT_OPTIONAL_RESULTSET_METADATA, + .CLIENT_ZSTD_COMPRESSION_ALGORITHM = _CLIENT_ZSTD_COMPRESSION_ALGORITHM, + .CLIENT_QUERY_ATTRIBUTES = _CLIENT_QUERY_ATTRIBUTES, + .MULTI_FACTOR_AUTHENTICATION = _MULTI_FACTOR_AUTHENTICATION, + .CLIENT_CAPABILITY_EXTENSION = _CLIENT_CAPABILITY_EXTENSION, + .CLIENT_SSL_VERIFY_SERVER_CERT = _CLIENT_SSL_VERIFY_SERVER_CERT, + .CLIENT_REMEMBER_OPTIONS = _CLIENT_REMEMBER_OPTIONS, + }; +} + +pub fn reject(this: *Capabilities) void { + this.CLIENT_ZSTD_COMPRESSION_ALGORITHM = false; + this.MULTI_FACTOR_AUTHENTICATION = false; + this.CLIENT_CAPABILITY_EXTENSION = false; + this.CLIENT_SSL_VERIFY_SERVER_CERT = false; + this.CLIENT_REMEMBER_OPTIONS = false; + this.CLIENT_COMPRESS = false; + this.CLIENT_INTERACTIVE = false; + this.CLIENT_IGNORE_SIGPIPE = false; + this.CLIENT_NO_SCHEMA = false; + this.CLIENT_ODBC = false; + this.CLIENT_LOCAL_FILES = false; + this.CLIENT_OPTIONAL_RESULTSET_METADATA = false; + this.CLIENT_QUERY_ATTRIBUTES = false; +} + +pub fn format(self: @This(), comptime _: []const u8, _: anytype, writer: anytype) !void { + var first = true; + inline for (comptime std.meta.fieldNames(Capabilities)) |field| { + if (@TypeOf(@field(self, field)) == bool) { + if (@field(self, field)) { + if (!first) { + try writer.writeAll(", "); + } + first = false; + try writer.writeAll(field); + } + } + } +} + +pub fn toInt(this: Capabilities) u32 { + var value: u32 = 0; + + const fields = .{ + "CLIENT_LONG_PASSWORD", + "CLIENT_FOUND_ROWS", + "CLIENT_LONG_FLAG", + "CLIENT_CONNECT_WITH_DB", + "CLIENT_NO_SCHEMA", + "CLIENT_COMPRESS", + "CLIENT_ODBC", + "CLIENT_LOCAL_FILES", + "CLIENT_IGNORE_SPACE", + "CLIENT_PROTOCOL_41", + "CLIENT_INTERACTIVE", + "CLIENT_SSL", + "CLIENT_IGNORE_SIGPIPE", + "CLIENT_TRANSACTIONS", + "CLIENT_RESERVED", + "CLIENT_SECURE_CONNECTION", + "CLIENT_MULTI_STATEMENTS", + "CLIENT_MULTI_RESULTS", + "CLIENT_PS_MULTI_RESULTS", + "CLIENT_PLUGIN_AUTH", + "CLIENT_CONNECT_ATTRS", + "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA", + "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS", + "CLIENT_SESSION_TRACK", + "CLIENT_DEPRECATE_EOF", + "CLIENT_OPTIONAL_RESULTSET_METADATA", + "CLIENT_ZSTD_COMPRESSION_ALGORITHM", + "CLIENT_QUERY_ATTRIBUTES", + "MULTI_FACTOR_AUTHENTICATION", + "CLIENT_CAPABILITY_EXTENSION", + "CLIENT_SSL_VERIFY_SERVER_CERT", + "CLIENT_REMEMBER_OPTIONS", + }; + inline for (fields) |field| { + if (@field(this, field)) { + value |= @field(Capabilities, "_" ++ field); + } + } + + return value; +} + +pub fn fromInt(flags: u32) Capabilities { + var this: Capabilities = .{}; + inline for (comptime std.meta.fieldNames(Capabilities)) |field| { + @field(this, field) = (@field(Capabilities, "_" ++ field) & flags) != 0; + } + return this; +} + +pub fn getDefaultCapabilities(ssl: bool, has_db_name: bool) Capabilities { + return .{ + .CLIENT_PROTOCOL_41 = true, + .CLIENT_PLUGIN_AUTH = true, + .CLIENT_SECURE_CONNECTION = true, + .CLIENT_CONNECT_WITH_DB = has_db_name, + .CLIENT_DEPRECATE_EOF = true, + .CLIENT_SSL = ssl, + .CLIENT_MULTI_STATEMENTS = true, + .CLIENT_MULTI_RESULTS = true, + }; +} + +const std = @import("std"); diff --git a/src/sql/mysql/ConnectionState.zig b/src/sql/mysql/ConnectionState.zig new file mode 100644 index 0000000000..d39aef7582 --- /dev/null +++ b/src/sql/mysql/ConnectionState.zig @@ -0,0 +1,9 @@ +pub const ConnectionState = enum { + disconnected, + connecting, + handshaking, + authenticating, + authentication_awaiting_pk, + connected, + failed, +}; diff --git a/src/sql/mysql/MySQLConnection.zig b/src/sql/mysql/MySQLConnection.zig new file mode 100644 index 0000000000..81e1b226d2 --- /dev/null +++ b/src/sql/mysql/MySQLConnection.zig @@ -0,0 +1,1949 @@ +const MySQLConnection = @This(); + +socket: Socket, +status: ConnectionState = .disconnected, +ref_count: RefCount = RefCount.init(), + +write_buffer: bun.OffsetByteList = .{}, +read_buffer: bun.OffsetByteList = .{}, +last_message_start: u32 = 0, +sequence_id: u8 = 0, + +requests: Queue = Queue.init(bun.default_allocator), +// number of pipelined requests (Bind/Execute/Prepared statements) +pipelined_requests: u32 = 0, +// number of non-pipelined requests (Simple/Copy) +nonpipelinable_requests: u32 = 0, + +statements: PreparedStatementsMap = .{}, + +poll_ref: bun.Async.KeepAlive = .{}, +globalObject: *jsc.JSGlobalObject, +vm: *jsc.VirtualMachine, + +pending_activity_count: std.atomic.Value(u32) = std.atomic.Value(u32).init(0), +js_value: JSValue = .js_undefined, + +server_version: bun.ByteList = .{}, +connection_id: u32 = 0, +capabilities: Capabilities = .{}, +character_set: CharacterSet = CharacterSet.default, +status_flags: StatusFlags = .{}, + +auth_plugin: ?AuthMethod = null, +auth_state: AuthState = .{ .pending = {} }, + +auth_data: []const u8 = "", +database: []const u8 = "", +user: []const u8 = "", +password: []const u8 = "", +options: []const u8 = "", +options_buf: []const u8 = "", + +tls_ctx: ?*uws.SocketContext = null, +tls_config: jsc.API.ServerConfig.SSLConfig = .{}, +tls_status: TLSStatus = .none, +ssl_mode: SSLMode = .disable, + +idle_timeout_interval_ms: u32 = 0, +connection_timeout_ms: u32 = 0, + +flags: ConnectionFlags = .{}, + +/// Before being connected, this is a connection timeout timer. +/// After being connected, this is an idle timeout timer. +timer: bun.api.Timer.EventLoopTimer = .{ + .tag = .MySQLConnectionTimeout, + .next = .{ + .sec = 0, + .nsec = 0, + }, +}, + +/// This timer controls the maximum lifetime of a connection. +/// It starts when the connection successfully starts (i.e. after handshake is complete). +/// It stops when the connection is closed. +max_lifetime_interval_ms: u32 = 0, +max_lifetime_timer: bun.api.Timer.EventLoopTimer = .{ + .tag = .MySQLConnectionMaxLifetime, + .next = .{ + .sec = 0, + .nsec = 0, + }, +}, + +auto_flusher: AutoFlusher = .{}, + +pub const ref = RefCount.ref; +pub const deref = RefCount.deref; + +pub fn onAutoFlush(this: *@This()) bool { + if (this.flags.has_backpressure) { + debug("onAutoFlush: has backpressure", .{}); + this.auto_flusher.registered = false; + // if we have backpressure, wait for onWritable + return false; + } + this.ref(); + defer this.deref(); + debug("onAutoFlush: draining", .{}); + // drain as much as we can + this.drainInternal(); + + // if we dont have backpressure and if we still have data to send, return true otherwise return false and wait for onWritable + const keep_flusher_registered = !this.flags.has_backpressure and this.write_buffer.len() > 0; + debug("onAutoFlush: keep_flusher_registered: {}", .{keep_flusher_registered}); + this.auto_flusher.registered = keep_flusher_registered; + return keep_flusher_registered; +} + +pub fn canPipeline(this: *@This()) bool { + if (bun.getRuntimeFeatureFlag(.BUN_FEATURE_FLAG_DISABLE_SQL_AUTO_PIPELINING)) { + @branchHint(.unlikely); + return false; + } + return this.status == .connected and + this.nonpipelinable_requests == 0 and // need to wait for non pipelinable requests to finish + !this.flags.use_unnamed_prepared_statements and // unnamed statements are not pipelinable + !this.flags.waiting_to_prepare and // cannot pipeline when waiting prepare + !this.flags.has_backpressure and // dont make sense to buffer more if we have backpressure + this.write_buffer.len() < MAX_PIPELINE_SIZE; // buffer is too big need to flush before pipeline more +} +pub const AuthState = union(enum) { + pending: void, + native_password: void, + caching_sha2: CachingSha2, + ok: void, + + pub const CachingSha2 = union(enum) { + fast_auth, + full_auth, + waiting_key, + }; +}; + +pub fn hasPendingActivity(this: *MySQLConnection) bool { + return this.pending_activity_count.load(.acquire) > 0; +} + +fn updateHasPendingActivity(this: *MySQLConnection) void { + const a: u32 = if (this.requests.readableLength() > 0) 1 else 0; + const b: u32 = if (this.status != .disconnected) 1 else 0; + this.pending_activity_count.store(a + b, .release); +} + +fn hasDataToSend(this: *@This()) bool { + if (this.write_buffer.len() > 0) { + return true; + } + if (this.current()) |request| { + switch (request.status) { + .pending, .binding => return true, + else => return false, + } + } + return false; +} + +fn registerAutoFlusher(this: *@This()) void { + const has_data_to_send = this.hasDataToSend(); + debug("registerAutoFlusher: backpressure: {} registered: {} has_data_to_send: {}", .{ this.flags.has_backpressure, this.auto_flusher.registered, has_data_to_send }); + + if (!this.auto_flusher.registered and // should not be registered + !this.flags.has_backpressure and // if has backpressure we need to wait for onWritable event + has_data_to_send and // we need data to send + this.status == .connected //and we need to be connected + ) { + AutoFlusher.registerDeferredMicrotaskWithTypeUnchecked(@This(), this, this.vm); + this.auto_flusher.registered = true; + } +} +pub fn flushDataAndResetTimeout(this: *@This()) void { + this.resetConnectionTimeout(); + // defer flushing, so if many queries are running in parallel in the same connection, we don't flush more than once + this.registerAutoFlusher(); +} + +fn unregisterAutoFlusher(this: *@This()) void { + debug("unregisterAutoFlusher registered: {}", .{this.auto_flusher.registered}); + if (this.auto_flusher.registered) { + AutoFlusher.unregisterDeferredMicrotaskWithType(@This(), this, this.vm); + this.auto_flusher.registered = false; + } +} + +fn getTimeoutInterval(this: *const @This()) u32 { + return switch (this.status) { + .connected => this.idle_timeout_interval_ms, + .failed => 0, + else => this.connection_timeout_ms, + }; +} +pub fn disableConnectionTimeout(this: *@This()) void { + if (this.timer.state == .ACTIVE) { + this.vm.timer.remove(&this.timer); + } + this.timer.state = .CANCELLED; +} +pub fn resetConnectionTimeout(this: *@This()) void { + // if we are processing data, don't reset the timeout, wait for the data to be processed + if (this.flags.is_processing_data) return; + const interval = this.getTimeoutInterval(); + if (this.timer.state == .ACTIVE) { + this.vm.timer.remove(&this.timer); + } + if (interval == 0) { + return; + } + + this.timer.next = bun.timespec.msFromNow(@intCast(interval)); + this.vm.timer.insert(&this.timer); +} + +fn setupMaxLifetimeTimerIfNecessary(this: *@This()) void { + if (this.max_lifetime_interval_ms == 0) return; + if (this.max_lifetime_timer.state == .ACTIVE) return; + + this.max_lifetime_timer.next = bun.timespec.msFromNow(@intCast(this.max_lifetime_interval_ms)); + this.vm.timer.insert(&this.max_lifetime_timer); +} + +pub fn onConnectionTimeout(this: *@This()) bun.api.Timer.EventLoopTimer.Arm { + debug("onConnectionTimeout", .{}); + + this.timer.state = .FIRED; + if (this.flags.is_processing_data) { + return .disarm; + } + + if (this.getTimeoutInterval() == 0) { + this.resetConnectionTimeout(); + return .disarm; + } + + switch (this.status) { + .connected => { + this.failFmt(error.IdleTimeout, "Idle timeout reached after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.idle_timeout_interval_ms) *| std.time.ns_per_ms)}); + }, + else => { + this.failFmt(error.ConnectionTimedOut, "Connection timeout after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.connection_timeout_ms) *| std.time.ns_per_ms)}); + }, + .handshaking, + .authenticating, + .authentication_awaiting_pk, + => { + this.failFmt(error.ConnectionTimedOut, "Connection timed out after {} (during authentication)", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.connection_timeout_ms) *| std.time.ns_per_ms)}); + }, + } + return .disarm; +} + +pub fn onMaxLifetimeTimeout(this: *@This()) bun.api.Timer.EventLoopTimer.Arm { + debug("onMaxLifetimeTimeout", .{}); + this.max_lifetime_timer.state = .FIRED; + if (this.status == .failed) return .disarm; + this.failFmt(error.LifetimeTimeout, "Max lifetime timeout reached after {}", .{bun.fmt.fmtDurationOneDecimal(@as(u64, this.max_lifetime_interval_ms) *| std.time.ns_per_ms)}); + return .disarm; +} +fn drainInternal(this: *@This()) void { + debug("drainInternal", .{}); + if (this.vm.isShuttingDown()) return this.close(); + + const event_loop = this.vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + + this.flushData(); + + if (!this.flags.has_backpressure) { + // no backpressure yet so pipeline more if possible and flush again + this.advance(); + this.flushData(); + } +} +pub fn finalize(this: *MySQLConnection) void { + this.stopTimers(); + debug("MySQLConnection finalize", .{}); + + // Ensure we disconnect before finalizing + if (this.status != .disconnected) { + this.disconnect(); + } + + this.js_value = .zero; + this.deref(); +} + +pub fn doRef(this: *@This(), _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue { + this.poll_ref.ref(this.vm); + this.updateHasPendingActivity(); + return .js_undefined; +} + +pub fn doUnref(this: *@This(), _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue { + this.poll_ref.unref(this.vm); + this.updateHasPendingActivity(); + return .js_undefined; +} + +pub fn doFlush(this: *MySQLConnection, _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue { + this.registerAutoFlusher(); + return .js_undefined; +} + +pub fn createQuery(this: *MySQLConnection, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + _ = callframe; + _ = globalObject; + _ = this; + + return .js_undefined; +} + +pub fn getConnected(this: *MySQLConnection, _: *jsc.JSGlobalObject) JSValue { + return JSValue.jsBoolean(this.status == .connected); +} + +pub fn doClose(this: *MySQLConnection, globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue { + _ = globalObject; + this.disconnect(); + this.write_buffer.deinit(bun.default_allocator); + + return .js_undefined; +} + +pub fn constructor(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!*MySQLConnection { + _ = callframe; + + return globalObject.throw("MySQLConnection cannot be constructed directly", .{}); +} + +pub fn flushData(this: *@This()) void { + // we know we still have backpressure so just return we will flush later + if (this.flags.has_backpressure) { + debug("flushData: has backpressure", .{}); + return; + } + + const chunk = this.write_buffer.remaining(); + if (chunk.len == 0) { + debug("flushData: no data to flush", .{}); + return; + } + + const wrote = this.socket.write(chunk); + this.flags.has_backpressure = wrote < chunk.len; + debug("flushData: wrote {d}/{d} bytes", .{ wrote, chunk.len }); + if (wrote > 0) { + SocketMonitor.write(chunk[0..@intCast(wrote)]); + this.write_buffer.consume(@intCast(wrote)); + } +} + +pub fn stopTimers(this: *@This()) void { + if (this.timer.state == .ACTIVE) { + this.vm.timer.remove(&this.timer); + } + if (this.max_lifetime_timer.state == .ACTIVE) { + this.vm.timer.remove(&this.max_lifetime_timer); + } +} + +pub fn getQueriesArray(this: *const @This()) JSValue { + return js.queriesGetCached(this.js_value) orelse .zero; +} +pub fn failFmt(this: *@This(), error_code: AnyMySQLError.Error, comptime fmt: [:0]const u8, args: anytype) void { + const message = std.fmt.allocPrint(bun.default_allocator, fmt, args) catch bun.outOfMemory(); + defer bun.default_allocator.free(message); + + const err = AnyMySQLError.mysqlErrorToJS(this.globalObject, message, error_code); + this.failWithJSValue(err); +} +pub fn failWithJSValue(this: *MySQLConnection, value: JSValue) void { + defer this.updateHasPendingActivity(); + this.stopTimers(); + if (this.status == .failed) return; + this.setStatus(.failed); + + this.ref(); + defer this.deref(); + // we defer the refAndClose so the on_close will be called first before we reject the pending requests + defer this.refAndClose(value); + const on_close = this.consumeOnCloseCallback(this.globalObject) orelse return; + + const loop = this.vm.eventLoop(); + loop.enter(); + defer loop.exit(); + _ = on_close.call( + this.globalObject, + this.js_value, + &[_]JSValue{ + value, + this.getQueriesArray(), + }, + ) catch |e| this.globalObject.reportActiveExceptionAsUnhandled(e); +} + +pub fn fail(this: *MySQLConnection, message: []const u8, err: AnyMySQLError.Error) void { + debug("failed: {s}: {s}", .{ message, @errorName(err) }); + const instance = AnyMySQLError.mysqlErrorToJS(this.globalObject, message, err); + this.failWithJSValue(instance); +} + +pub fn onClose(this: *MySQLConnection) void { + var vm = this.vm; + defer vm.drainMicrotasks(); + this.fail("Connection closed", error.ConnectionClosed); +} + +fn refAndClose(this: *@This(), js_reason: ?jsc.JSValue) void { + // refAndClose is always called when we wanna to disconnect or when we are closed + + if (!this.socket.isClosed()) { + // event loop need to be alive to close the socket + this.poll_ref.ref(this.vm); + // will unref on socket close + this.socket.close(); + } + + // cleanup requests + this.cleanUpRequests(js_reason); +} + +pub fn disconnect(this: *@This()) void { + this.stopTimers(); + if (this.status == .connected) { + this.setStatus(.disconnected); + this.poll_ref.disable(); + + const requests = this.requests.readableSlice(0); + this.requests.head = 0; + this.requests.count = 0; + + // Fail any pending requests + for (requests) |request| { + this.finishRequest(request); + request.onError(.{ + .error_code = 2013, // CR_SERVER_LOST + .error_message = .{ .temporary = "Lost connection to MySQL server" }, + }, this.globalObject); + } + + this.socket.close(); + } +} + +fn finishRequest(this: *@This(), item: *MySQLQuery) void { + switch (item.status) { + .running, .binding, .partial_response => { + if (item.flags.simple) { + this.nonpipelinable_requests -= 1; + } else if (item.flags.pipelined) { + this.pipelined_requests -= 1; + } + }, + .success, .fail, .pending => { + if (this.flags.waiting_to_prepare) { + this.flags.waiting_to_prepare = false; + } + }, + } +} + +fn current(this: *@This()) ?*MySQLQuery { + if (this.requests.readableLength() == 0) { + return null; + } + + return this.requests.peekItem(0); +} + +pub fn canExecuteQuery(this: *@This()) bool { + if (this.status != .connected) return false; + return this.flags.is_ready_for_query and this.current() == null; +} +pub fn canPrepareQuery(this: *@This()) bool { + return this.flags.is_ready_for_query and !this.flags.waiting_to_prepare and this.pipelined_requests == 0; +} + +fn cleanUpRequests(this: *@This(), js_reason: ?jsc.JSValue) void { + while (this.current()) |request| { + switch (request.status) { + // pending we will fail the request and the stmt will be marked as error ConnectionClosed too + .pending => { + const stmt = request.statement orelse continue; + stmt.status = .failed; + if (!this.vm.isShuttingDown()) { + if (js_reason) |reason| { + request.onJSError(reason, this.globalObject); + } else { + request.onError(.{ + .error_code = 2013, + .error_message = .{ .temporary = "Connection closed" }, + }, this.globalObject); + } + } + }, + // in the middle of running + .binding, + .running, + .partial_response, + => { + this.finishRequest(request); + if (!this.vm.isShuttingDown()) { + if (js_reason) |reason| { + request.onJSError(reason, this.globalObject); + } else { + request.onError(.{ + .error_code = 2013, + .error_message = .{ .temporary = "Connection closed" }, + }, this.globalObject); + } + } + }, + // just ignore success and fail cases + .success, .fail => {}, + } + request.deref(); + this.requests.discard(1); + } +} +fn advance(this: *@This()) void { + var offset: usize = 0; + debug("advance", .{}); + defer { + while (this.requests.readableLength() > 0) { + const result = this.requests.peekItem(0); + // An item may be in the success or failed state and still be inside the queue (see deinit later comments) + // so we do the cleanup her + switch (result.status) { + .success => { + result.deref(); + this.requests.discard(1); + continue; + }, + .fail => { + result.deref(); + this.requests.discard(1); + continue; + }, + else => break, // trully current item + } + } + } + + while (this.requests.readableLength() > offset and !this.flags.has_backpressure) { + if (this.vm.isShuttingDown()) return this.close(); + var req: *MySQLQuery = this.requests.peekItem(offset); + switch (req.status) { + .pending => { + if (req.flags.simple) { + if (this.pipelined_requests > 0 or !this.flags.is_ready_for_query) { + debug("cannot execute simple query, pipelined_requests: {d}, is_ready_for_query: {}", .{ this.pipelined_requests, this.flags.is_ready_for_query }); + // need to wait for the previous request to finish before starting simple queries + return; + } + + var query_str = req.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + + debug("execute simple query: {d} {s}", .{ this.sequence_id, query_str.slice() }); + + MySQLRequest.executeQuery(query_str.slice(), MySQLConnection.Writer, this.writer()) catch |err| { + if (this.globalObject.tryTakeException()) |err_| { + req.onJSError(err_, this.globalObject); + } else { + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + } + if (offset == 0) { + req.deref(); + this.requests.discard(1); + } else { + // deinit later + req.status = .fail; + } + debug("executeQuery failed: {s}", .{@errorName(err)}); + offset += 1; + continue; + }; + this.nonpipelinable_requests += 1; + this.flags.is_ready_for_query = false; + req.status = .running; + this.flushDataAndResetTimeout(); + return; + } else { + if (req.statement) |statement| { + switch (statement.status) { + .failed => { + debug("stmt failed", .{}); + req.onError(statement.error_response, this.globalObject); + if (offset == 0) { + req.deref(); + this.requests.discard(1); + } else { + // deinit later + req.status = .fail; + offset += 1; + } + continue; + }, + .prepared => { + req.bindAndExecute(this.writer(), statement, this.globalObject) catch |err| { + if (this.globalObject.tryTakeException()) |err_| { + req.onJSError(err_, this.globalObject); + } else { + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + } + if (offset == 0) { + req.deref(); + this.requests.discard(1); + } else { + // deinit later + req.status = .fail; + offset += 1; + } + debug("executeQuery failed: {s}", .{@errorName(err)}); + continue; + }; + + req.flags.pipelined = true; + this.pipelined_requests += 1; + this.flags.is_ready_for_query = false; + this.flushDataAndResetTimeout(); + if (this.flags.use_unnamed_prepared_statements or !this.canPipeline()) { + debug("cannot pipeline more stmt", .{}); + return; + } + offset += 1; + continue; + }, + .pending => { + if (!this.canPrepareQuery()) { + debug("need to wait to finish the pipeline before starting a new query preparation", .{}); + // need to wait to finish the pipeline before starting a new query preparation + return; + } + // We're waiting for prepare response + req.statement.?.status = .parsing; + var query_str = req.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + MySQLRequest.prepareRequest(query_str.slice(), Writer, this.writer()) catch |err| { + if (this.globalObject.tryTakeException()) |err_| { + req.onJSError(err_, this.globalObject); + } else { + req.onWriteFail(err, this.globalObject, this.getQueriesArray()); + } + if (offset == 0) { + req.deref(); + this.requests.discard(1); + } else { + // deinit later + req.status = .fail; + offset += 1; + } + debug("executeQuery failed: {s}", .{@errorName(err)}); + continue; + }; + this.flags.waiting_to_prepare = true; + this.flags.is_ready_for_query = false; + this.flushDataAndResetTimeout(); + return; + }, + .parsing => { + // we are still parsing, lets wait for it to be prepared or failed + offset += 1; + continue; + }, + } + } + } + }, + .binding, .running, .partial_response => { + offset += 1; + continue; + }, + .success => { + if (offset > 0) { + // deinit later + req.status = .fail; + offset += 1; + continue; + } + req.deref(); + this.requests.discard(1); + continue; + }, + .fail => { + if (offset > 0) { + // deinit later + offset += 1; + continue; + } + req.deref(); + this.requests.discard(1); + continue; + }, + } + } +} + +fn SocketHandler(comptime ssl: bool) type { + return struct { + const SocketType = uws.NewSocketHandler(ssl); + fn _socket(s: SocketType) Socket { + if (comptime ssl) { + return Socket{ .SocketTLS = s }; + } + + return Socket{ .SocketTCP = s }; + } + pub fn onOpen(this: *MySQLConnection, socket: SocketType) void { + this.onOpen(_socket(socket)); + } + + fn onHandshake_(this: *MySQLConnection, _: anytype, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + this.onHandshake(success, ssl_error); + } + + pub const onHandshake = if (ssl) onHandshake_ else null; + + pub fn onClose(this: *MySQLConnection, socket: SocketType, _: i32, _: ?*anyopaque) void { + _ = socket; + this.onClose(); + } + + pub fn onEnd(this: *MySQLConnection, socket: SocketType) void { + _ = socket; + this.onClose(); + } + + pub fn onConnectError(this: *MySQLConnection, socket: SocketType, _: i32) void { + _ = socket; + this.onClose(); + } + + pub fn onTimeout(this: *MySQLConnection, socket: SocketType) void { + _ = socket; + this.onTimeout(); + } + + pub fn onData(this: *MySQLConnection, socket: SocketType, data: []const u8) void { + _ = socket; + this.onData(data); + } + + pub fn onWritable(this: *MySQLConnection, socket: SocketType) void { + _ = socket; + this.onDrain(); + } + }; +} + +pub fn onTimeout(this: *MySQLConnection) void { + this.fail("Connection timed out", error.ConnectionTimedOut); +} + +pub fn onDrain(this: *MySQLConnection) void { + debug("onDrain", .{}); + this.flags.has_backpressure = false; + this.drainInternal(); +} + +pub fn call(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + var vm = globalObject.bunVM(); + const arguments = callframe.arguments(); + const hostname_str = try arguments[0].toBunString(globalObject); + defer hostname_str.deref(); + const port = try arguments[1].coerce(i32, globalObject); + + const username_str = try arguments[2].toBunString(globalObject); + defer username_str.deref(); + const password_str = try arguments[3].toBunString(globalObject); + defer password_str.deref(); + const database_str = try arguments[4].toBunString(globalObject); + defer database_str.deref(); + // TODO: update this to match MySQL. + const ssl_mode: SSLMode = switch (arguments[5].toInt32()) { + 0 => .disable, + 1 => .prefer, + 2 => .require, + 3 => .verify_ca, + 4 => .verify_full, + else => .disable, + }; + + const tls_object = arguments[6]; + + var tls_config: jsc.API.ServerConfig.SSLConfig = .{}; + var tls_ctx: ?*uws.SocketContext = null; + if (ssl_mode != .disable) { + tls_config = if (tls_object.isBoolean() and tls_object.toBoolean()) + .{} + else if (tls_object.isObject()) + (jsc.API.ServerConfig.SSLConfig.fromJS(vm, globalObject, tls_object) catch return .zero) orelse .{} + else { + return globalObject.throwInvalidArguments("tls must be a boolean or an object", .{}); + }; + + if (globalObject.hasException()) { + tls_config.deinit(); + return .zero; + } + + // we always request the cert so we can verify it and also we manually abort the connection if the hostname doesn't match + const original_reject_unauthorized = tls_config.reject_unauthorized; + tls_config.reject_unauthorized = 0; + tls_config.request_cert = 1; + + // We create it right here so we can throw errors early. + const context_options = tls_config.asUSockets(); + var err: uws.create_bun_socket_error_t = .none; + tls_ctx = uws.SocketContext.createSSLContext(vm.uwsLoop(), @sizeOf(*@This()), context_options, &err) orelse { + if (err != .none) { + return globalObject.throw("failed to create TLS context", .{}); + } else { + return globalObject.throwValue(err.toJS(globalObject)); + } + }; + + // restore the original reject_unauthorized + tls_config.reject_unauthorized = original_reject_unauthorized; + if (err != .none) { + tls_config.deinit(); + if (tls_ctx) |ctx| { + ctx.deinit(true); + } + return globalObject.throwValue(err.toJS(globalObject)); + } + + uws.NewSocketHandler(true).configure(tls_ctx.?, true, *@This(), SocketHandler(true)); + } + + var username: []const u8 = ""; + var password: []const u8 = ""; + var database: []const u8 = ""; + var options: []const u8 = ""; + var path: []const u8 = ""; + + const options_str = try arguments[7].toBunString(globalObject); + defer options_str.deref(); + + const path_str = try arguments[8].toBunString(globalObject); + defer path_str.deref(); + + const options_buf: []u8 = brk: { + var b = bun.StringBuilder{}; + b.cap += username_str.utf8ByteLength() + 1 + password_str.utf8ByteLength() + 1 + database_str.utf8ByteLength() + 1 + options_str.utf8ByteLength() + 1 + path_str.utf8ByteLength() + 1; + + b.allocate(bun.default_allocator) catch {}; + var u = username_str.toUTF8WithoutRef(bun.default_allocator); + defer u.deinit(); + username = b.append(u.slice()); + + var p = password_str.toUTF8WithoutRef(bun.default_allocator); + defer p.deinit(); + password = b.append(p.slice()); + + var d = database_str.toUTF8WithoutRef(bun.default_allocator); + defer d.deinit(); + database = b.append(d.slice()); + + var o = options_str.toUTF8WithoutRef(bun.default_allocator); + defer o.deinit(); + options = b.append(o.slice()); + + var _path = path_str.toUTF8WithoutRef(bun.default_allocator); + defer _path.deinit(); + path = b.append(_path.slice()); + + break :brk b.allocatedSlice(); + }; + + const on_connect = arguments[9]; + const on_close = arguments[10]; + const idle_timeout = arguments[11].toInt32(); + const connection_timeout = arguments[12].toInt32(); + const max_lifetime = arguments[13].toInt32(); + const use_unnamed_prepared_statements = arguments[14].asBoolean(); + + var ptr = try bun.default_allocator.create(MySQLConnection); + + ptr.* = MySQLConnection{ + .globalObject = globalObject, + .vm = vm, + .database = database, + .user = username, + .password = password, + .options = options, + .options_buf = options_buf, + .socket = .{ .SocketTCP = .{ .socket = .{ .detached = {} } } }, + .requests = Queue.init(bun.default_allocator), + .statements = PreparedStatementsMap{}, + .tls_config = tls_config, + .tls_ctx = tls_ctx, + .ssl_mode = ssl_mode, + .tls_status = if (ssl_mode != .disable) .pending else .none, + .idle_timeout_interval_ms = @intCast(idle_timeout), + .connection_timeout_ms = @intCast(connection_timeout), + .max_lifetime_interval_ms = @intCast(max_lifetime), + .character_set = CharacterSet.default, + .flags = .{ + .use_unnamed_prepared_statements = use_unnamed_prepared_statements, + }, + }; + + { + const hostname = hostname_str.toUTF8(bun.default_allocator); + defer hostname.deinit(); + + const ctx = vm.rareData().mysql_context.tcp orelse brk: { + const ctx_ = uws.SocketContext.createNoSSLContext(vm.uwsLoop(), @sizeOf(*@This())).?; + uws.NewSocketHandler(false).configure(ctx_, true, *@This(), SocketHandler(false)); + vm.rareData().mysql_context.tcp = ctx_; + break :brk ctx_; + }; + + if (path.len > 0) { + ptr.socket = .{ + .SocketTCP = uws.SocketTCP.connectUnixAnon(path, ctx, ptr, false) catch |err| { + tls_config.deinit(); + if (tls_ctx) |tls| { + tls.deinit(true); + } + ptr.deinit(); + return globalObject.throwError(err, "failed to connect to postgresql"); + }, + }; + } else { + ptr.socket = .{ + .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| { + tls_config.deinit(); + if (tls_ctx) |tls| { + tls.deinit(true); + } + ptr.deinit(); + return globalObject.throwError(err, "failed to connect to mysql"); + }, + }; + } + } + ptr.setStatus(.connecting); + ptr.updateHasPendingActivity(); + ptr.resetConnectionTimeout(); + ptr.poll_ref.ref(vm); + const js_value = ptr.toJS(globalObject); + js_value.ensureStillAlive(); + ptr.js_value = js_value; + js.onconnectSetCached(js_value, globalObject, on_connect); + js.oncloseSetCached(js_value, globalObject, on_close); + + return js_value; +} + +pub fn deinit(this: *MySQLConnection) void { + this.disconnect(); + this.stopTimers(); + debug("MySQLConnection deinit", .{}); + + var requests = this.requests; + defer requests.deinit(); + this.requests = Queue.init(bun.default_allocator); + + // Clear any pending requests first + for (requests.readableSlice(0)) |request| { + this.finishRequest(request); + request.onError(.{ + .error_code = 2013, + .error_message = .{ .temporary = "Connection closed" }, + }, this.globalObject); + } + this.write_buffer.deinit(bun.default_allocator); + this.read_buffer.deinit(bun.default_allocator); + this.statements.deinit(bun.default_allocator); + bun.default_allocator.free(this.auth_data); + this.auth_data = ""; + this.tls_config.deinit(); + if (this.tls_ctx) |ctx| { + ctx.deinit(true); + } + bun.default_allocator.free(this.options_buf); + bun.default_allocator.destroy(this); +} + +pub fn onOpen(this: *MySQLConnection, socket: Socket) void { + this.setupMaxLifetimeTimerIfNecessary(); + this.resetConnectionTimeout(); + this.socket = socket; + this.setStatus(.handshaking); + this.poll_ref.ref(this.vm); + this.updateHasPendingActivity(); +} + +pub fn onHandshake(this: *MySQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + debug("onHandshake: {d} {d}", .{ success, ssl_error.error_no }); + const handshake_success = if (success == 1) true else false; + if (handshake_success) { + if (this.tls_config.reject_unauthorized != 0) { + // only reject the connection if reject_unauthorized == true + switch (this.ssl_mode) { + // https://github.com/porsager/postgres/blob/6ec85a432b17661ccacbdf7f765c651e88969d36/src/connection.js#L272-L279 + + .verify_ca, .verify_full => { + if (ssl_error.error_no != 0) { + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + return; + } + + const ssl_ptr: *BoringSSL.c.SSL = @ptrCast(this.socket.getNativeHandle()); + if (BoringSSL.c.SSL_get_servername(ssl_ptr, 0)) |servername| { + const hostname = servername[0..bun.len(servername)]; + if (!BoringSSL.checkServerIdentity(ssl_ptr, hostname)) { + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + } + } + }, + else => { + return; + }, + } + } + } else { + // if we are here is because server rejected us, and the error_no is the cause of this + // no matter if reject_unauthorized is false because we are disconnected by the server + this.failWithJSValue(ssl_error.toJS(this.globalObject)); + } +} + +pub fn onData(this: *MySQLConnection, data: []const u8) void { + this.ref(); + this.flags.is_processing_data = true; + const vm = this.vm; + // Clear the timeout. + this.socket.setTimeout(0); + + defer { + if (this.status == .connected and this.requests.readableLength() == 0 and this.write_buffer.remaining().len == 0) { + // Don't keep the process alive when there's nothixng to do. + this.poll_ref.unref(vm); + } else if (this.status == .connected) { + // Keep the process alive if there's something to do. + this.poll_ref.ref(vm); + } + // reset the connection timeout after we're done processing the data + this.flags.is_processing_data = false; + this.resetConnectionTimeout(); + this.deref(); + } + + const event_loop = vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + + SocketMonitor.read(data); + + if (this.read_buffer.remaining().len == 0) { + var consumed: usize = 0; + var offset: usize = 0; + const reader = StackReader.init(data, &consumed, &offset); + this.processPackets(StackReader, reader) catch |err| { + debug("processPackets without buffer: {s}", .{@errorName(err)}); + if (err == error.ShortRead) { + if (comptime bun.Environment.allow_assert) { + debug("Received short read: last_message_start: {d}, head: {d}, len: {d}", .{ + offset, + consumed, + data.len, + }); + } + + this.read_buffer.head = 0; + this.last_message_start = 0; + this.read_buffer.byte_list.len = 0; + this.read_buffer.write(bun.default_allocator, data[offset..]) catch @panic("failed to write to read buffer"); + } else { + if (comptime bun.Environment.allow_assert) { + bun.handleErrorReturnTrace(err, @errorReturnTrace()); + } + this.fail("Failed to read data", err); + } + }; + return; + } + + { + this.read_buffer.head = this.last_message_start; + + this.read_buffer.write(bun.default_allocator, data) catch @panic("failed to write to read buffer"); + this.processPackets(Reader, this.bufferedReader()) catch |err| { + debug("processPackets with buffer: {s}", .{@errorName(err)}); + if (err != error.ShortRead) { + if (comptime bun.Environment.allow_assert) { + if (@errorReturnTrace()) |trace| { + debug("Error: {s}\n{}", .{ @errorName(err), trace }); + } + } + this.fail("Failed to read data", err); + return; + } + + if (comptime bun.Environment.allow_assert) { + debug("Received short read: last_message_start: {d}, head: {d}, len: {d}", .{ + this.last_message_start, + this.read_buffer.head, + this.read_buffer.byte_list.len, + }); + } + + return; + }; + + this.last_message_start = 0; + this.read_buffer.head = 0; + } +} + +pub fn processPackets(this: *MySQLConnection, comptime Context: type, reader: NewReader(Context)) AnyMySQLError.Error!void { + while (true) { + reader.markMessageStart(); + + // Read packet header + const header = PacketHeader.decode(reader.peek()) orelse return AnyMySQLError.Error.ShortRead; + const header_length = header.length; + debug("sequence_id: {d} header: {d}", .{ this.sequence_id, header_length }); + // Ensure we have the full packet + reader.ensureCapacity(header_length + PacketHeader.size) catch return AnyMySQLError.Error.ShortRead; + // always skip the full packet, we dont care about padding or unreaded bytes + defer reader.setOffsetFromStart(header_length + PacketHeader.size); + reader.skip(PacketHeader.size); + + // Update sequence id + this.sequence_id = header.sequence_id +% 1; + + // Process packet based on connection state + switch (this.status) { + .handshaking => try this.handleHandshake(Context, reader), + .authenticating, .authentication_awaiting_pk => try this.handleAuth(Context, reader, header_length), + .connected => try this.handleCommand(Context, reader, header_length), + else => { + debug("Unexpected packet in state {s}", .{@tagName(this.status)}); + return error.UnexpectedPacket; + }, + } + } +} + +pub fn handleHandshake(this: *MySQLConnection, comptime Context: type, reader: NewReader(Context)) AnyMySQLError.Error!void { + var handshake = HandshakeV10{}; + try handshake.decode(reader); + defer handshake.deinit(); + + // Store server info + this.server_version = try handshake.server_version.toOwned(); + this.connection_id = handshake.connection_id; + // this.capabilities = handshake.capability_flags; + this.capabilities = Capabilities.getDefaultCapabilities(this.ssl_mode != .disable, this.database.len > 0); + + // Override with utf8mb4 instead of using server's default + this.character_set = CharacterSet.default; + this.status_flags = handshake.status_flags; + + debug( + \\Handshake + \\ Server Version: {s} + \\ Connection ID: {d} + \\ Character Set: {d} ({s}) + \\ Server Capabilities: [ {} ] 0x{x:0>8} + \\ Status Flags: [ {} ] + \\ + , .{ + this.server_version.slice(), + this.connection_id, + this.character_set, + this.character_set.label(), + this.capabilities, + this.capabilities.toInt(), + this.status_flags, + }); + + if (this.auth_data.len > 0) { + bun.default_allocator.free(this.auth_data); + this.auth_data = ""; + } + + // Store auth data + const auth_data = try bun.default_allocator.alloc(u8, handshake.auth_plugin_data_part_1.len + handshake.auth_plugin_data_part_2.len); + @memcpy(auth_data[0..8], &handshake.auth_plugin_data_part_1); + @memcpy(auth_data[8..], handshake.auth_plugin_data_part_2); + this.auth_data = auth_data; + + // Get auth plugin + if (handshake.auth_plugin_name.slice().len > 0) { + this.auth_plugin = AuthMethod.fromString(handshake.auth_plugin_name.slice()) orelse { + this.fail("Unsupported auth plugin", error.UnsupportedAuthPlugin); + return; + }; + } + + // Update status + this.setStatus(.authenticating); + + // Send auth response + try this.sendHandshakeResponse(); +} + +fn handleHandshakeDecodePublicKey(this: *MySQLConnection, comptime Context: type, reader: NewReader(Context)) !void { + var response = Auth.caching_sha2_password.PublicKeyResponse{}; + try response.decode(reader); + defer response.deinit(); + // revert back to authenticating since we received the public key + this.setStatus(.authenticating); + + var encrypted_password = Auth.caching_sha2_password.EncryptedPassword{ + .password = this.password, + .public_key = response.data.slice(), + .nonce = this.auth_data, + .sequence_id = this.sequence_id, + }; + try encrypted_password.write(this.writer()); + this.flushData(); +} + +pub fn consumeOnConnectCallback(this: *const @This(), globalObject: *jsc.JSGlobalObject) ?jsc.JSValue { + debug("consumeOnConnectCallback", .{}); + const on_connect = js.onconnectGetCached(this.js_value) orelse return null; + debug("consumeOnConnectCallback exists", .{}); + + js.onconnectSetCached(this.js_value, globalObject, .zero); + return on_connect; +} + +pub fn consumeOnCloseCallback(this: *const @This(), globalObject: *jsc.JSGlobalObject) ?jsc.JSValue { + debug("consumeOnCloseCallback", .{}); + const on_close = js.oncloseGetCached(this.js_value) orelse return null; + debug("consumeOnCloseCallback exists", .{}); + js.oncloseSetCached(this.js_value, globalObject, .zero); + return on_close; +} + +pub fn setStatus(this: *@This(), status: ConnectionState) void { + if (this.status == status) return; + defer this.updateHasPendingActivity(); + + this.status = status; + this.resetConnectionTimeout(); + if (this.vm.isShuttingDown()) return; + + switch (status) { + .connected => { + const on_connect = this.consumeOnConnectCallback(this.globalObject) orelse return; + const js_value = this.js_value; + js_value.ensureStillAlive(); + this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ JSValue.jsNull(), js_value }); + this.poll_ref.unref(this.vm); + }, + else => {}, + } +} + +pub fn updateRef(this: *@This()) void { + this.updateHasPendingActivity(); + if (this.pending_activity_count.raw > 0) { + this.poll_ref.ref(this.vm); + } else { + this.poll_ref.unref(this.vm); + } +} +pub fn handleAuth(this: *MySQLConnection, comptime Context: type, reader: NewReader(Context), header_length: u24) !void { + const first_byte = try reader.int(u8); + reader.skip(-1); + + debug("Auth packet: 0x{x:0>2}", .{first_byte}); + + switch (first_byte) { + @intFromEnum(PacketType.OK) => { + var ok = OKPacket{ + .packet_size = header_length, + }; + try ok.decode(reader); + defer ok.deinit(); + + this.setStatus(.connected); + defer this.updateRef(); + this.status_flags = ok.status_flags; + this.flags.is_ready_for_query = true; + this.advance(); + + this.registerAutoFlusher(); + }, + + @intFromEnum(PacketType.ERROR) => { + var err = ErrorPacket{}; + try err.decode(reader); + defer err.deinit(); + + this.failWithJSValue(err.toJS(this.globalObject)); + return error.AuthenticationFailed; + }, + + @intFromEnum(PacketType.MORE_DATA) => { + // Handle various MORE_DATA cases + if (this.auth_plugin) |plugin| { + switch (plugin) { + .caching_sha2_password => { + reader.skip(1); + + if (this.status == .authentication_awaiting_pk) { + return this.handleHandshakeDecodePublicKey(Context, reader); + } + + var response = Auth.caching_sha2_password.Response{}; + try response.decode(reader); + defer response.deinit(); + + switch (response.status) { + .success => { + debug("success", .{}); + this.setStatus(.connected); + defer this.updateRef(); + this.flags.is_ready_for_query = true; + this.advance(); + this.registerAutoFlusher(); + }, + .continue_auth => { + debug("continue auth", .{}); + + if (this.ssl_mode == .disable) { + // we are in plain TCP so we need to request the public key + this.setStatus(.authentication_awaiting_pk); + var packet = try this.writer().start(this.sequence_id); + + var request = Auth.caching_sha2_password.PublicKeyRequest{}; + try request.write(this.writer()); + try packet.end(); + this.flushData(); + } else { + // SSL mode is enabled, send password as is + var packet = try this.writer().start(this.sequence_id); + try this.writer().write(this.password); + try packet.end(); + this.flushData(); + } + }, + else => { + this.fail("Authentication failed", error.AuthenticationFailed); + }, + } + }, + else => { + debug("Unexpected auth continuation for plugin: {s}", .{@tagName(plugin)}); + return error.UnexpectedPacket; + }, + } + } else if (first_byte == @intFromEnum(PacketType.LOCAL_INFILE)) { + // Handle LOCAL INFILE request + var infile = LocalInfileRequest{ + .packet_size = header_length, + }; + try infile.decode(reader); + defer infile.deinit(); + + // We don't support LOCAL INFILE for security reasons + this.fail("LOCAL INFILE not supported", error.LocalInfileNotSupported); + return; + } else { + debug("Received auth continuation without plugin", .{}); + return error.UnexpectedPacket; + } + }, + + PacketType.AUTH_SWITCH => { + var auth_switch = AuthSwitchRequest{ + .packet_size = header_length, + }; + try auth_switch.decode(reader); + defer auth_switch.deinit(); + + // Update auth plugin and data + const auth_method = AuthMethod.fromString(auth_switch.plugin_name.slice()) orelse { + this.fail("Unsupported auth plugin", error.UnsupportedAuthPlugin); + return; + }; + + // Send new auth response + try this.sendAuthSwitchResponse(auth_method, auth_switch.plugin_data.slice()); + }, + + else => { + debug("Unexpected auth packet: 0x{x:0>2}", .{first_byte}); + return error.UnexpectedPacket; + }, + } +} + +pub fn handleCommand(this: *MySQLConnection, comptime Context: type, reader: NewReader(Context), header_length: u24) !void { + // Get the current request if any + const request = this.current() orelse { + debug("Received unexpected command response", .{}); + return error.UnexpectedPacket; + }; + + debug("handleCommand", .{}); + if (request.flags.simple) { + // Regular query response + return try this.handleResultSet(Context, reader, header_length); + } + + // Handle based on request type + if (request.statement) |statement| { + switch (statement.status) { + .pending => { + return error.UnexpectedPacket; + }, + .parsing => { + // We're waiting for prepare response + try this.handlePreparedStatement(Context, reader, header_length); + }, + .prepared => { + // We're waiting for execute response + try this.handleResultSet(Context, reader, header_length); + }, + .failed => { + defer { + this.advance(); + this.registerAutoFlusher(); + } + this.flags.is_ready_for_query = true; + this.finishRequest(request); + // Statement failed, clean up + request.onError(statement.error_response, this.globalObject); + }, + } + } +} + +pub fn sendHandshakeResponse(this: *MySQLConnection) AnyMySQLError.Error!void { + // Only require password for caching_sha2_password when connecting for the first time + if (this.auth_plugin) |plugin| { + const requires_password = switch (plugin) { + .caching_sha2_password => false, // Allow empty password, server will handle auth flow + .sha256_password => true, // Always requires password + .mysql_native_password => false, // Allows empty password + }; + + if (requires_password and this.password.len == 0) { + this.fail("Password required for authentication", error.PasswordRequired); + return; + } + } + + var response = HandshakeResponse41{ + .capability_flags = this.capabilities, + .max_packet_size = 0, //16777216, + .character_set = CharacterSet.default, + .username = .{ .temporary = this.user }, + .database = .{ .temporary = this.database }, + .auth_plugin_name = .{ + .temporary = if (this.auth_plugin) |plugin| + switch (plugin) { + .mysql_native_password => "mysql_native_password", + .caching_sha2_password => "caching_sha2_password", + .sha256_password => "sha256_password", + } + else + "", + }, + .auth_response = .{ .empty = {} }, + }; + defer response.deinit(); + + // Add some basic connect attributes like mysql2 + try response.connect_attrs.put(bun.default_allocator, try bun.default_allocator.dupe(u8, "_client_name"), try bun.default_allocator.dupe(u8, "Bun")); + try response.connect_attrs.put(bun.default_allocator, try bun.default_allocator.dupe(u8, "_client_version"), try bun.default_allocator.dupe(u8, bun.Global.package_json_version_with_revision)); + + // Generate auth response based on plugin + var scrambled_buf: [32]u8 = undefined; + if (this.auth_plugin) |plugin| { + if (this.auth_data.len == 0) { + this.fail("Missing auth data from server", error.MissingAuthData); + return; + } + + response.auth_response = .{ .temporary = try plugin.scramble(this.password, this.auth_data, &scrambled_buf) }; + } + response.capability_flags.reject(); + try response.write(this.writer()); + this.capabilities = response.capability_flags; + this.flushData(); +} + +pub fn sendAuthSwitchResponse(this: *MySQLConnection, auth_method: AuthMethod, plugin_data: []const u8) !void { + var response = AuthSwitchResponse{}; + defer response.deinit(); + + var scrambled_buf: [32]u8 = undefined; + + response.auth_response = .{ + .temporary = try auth_method.scramble(this.password, plugin_data, &scrambled_buf), + }; + + try response.write(this.writer()); + this.flushData(); +} + +pub const Writer = struct { + connection: *MySQLConnection, + + pub fn write(this: Writer, data: []const u8) AnyMySQLError.Error!void { + var buffer = &this.connection.write_buffer; + try buffer.write(bun.default_allocator, data); + } + + pub fn pwrite(this: Writer, data: []const u8, index: usize) AnyMySQLError.Error!void { + @memcpy(this.connection.write_buffer.byte_list.slice()[index..][0..data.len], data); + } + + pub fn offset(this: Writer) usize { + return this.connection.write_buffer.len(); + } +}; + +pub fn writer(this: *MySQLConnection) NewWriter(Writer) { + return .{ + .wrapped = .{ + .connection = this, + }, + }; +} + +pub const Reader = struct { + connection: *MySQLConnection, + + pub fn markMessageStart(this: Reader) void { + this.connection.last_message_start = this.connection.read_buffer.head; + } + + pub fn setOffsetFromStart(this: Reader, offset: usize) void { + this.connection.read_buffer.head = this.connection.last_message_start + @as(u32, @truncate(offset)); + } + + pub const ensureLength = ensureCapacity; + + pub fn peek(this: Reader) []const u8 { + return this.connection.read_buffer.remaining(); + } + + pub fn skip(this: Reader, count: isize) void { + if (count < 0) { + const abs_count = @abs(count); + if (abs_count > this.connection.read_buffer.head) { + this.connection.read_buffer.head = 0; + return; + } + this.connection.read_buffer.head -= @intCast(abs_count); + return; + } + + const ucount: usize = @intCast(count); + if (this.connection.read_buffer.head + ucount > this.connection.read_buffer.byte_list.len) { + this.connection.read_buffer.head = this.connection.read_buffer.byte_list.len; + return; + } + + this.connection.read_buffer.head += @intCast(ucount); + } + + pub fn ensureCapacity(this: Reader, count: usize) bool { + return this.connection.read_buffer.remaining().len >= count; + } + + pub fn read(this: Reader, count: usize) AnyMySQLError.Error!Data { + const remaining = this.peek(); + if (remaining.len < count) { + return AnyMySQLError.Error.ShortRead; + } + + this.skip(@intCast(count)); + return Data{ + .temporary = remaining[0..count], + }; + } + + pub fn readZ(this: Reader) AnyMySQLError.Error!Data { + const remaining = this.peek(); + if (bun.strings.indexOfChar(remaining, 0)) |zero| { + this.skip(@intCast(zero + 1)); + return Data{ + .temporary = remaining[0..zero], + }; + } + + return error.ShortRead; + } +}; + +pub fn bufferedReader(this: *MySQLConnection) NewReader(Reader) { + return .{ + .wrapped = .{ + .connection = this, + }, + }; +} + +fn checkIfPreparedStatementIsDone(this: *MySQLConnection, statement: *MySQLStatement) void { + debug("checkIfPreparedStatementIsDone: {d} {d} {d} {d}", .{ statement.columns_received, statement.params_received, statement.columns.len, statement.params.len }); + if (statement.columns_received == statement.columns.len and statement.params_received == statement.params.len) { + statement.status = .prepared; + this.flags.waiting_to_prepare = false; + this.flags.is_ready_for_query = true; + statement.reset(); + this.advance(); + this.registerAutoFlusher(); + } +} + +pub fn handlePreparedStatement(this: *MySQLConnection, comptime Context: type, reader: NewReader(Context), header_length: u24) !void { + debug("handlePreparedStatement", .{}); + const first_byte = try reader.int(u8); + reader.skip(-1); + + const request = this.current() orelse { + debug("Unexpected prepared statement packet missing request", .{}); + return error.UnexpectedPacket; + }; + const statement = request.statement orelse { + debug("Unexpected prepared statement packet missing statement", .{}); + return error.UnexpectedPacket; + }; + if (statement.statement_id > 0) { + if (statement.params_received < statement.params.len) { + var column = ColumnDefinition41{}; + defer column.deinit(); + try column.decode(reader); + statement.params[statement.params_received] = .{ + .type = column.column_type, + .flags = column.flags, + }; + statement.params_received += 1; + } else if (statement.columns_received < statement.columns.len) { + try statement.columns[statement.columns_received].decode(reader); + statement.columns_received += 1; + } + this.checkIfPreparedStatementIsDone(statement); + return; + } + + switch (@as(PacketType, @enumFromInt(first_byte))) { + .OK => { + var ok = StmtPrepareOKPacket{ + .packet_length = header_length, + }; + try ok.decode(reader); + + // Get the current request + + statement.statement_id = ok.statement_id; + + // Read parameter definitions if any + if (ok.num_params > 0) { + statement.params = try bun.default_allocator.alloc(MySQLStatement.Param, ok.num_params); + statement.params_received = 0; + } + + // Read column definitions if any + if (ok.num_columns > 0) { + statement.columns = try bun.default_allocator.alloc(ColumnDefinition41, ok.num_columns); + statement.columns_received = 0; + } + + this.checkIfPreparedStatementIsDone(statement); + }, + + .ERROR => { + var err = ErrorPacket{}; + try err.decode(reader); + defer err.deinit(); + defer { + this.advance(); + this.registerAutoFlusher(); + } + this.flags.is_ready_for_query = true; + this.finishRequest(request); + statement.status = .failed; + statement.error_response = err; + request.onError(err, this.globalObject); + }, + + else => { + debug("Unexpected prepared statement packet: 0x{x:0>2}", .{first_byte}); + return error.UnexpectedPacket; + }, + } +} + +fn handleResultSetOK(this: *MySQLConnection, request: *MySQLQuery, statement: *MySQLStatement, status_flags: StatusFlags) void { + this.status_flags = status_flags; + this.flags.is_ready_for_query = !status_flags.has(.SERVER_MORE_RESULTS_EXISTS); + debug("handleResultSetOK: {d} {}", .{ status_flags.toInt(), status_flags.has(.SERVER_MORE_RESULTS_EXISTS) }); + defer { + this.advance(); + this.registerAutoFlusher(); + } + if (this.flags.is_ready_for_query) { + this.finishRequest(request); + } + request.onResult(statement.result_count, this.globalObject, this.js_value, this.flags.is_ready_for_query); + statement.reset(); +} + +pub fn handleResultSet(this: *MySQLConnection, comptime Context: type, reader: NewReader(Context), header_length: u24) !void { + const first_byte = try reader.int(u8); + debug("handleResultSet: {x:0>2}", .{first_byte}); + + reader.skip(-1); + + var request = this.current() orelse { + debug("Unexpected result set packet", .{}); + return error.UnexpectedPacket; + }; + var ok = OKPacket{ + .packet_size = header_length, + }; + switch (@as(PacketType, @enumFromInt(first_byte))) { + .ERROR => { + var err = ErrorPacket{}; + try err.decode(reader); + defer err.deinit(); + defer { + this.advance(); + this.registerAutoFlusher(); + } + if (request.statement) |statement| { + statement.reset(); + } + + this.flags.is_ready_for_query = true; + this.finishRequest(request); + request.onError(err, this.globalObject); + }, + + else => |packet_type| { + const statement = request.statement orelse { + debug("Unexpected result set packet", .{}); + return error.UnexpectedPacket; + }; + if (!statement.execution_flags.header_received) { + if (packet_type == .OK) { + // if packet type is OK it means the query is done and no results are returned + try ok.decode(reader); + defer ok.deinit(); + this.handleResultSetOK(request, statement, ok.status_flags); + return; + } + + var header = ResultSetHeader{}; + try header.decode(reader); + if (header.field_count == 0) { + // Can't be 0 + return error.UnexpectedPacket; + } + if (statement.columns.len != header.field_count) { + debug("header field count mismatch: {d} != {d}", .{ statement.columns.len, header.field_count }); + statement.cached_structure.deinit(); + statement.cached_structure = .{}; + if (statement.columns.len > 0) { + for (statement.columns) |*column| { + column.deinit(); + } + bun.default_allocator.free(statement.columns); + } + statement.columns = try bun.default_allocator.alloc(ColumnDefinition41, header.field_count); + statement.columns_received = 0; + } + statement.execution_flags.needs_duplicate_check = true; + statement.execution_flags.header_received = true; + return; + } else if (statement.columns_received < statement.columns.len) { + try statement.columns[statement.columns_received].decode(reader); + statement.columns_received += 1; + } else { + if (packet_type == .OK or packet_type == .EOF) { + if (request.flags.simple) { + // if we are using the text protocol for sure this is a OK packet otherwise will be OK packet with 0xFE code + try ok.decode(reader); + defer ok.deinit(); + + this.handleResultSetOK(request, statement, ok.status_flags); + return; + } else if (packet_type == .EOF) { + // this is actually a OK packet but with the flag EOF + try ok.decode(reader); + defer ok.deinit(); + this.handleResultSetOK(request, statement, ok.status_flags); + return; + } + } + + var stack_fallback = std.heap.stackFallback(4096, bun.default_allocator); + const allocator = stack_fallback.get(); + var row = ResultSet.Row{ + .globalObject = this.globalObject, + .columns = statement.columns, + .binary = request.flags.binary, + .raw = request.flags.result_mode == .raw, + .bigint = request.flags.bigint, + }; + var structure: JSValue = .js_undefined; + var cached_structure: ?CachedStructure = null; + switch (request.flags.result_mode) { + .objects => { + cached_structure = statement.structure(this.js_value, this.globalObject); + structure = cached_structure.?.jsValue() orelse .js_undefined; + }, + .raw, .values => { + // no need to check for duplicate fields or structure + }, + } + defer row.deinit(allocator); + try row.decode(allocator, reader); + + const pending_value = MySQLQuery.js.pendingValueGetCached(request.thisValue.get()) orelse .zero; + + // Process row data + const row_value = row.toJS( + this.globalObject, + pending_value, + structure, + statement.fields_flags, + request.flags.result_mode, + cached_structure, + ); + if (this.globalObject.tryTakeException()) |err| { + this.finishRequest(request); + request.onJSError(err, this.globalObject); + return error.JSError; + } + statement.result_count += 1; + + if (pending_value == .zero) { + MySQLQuery.js.pendingValueSetCached(request.thisValue.get(), this.globalObject, row_value); + } + } + }, + } +} + +fn close(this: *@This()) void { + this.disconnect(); + this.unregisterAutoFlusher(); + this.write_buffer.deinit(bun.default_allocator); +} + +pub fn closeStatement(this: *MySQLConnection, statement: *MySQLStatement) !void { + var _close = PreparedStatement.Close{ + .statement_id = statement.statement_id, + }; + + try _close.write(this.writer()); + this.flushData(); + this.registerAutoFlusher(); +} + +pub fn resetStatement(this: *MySQLConnection, statement: *MySQLStatement) !void { + var reset = PreparedStatement.Reset{ + .statement_id = statement.statement_id, + }; + + try reset.write(this.writer()); + this.flushData(); + this.registerAutoFlusher(); +} + +pub fn getQueries(_: *@This(), thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject) bun.JSError!jsc.JSValue { + if (js.queriesGetCached(thisValue)) |value| { + return value; + } + + const array = try jsc.JSValue.createEmptyArray(globalObject, 0); + js.queriesSetCached(thisValue, globalObject, array); + + return array; +} + +pub fn getOnConnect(_: *@This(), thisValue: jsc.JSValue, _: *jsc.JSGlobalObject) jsc.JSValue { + if (js.onconnectGetCached(thisValue)) |value| { + return value; + } + + return .js_undefined; +} + +pub fn setOnConnect(_: *@This(), thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject, value: jsc.JSValue) void { + js.onconnectSetCached(thisValue, globalObject, value); +} + +pub fn getOnClose(_: *@This(), thisValue: jsc.JSValue, _: *jsc.JSGlobalObject) jsc.JSValue { + if (js.oncloseGetCached(thisValue)) |value| { + return value; + } + + return .js_undefined; +} + +pub fn setOnClose(_: *@This(), thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject, value: jsc.JSValue) void { + js.oncloseSetCached(thisValue, globalObject, value); +} + +pub const js = jsc.Codegen.JSMySQLConnection; +pub const fromJS = js.fromJS; +pub const fromJSDirect = js.fromJSDirect; +pub const toJS = js.toJS; +const MAX_PIPELINE_SIZE = std.math.maxInt(u16); // about 64KB per connection + +const PreparedStatementsMap = std.HashMapUnmanaged(u64, *MySQLStatement, bun.IdentityContext(u64), 80); +const debug = bun.Output.scoped(.MySQLConnection, .visible); +const RefCount = bun.ptr.RefCount(@This(), "ref_count", deinit, .{}); +const Queue = std.fifo.LinearFifo(*MySQLQuery, .Dynamic); + +const AnyMySQLError = @import("./protocol/AnyMySQLError.zig"); +const Auth = @import("./protocol/Auth.zig"); +const AuthSwitchRequest = @import("./protocol/AuthSwitchRequest.zig"); +const AuthSwitchResponse = @import("./protocol/AuthSwitchResponse.zig"); +const CachedStructure = @import("../shared/CachedStructure.zig"); +const Capabilities = @import("./Capabilities.zig"); +const ColumnDefinition41 = @import("./protocol/ColumnDefinition41.zig"); +const ErrorPacket = @import("./protocol/ErrorPacket.zig"); +const HandshakeResponse41 = @import("./protocol/HandshakeResponse41.zig"); +const HandshakeV10 = @import("./protocol/HandshakeV10.zig"); +const LocalInfileRequest = @import("./protocol/LocalInfileRequest.zig"); +const MySQLQuery = @import("./MySQLQuery.zig"); +const MySQLRequest = @import("./MySQLRequest.zig"); +const MySQLStatement = @import("./MySQLStatement.zig"); +const OKPacket = @import("./protocol/OKPacket.zig"); +const PacketHeader = @import("./protocol/PacketHeader.zig"); +const PreparedStatement = @import("./protocol/PreparedStatement.zig"); +const ResultSet = @import("./protocol/ResultSet.zig"); +const ResultSetHeader = @import("./protocol/ResultSetHeader.zig"); +const SocketMonitor = @import("../postgres/SocketMonitor.zig"); +const StackReader = @import("./protocol/StackReader.zig"); +const StmtPrepareOKPacket = @import("./protocol/StmtPrepareOKPacket.zig"); +const std = @import("std"); +const AuthMethod = @import("./AuthMethod.zig").AuthMethod; +const CharacterSet = @import("./protocol/CharacterSet.zig").CharacterSet; +const ConnectionFlags = @import("../shared/ConnectionFlags.zig").ConnectionFlags; +const ConnectionState = @import("./ConnectionState.zig").ConnectionState; +const Data = @import("../shared/Data.zig").Data; +const NewReader = @import("./protocol/NewReader.zig").NewReader; +const NewWriter = @import("./protocol/NewWriter.zig").NewWriter; +const PacketType = @import("./protocol/PacketType.zig").PacketType; +const SSLMode = @import("./SSLMode.zig").SSLMode; +const StatusFlags = @import("./StatusFlags.zig").StatusFlags; +const TLSStatus = @import("./TLSStatus.zig").TLSStatus; + +const bun = @import("bun"); +const BoringSSL = bun.BoringSSL; + +const jsc = bun.jsc; +const JSValue = jsc.JSValue; +const AutoFlusher = jsc.WebCore.AutoFlusher; + +const uws = bun.uws; +const Socket = uws.AnySocket; diff --git a/src/sql/mysql/MySQLContext.zig b/src/sql/mysql/MySQLContext.zig new file mode 100644 index 0000000000..fa80904c5a --- /dev/null +++ b/src/sql/mysql/MySQLContext.zig @@ -0,0 +1,22 @@ +tcp: ?*uws.SocketContext = null, + +onQueryResolveFn: JSC.Strong.Optional = .empty, +onQueryRejectFn: JSC.Strong.Optional = .empty, + +pub fn init(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + var ctx = &globalObject.bunVM().rareData().mysql_context; + ctx.onQueryResolveFn.set(globalObject, callframe.argument(0)); + ctx.onQueryRejectFn.set(globalObject, callframe.argument(1)); + + return .js_undefined; +} + +comptime { + @export(&JSC.toJSHostFn(init), .{ .name = "MySQLContext__init" }); +} + +const bun = @import("bun"); +const uws = bun.uws; + +const JSC = bun.jsc; +const JSValue = JSC.JSValue; diff --git a/src/sql/mysql/MySQLQuery.zig b/src/sql/mysql/MySQLQuery.zig new file mode 100644 index 0000000000..292922afd1 --- /dev/null +++ b/src/sql/mysql/MySQLQuery.zig @@ -0,0 +1,545 @@ +const MySQLQuery = @This(); +const RefCount = bun.ptr.ThreadSafeRefCount(@This(), "ref_count", deinit, .{}); + +statement: ?*MySQLStatement = null, +query: bun.String = bun.String.empty, +cursor_name: bun.String = bun.String.empty, +thisValue: JSRef = JSRef.empty(), + +status: Status = Status.pending, + +ref_count: RefCount = RefCount.init(), + +flags: packed struct(u8) { + is_done: bool = false, + binary: bool = false, + bigint: bool = false, + simple: bool = false, + pipelined: bool = false, + result_mode: SQLQueryResultMode = .objects, + _padding: u1 = 0, +} = .{}, + +pub const ref = RefCount.ref; +pub const deref = RefCount.deref; + +pub const Status = enum(u8) { + /// The query was just enqueued, statement status can be checked for more details + pending, + /// The query is being bound to the statement + binding, + /// The query is running + running, + /// The query is waiting for a partial response + partial_response, + /// The query was successful + success, + /// The query failed + fail, + + pub fn isRunning(this: Status) bool { + return @intFromEnum(this) > @intFromEnum(Status.pending) and @intFromEnum(this) < @intFromEnum(Status.success); + } +}; + +pub fn hasPendingActivity(this: *@This()) bool { + return this.ref_count.load(.monotonic) > 1; +} + +pub fn deinit(this: *@This()) void { + this.thisValue.deinit(); + if (this.statement) |statement| { + statement.deref(); + } + this.query.deref(); + this.cursor_name.deref(); + + bun.default_allocator.destroy(this); +} + +pub fn finalize(this: *@This()) void { + debug("MySQLQuery finalize", .{}); + + // Clean up any statement reference + if (this.statement) |statement| { + statement.deref(); + this.statement = null; + } + + if (this.thisValue == .weak) { + // clean up if is a weak reference, if is a strong reference we need to wait until the query is done + // if we are a strong reference, here is probably a bug because GC'd should not happen + this.thisValue.weak = .zero; + } + this.deref(); +} + +pub fn onWriteFail( + this: *@This(), + err: AnyMySQLError.Error, + globalObject: *jsc.JSGlobalObject, + queries_array: JSValue, +) void { + this.status = .fail; + const thisValue = this.thisValue.get(); + defer this.thisValue.deinit(); + const targetValue = this.getTarget(globalObject, true); + if (thisValue == .zero or targetValue == .zero) { + return; + } + + const instance = AnyMySQLError.mysqlErrorToJS(globalObject, "Failed to bind query", err); + + const vm = jsc.VirtualMachine.get(); + const function = vm.rareData().mysql_context.onQueryRejectFn.get().?; + const event_loop = vm.eventLoop(); + event_loop.runCallback(function, globalObject, thisValue, &.{ + targetValue, + // TODO: add mysql error to JS + // postgresErrorToJS(globalObject, null, err), + instance, + queries_array, + }); +} + +pub fn bindAndExecute(this: *MySQLQuery, writer: anytype, statement: *MySQLStatement, globalObject: *jsc.JSGlobalObject) AnyMySQLError.Error!void { + debug("bindAndExecute", .{}); + bun.assertf(statement.params.len == statement.params_received and statement.statement_id > 0, "statement is not prepared", .{}); + if (statement.signature.fields.len != statement.params.len) { + return error.WrongNumberOfParametersProvided; + } + var packet = try writer.start(0); + var execute = PreparedStatement.Execute{ + .statement_id = statement.statement_id, + .param_types = statement.signature.fields, + .new_params_bind_flag = statement.execution_flags.need_to_send_params, + .iteration_count = 1, + }; + statement.execution_flags.need_to_send_params = false; + defer execute.deinit(); + try this.bind(&execute, globalObject); + try execute.write(writer); + try packet.end(); + this.status = .running; +} + +fn bind(this: *MySQLQuery, execute: *PreparedStatement.Execute, globalObject: *jsc.JSGlobalObject) AnyMySQLError.Error!void { + const thisValue = this.thisValue.get(); + const binding_value = js.bindingGetCached(thisValue) orelse .zero; + const columns_value = js.columnsGetCached(thisValue) orelse .zero; + + var iter = try QueryBindingIterator.init(binding_value, columns_value, globalObject); + + var i: u32 = 0; + var params = try bun.default_allocator.alloc(Value, execute.param_types.len); + errdefer { + for (params[0..i]) |*param| { + param.deinit(bun.default_allocator); + } + bun.default_allocator.free(params); + } + while (try iter.next()) |js_value| { + const param = execute.param_types[i]; + debug("param: {s} unsigned? {}", .{ @tagName(param.type), param.flags.UNSIGNED }); + params[i] = try Value.fromJS( + js_value, + globalObject, + param.type, + param.flags.UNSIGNED, + ); + i += 1; + } + + if (iter.anyFailed()) { + return error.InvalidQueryBinding; + } + + this.status = .binding; + execute.params = params; +} + +pub fn onError(this: *@This(), err: ErrorPacket, globalObject: *jsc.JSGlobalObject) void { + debug("onError", .{}); + this.onJSError(err.toJS(globalObject), globalObject); +} + +pub fn onJSError(this: *@This(), err: jsc.JSValue, globalObject: *jsc.JSGlobalObject) void { + this.ref(); + defer this.deref(); + this.status = .fail; + const thisValue = this.thisValue.get(); + defer this.thisValue.deinit(); + const targetValue = this.getTarget(globalObject, true); + if (thisValue == .zero or targetValue == .zero) { + return; + } + + var vm = jsc.VirtualMachine.get(); + const function = vm.rareData().mysql_context.onQueryRejectFn.get().?; + const event_loop = vm.eventLoop(); + event_loop.runCallback(function, globalObject, thisValue, &.{ + targetValue, + err, + }); +} +pub fn getTarget(this: *@This(), globalObject: *jsc.JSGlobalObject, clean_target: bool) jsc.JSValue { + const thisValue = this.thisValue.tryGet() orelse return .zero; + const target = js.targetGetCached(thisValue) orelse return .zero; + if (clean_target) { + js.targetSetCached(thisValue, globalObject, .zero); + } + return target; +} + +fn consumePendingValue(thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject) ?JSValue { + const pending_value = js.pendingValueGetCached(thisValue) orelse return null; + js.pendingValueSetCached(thisValue, globalObject, .zero); + return pending_value; +} + +pub fn allowGC(thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject) void { + if (thisValue == .zero) { + return; + } + + defer thisValue.ensureStillAlive(); + js.bindingSetCached(thisValue, globalObject, .zero); + js.pendingValueSetCached(thisValue, globalObject, .zero); + js.targetSetCached(thisValue, globalObject, .zero); +} + +pub fn onResult(this: *@This(), result_count: u64, globalObject: *jsc.JSGlobalObject, connection: jsc.JSValue, is_last: bool) void { + this.ref(); + defer this.deref(); + + const thisValue = this.thisValue.get(); + const targetValue = this.getTarget(globalObject, is_last); + if (is_last) { + this.status = .success; + } else { + this.status = .partial_response; + } + defer if (is_last) { + allowGC(thisValue, globalObject); + this.thisValue.deinit(); + }; + if (thisValue == .zero or targetValue == .zero) { + return; + } + + const vm = jsc.VirtualMachine.get(); + const function = vm.rareData().mysql_context.onQueryResolveFn.get().?; + const event_loop = vm.eventLoop(); + const tag: CommandTag = .{ .SELECT = result_count }; + + event_loop.runCallback(function, globalObject, thisValue, &.{ + targetValue, + consumePendingValue(thisValue, globalObject) orelse .js_undefined, + tag.toJSTag(globalObject), + tag.toJSNumber(), + if (connection == .zero) .js_undefined else MySQLConnection.js.queriesGetCached(connection) orelse .js_undefined, + JSValue.jsBoolean(is_last), + }); +} + +pub fn constructor(globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!*MySQLQuery { + _ = callframe; + return globalThis.throw("MySQLQuery cannot be constructed directly", .{}); +} + +pub fn estimatedSize(this: *MySQLQuery) usize { + _ = this; + return @sizeOf(MySQLQuery); +} + +pub fn call(globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + const arguments = callframe.arguments(); + var args = jsc.CallFrame.ArgumentsSlice.init(globalThis.bunVM(), arguments); + defer args.deinit(); + const query = args.nextEat() orelse { + return globalThis.throw("query must be a string", .{}); + }; + const values = args.nextEat() orelse { + return globalThis.throw("values must be an array", .{}); + }; + + if (!query.isString()) { + return globalThis.throw("query must be a string", .{}); + } + + if (values.jsType() != .Array) { + return globalThis.throw("values must be an array", .{}); + } + + const pending_value: JSValue = args.nextEat() orelse .js_undefined; + const columns: JSValue = args.nextEat() orelse .js_undefined; + const js_bigint: JSValue = args.nextEat() orelse .false; + const js_simple: JSValue = args.nextEat() orelse .false; + + const bigint = js_bigint.isBoolean() and js_bigint.asBoolean(); + const simple = js_simple.isBoolean() and js_simple.asBoolean(); + if (simple) { + if (try values.getLength(globalThis) > 0) { + return globalThis.throwInvalidArguments("simple query cannot have parameters", .{}); + } + if (try query.getLength(globalThis) >= std.math.maxInt(i32)) { + return globalThis.throwInvalidArguments("query is too long", .{}); + } + } + if (!pending_value.jsType().isArrayLike()) { + return globalThis.throwInvalidArgumentType("query", "pendingValue", "Array"); + } + + var ptr = bun.default_allocator.create(MySQLQuery) catch |err| { + return globalThis.throwError(err, "failed to allocate query"); + }; + + const this_value = ptr.toJS(globalThis); + this_value.ensureStillAlive(); + + ptr.* = .{ + .query = try query.toBunString(globalThis), + .thisValue = JSRef.initWeak(this_value), + .flags = .{ + .bigint = bigint, + .simple = simple, + }, + }; + ptr.query.ref(); + + js.bindingSetCached(this_value, globalThis, values); + js.pendingValueSetCached(this_value, globalThis, pending_value); + if (!columns.isUndefined()) { + js.columnsSetCached(this_value, globalThis, columns); + } + + return this_value; +} +pub fn setPendingValue(this: *@This(), globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + const result = callframe.argument(0); + const thisValue = this.thisValue.tryGet() orelse return .js_undefined; + js.pendingValueSetCached(thisValue, globalObject, result); + return .js_undefined; +} +pub fn setMode(this: *@This(), globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + const js_mode = callframe.argument(0); + if (js_mode.isEmptyOrUndefinedOrNull() or !js_mode.isNumber()) { + return globalObject.throwInvalidArgumentType("setMode", "mode", "Number"); + } + + const mode = try js_mode.coerce(i32, globalObject); + this.flags.result_mode = std.meta.intToEnum(SQLQueryResultMode, mode) catch { + return globalObject.throwInvalidArgumentTypeValue("mode", "Number", js_mode); + }; + return .js_undefined; +} + +pub fn doDone(this: *@This(), globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue { + _ = globalObject; + this.flags.is_done = true; + return .js_undefined; +} + +pub fn doCancel(this: *MySQLQuery, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + _ = callframe; + _ = globalObject; + _ = this; + + return .js_undefined; +} + +pub fn doRun(this: *MySQLQuery, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + debug("doRun", .{}); + var arguments = callframe.arguments(); + const connection: *MySQLConnection = arguments[0].as(MySQLConnection) orelse { + return globalObject.throw("connection must be a MySQLConnection", .{}); + }; + + connection.poll_ref.ref(globalObject.bunVM()); + var query = arguments[1]; + + if (!query.isObject()) { + return globalObject.throwInvalidArgumentType("run", "query", "Query"); + } + + const this_value = callframe.this(); + const binding_value = js.bindingGetCached(this_value) orelse .zero; + var query_str = this.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + const writer = connection.writer(); + // We need a strong reference to the query so that it doesn't get GC'd + this.ref(); + const can_execute = connection.canExecuteQuery(); + if (this.flags.simple) { + // simple queries are always text in MySQL + this.flags.binary = false; + debug("executeQuery", .{}); + + const stmt = bun.default_allocator.create(MySQLStatement) catch { + this.deref(); + return globalObject.throwOutOfMemory(); + }; + // Query is simple and it's the only owner of the statement + stmt.* = .{ + .signature = Signature.empty(), + .status = .parsing, + }; + this.statement = stmt; + + if (can_execute) { + connection.sequence_id = 0; + MySQLRequest.executeQuery(query_str.slice(), MySQLConnection.Writer, writer) catch |err| { + debug("executeQuery failed: {s}", .{@errorName(err)}); + // fail to run do cleanup + this.statement = null; + bun.default_allocator.destroy(stmt); + this.deref(); + + if (!globalObject.hasException()) + return globalObject.throwValue(AnyMySQLError.mysqlErrorToJS(globalObject, "failed to execute query", err)); + return error.JSError; + }; + connection.flags.is_ready_for_query = false; + connection.nonpipelinable_requests += 1; + this.status = .running; + } else { + this.status = .pending; + } + connection.requests.writeItem(this) catch { + // fail to run do cleanup + this.statement = null; + bun.default_allocator.destroy(stmt); + this.deref(); + + return globalObject.throwOutOfMemory(); + }; + debug("doRun: wrote query to queue", .{}); + + this.thisValue.upgrade(globalObject); + js.targetSetCached(this_value, globalObject, query); + connection.flushDataAndResetTimeout(); + return .js_undefined; + } + // prepared statements are always binary in MySQL + this.flags.binary = true; + + const columns_value = js.columnsGetCached(callframe.this()) orelse .js_undefined; + + var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value) catch |err| { + this.deref(); + if (!globalObject.hasException()) + return globalObject.throwValue(AnyMySQLError.mysqlErrorToJS(globalObject, "failed to generate signature", err)); + return error.JSError; + }; + errdefer signature.deinit(); + + const entry = connection.statements.getOrPut(bun.default_allocator, bun.hash(signature.name)) catch |err| { + this.deref(); + return globalObject.throwError(err, "failed to allocate statement"); + }; + + var did_write = false; + + enqueue: { + if (entry.found_existing) { + const stmt = entry.value_ptr.*; + this.statement = stmt; + stmt.ref(); + signature.deinit(); + signature = Signature{}; + switch (stmt.status) { + .failed => { + this.statement = null; + const error_response = stmt.error_response.toJS(globalObject); + stmt.deref(); + this.deref(); + // If the statement failed, we need to throw the error + return globalObject.throwValue(error_response); + }, + .prepared => { + if (can_execute or connection.canPipeline()) { + debug("doRun: binding and executing query", .{}); + this.bindAndExecute(writer, this.statement.?, globalObject) catch |err| { + if (!globalObject.hasException()) + return globalObject.throwValue(AnyMySQLError.mysqlErrorToJS(globalObject, "failed to bind and execute query", err)); + return error.JSError; + }; + connection.sequence_id = 0; + this.flags.pipelined = true; + connection.pipelined_requests += 1; + connection.flags.is_ready_for_query = false; + did_write = true; + } + }, + + .parsing, .pending => {}, + } + + break :enqueue; + } + + const stmt = bun.default_allocator.create(MySQLStatement) catch |err| { + this.deref(); + return globalObject.throwError(err, "failed to allocate statement"); + }; + stmt.* = .{ + .signature = signature, + .ref_count = .initExactRefs(2), + .status = .pending, + .statement_id = 0, + }; + this.statement = stmt; + entry.value_ptr.* = stmt; + } + + this.status = if (did_write) .running else .pending; + try connection.requests.writeItem(this); + this.thisValue.upgrade(globalObject); + + js.targetSetCached(this_value, globalObject, query); + if (!did_write and can_execute) { + debug("doRun: preparing query", .{}); + if (connection.canPrepareQuery()) { + this.statement.?.status = .parsing; + MySQLRequest.prepareRequest(query_str.slice(), MySQLConnection.Writer, writer) catch |err| { + this.deref(); + return globalObject.throwError(err, "failed to prepare query"); + }; + connection.flags.waiting_to_prepare = true; + connection.flags.is_ready_for_query = false; + } + } + connection.flushDataAndResetTimeout(); + + return .js_undefined; +} + +comptime { + @export(&jsc.toJSHostFn(call), .{ .name = "MySQLQuery__createInstance" }); +} + +pub const js = jsc.Codegen.JSMySQLQuery; +pub const fromJS = js.fromJS; +pub const fromJSDirect = js.fromJSDirect; +pub const toJS = js.toJS; + +const debug = bun.Output.scoped(.MySQLQuery, .visible); +// TODO: move to shared IF POSSIBLE + +const AnyMySQLError = @import("./protocol/AnyMySQLError.zig"); +const ErrorPacket = @import("./protocol/ErrorPacket.zig"); +const MySQLConnection = @import("./MySQLConnection.zig"); +const MySQLRequest = @import("./MySQLRequest.zig"); +const MySQLStatement = @import("./MySQLStatement.zig"); +const PreparedStatement = @import("./protocol/PreparedStatement.zig"); +const Signature = @import("./protocol/Signature.zig"); +const bun = @import("bun"); +const std = @import("std"); +const CommandTag = @import("../postgres/CommandTag.zig").CommandTag; +const QueryBindingIterator = @import("../shared/QueryBindingIterator.zig").QueryBindingIterator; +const SQLQueryResultMode = @import("../shared/SQLQueryResultMode.zig").SQLQueryResultMode; +const Value = @import("./MySQLTypes.zig").Value; + +const jsc = bun.jsc; +const JSRef = jsc.JSRef; +const JSValue = jsc.JSValue; diff --git a/src/sql/mysql/MySQLRequest.zig b/src/sql/mysql/MySQLRequest.zig new file mode 100644 index 0000000000..336d15f4fc --- /dev/null +++ b/src/sql/mysql/MySQLRequest.zig @@ -0,0 +1,31 @@ +pub fn executeQuery( + query: []const u8, + comptime Context: type, + writer: NewWriter(Context), +) !void { + debug("executeQuery len: {d} {s}", .{ query.len, query }); + // resets the sequence id to zero every time we send a query + var packet = try writer.start(0); + try writer.int1(@intFromEnum(CommandType.COM_QUERY)); + try writer.write(query); + + try packet.end(); +} +pub fn prepareRequest( + query: []const u8, + comptime Context: type, + writer: NewWriter(Context), +) !void { + debug("prepareRequest {s}", .{query}); + var packet = try writer.start(0); + try writer.int1(@intFromEnum(CommandType.COM_STMT_PREPARE)); + try writer.write(query); + + try packet.end(); +} + +const debug = bun.Output.scoped(.MySQLRequest, .visible); + +const bun = @import("bun"); +const CommandType = @import("./protocol/CommandType.zig").CommandType; +const NewWriter = @import("./protocol/NewWriter.zig").NewWriter; diff --git a/src/sql/mysql/MySQLStatement.zig b/src/sql/mysql/MySQLStatement.zig new file mode 100644 index 0000000000..437389b141 --- /dev/null +++ b/src/sql/mysql/MySQLStatement.zig @@ -0,0 +1,178 @@ +const MySQLStatement = @This(); +const RefCount = bun.ptr.RefCount(@This(), "ref_count", deinit, .{}); + +cached_structure: CachedStructure = .{}, +ref_count: RefCount = RefCount.init(), +statement_id: u32 = 0, +params: []Param = &[_]Param{}, +params_received: u32 = 0, + +columns: []ColumnDefinition41 = &[_]ColumnDefinition41{}, +columns_received: u32 = 0, + +signature: Signature, +status: Status = Status.parsing, +error_response: ErrorPacket = .{ .error_code = 0 }, +execution_flags: ExecutionFlags = .{}, +fields_flags: SQLDataCell.Flags = .{}, +result_count: u64 = 0, + +pub const ExecutionFlags = packed struct(u8) { + header_received: bool = false, + needs_duplicate_check: bool = true, + need_to_send_params: bool = true, + _: u5 = 0, +}; + +pub const Status = enum { + pending, + parsing, + prepared, + failed, +}; + +pub const ref = RefCount.ref; +pub const deref = RefCount.deref; + +pub fn reset(this: *MySQLStatement) void { + this.result_count = 0; + this.columns_received = 0; + this.execution_flags = .{}; +} + +pub fn deinit(this: *MySQLStatement) void { + debug("MySQLStatement deinit", .{}); + + for (this.columns) |*column| { + column.deinit(); + } + if (this.columns.len > 0) { + bun.default_allocator.free(this.columns); + } + if (this.params.len > 0) { + bun.default_allocator.free(this.params); + } + this.cached_structure.deinit(); + this.error_response.deinit(); + this.signature.deinit(); + bun.default_allocator.destroy(this); +} + +pub fn checkForDuplicateFields(this: *@This()) void { + if (!this.execution_flags.needs_duplicate_check) return; + this.execution_flags.needs_duplicate_check = false; + + var seen_numbers = std.ArrayList(u32).init(bun.default_allocator); + defer seen_numbers.deinit(); + var seen_fields = bun.StringHashMap(void).init(bun.default_allocator); + seen_fields.ensureUnusedCapacity(@intCast(this.columns.len)) catch bun.outOfMemory(); + defer seen_fields.deinit(); + + // iterate backwards + var remaining = this.columns.len; + var flags: SQLDataCell.Flags = .{}; + while (remaining > 0) { + remaining -= 1; + const field: *ColumnDefinition41 = &this.columns[remaining]; + switch (field.name_or_index) { + .name => |*name| { + const seen = seen_fields.getOrPut(name.slice()) catch unreachable; + if (seen.found_existing) { + field.name_or_index = .duplicate; + flags.has_duplicate_columns = true; + } + + flags.has_named_columns = true; + }, + .index => |index| { + if (std.mem.indexOfScalar(u32, seen_numbers.items, index) != null) { + field.name_or_index = .duplicate; + flags.has_duplicate_columns = true; + } else { + seen_numbers.append(index) catch bun.outOfMemory(); + } + + flags.has_indexed_columns = true; + }, + .duplicate => { + flags.has_duplicate_columns = true; + }, + } + } + + this.fields_flags = flags; +} + +pub fn structure(this: *MySQLStatement, owner: JSValue, globalObject: *jsc.JSGlobalObject) CachedStructure { + if (this.cached_structure.has()) { + return this.cached_structure; + } + this.checkForDuplicateFields(); + + // lets avoid most allocations + var stack_ids: [70]jsc.JSObject.ExternColumnIdentifier = [_]jsc.JSObject.ExternColumnIdentifier{.{ .tag = 0, .value = .{ .index = 0 } }} ** 70; + // lets de duplicate the fields early + var nonDuplicatedCount = this.columns.len; + for (this.columns) |*column| { + if (column.name_or_index == .duplicate) { + nonDuplicatedCount -= 1; + } + } + const ids = if (nonDuplicatedCount <= jsc.JSObject.maxInlineCapacity()) stack_ids[0..nonDuplicatedCount] else bun.default_allocator.alloc(jsc.JSObject.ExternColumnIdentifier, nonDuplicatedCount) catch bun.outOfMemory(); + + var i: usize = 0; + for (this.columns) |*column| { + if (column.name_or_index == .duplicate) continue; + + var id: *jsc.JSObject.ExternColumnIdentifier = &ids[i]; + switch (column.name_or_index) { + .name => |name| { + id.value.name = String.createAtomIfPossible(name.slice()); + }, + .index => |index| { + id.value.index = index; + }, + .duplicate => unreachable, + } + + id.tag = switch (column.name_or_index) { + .name => 2, + .index => 1, + .duplicate => 0, + }; + + i += 1; + } + + if (nonDuplicatedCount > jsc.JSObject.maxInlineCapacity()) { + this.cached_structure.set(globalObject, null, ids); + } else { + this.cached_structure.set(globalObject, jsc.JSObject.createStructure( + globalObject, + owner, + @truncate(ids.len), + ids.ptr, + ), null); + } + + return this.cached_structure; +} +pub const Param = struct { + type: types.FieldType, + flags: ColumnDefinition41.ColumnFlags, +}; +const debug = bun.Output.scoped(.MySQLStatement, .hidden); + +const CachedStructure = @import("../shared/CachedStructure.zig"); +const ColumnDefinition41 = @import("./protocol/ColumnDefinition41.zig"); +const ErrorPacket = @import("./protocol/ErrorPacket.zig"); +const Signature = @import("./protocol/Signature.zig"); +const std = @import("std"); +const types = @import("./MySQLTypes.zig"); +const SQLDataCell = @import("../shared/SQLDataCell.zig").SQLDataCell; + +const bun = @import("bun"); +const String = bun.String; + +const jsc = bun.jsc; +const JSValue = jsc.JSValue; diff --git a/src/sql/mysql/MySQLTypes.zig b/src/sql/mysql/MySQLTypes.zig new file mode 100644 index 0000000000..915dd0ffda --- /dev/null +++ b/src/sql/mysql/MySQLTypes.zig @@ -0,0 +1,877 @@ +pub const CharacterSet = enum(u8) { + big5_chinese_ci = 1, + latin2_czech_cs = 2, + dec8_swedish_ci = 3, + cp850_general_ci = 4, + latin1_german1_ci = 5, + hp8_english_ci = 6, + koi8r_general_ci = 7, + latin1_swedish_ci = 8, + latin2_general_ci = 9, + swe7_swedish_ci = 10, + ascii_general_ci = 11, + ujis_japanese_ci = 12, + sjis_japanese_ci = 13, + cp1251_bulgarian_ci = 14, + latin1_danish_ci = 15, + hebrew_general_ci = 16, + tis620_thai_ci = 18, + euckr_korean_ci = 19, + latin7_estonian_cs = 20, + latin2_hungarian_ci = 21, + koi8u_general_ci = 22, + cp1251_ukrainian_ci = 23, + gb2312_chinese_ci = 24, + greek_general_ci = 25, + cp1250_general_ci = 26, + latin2_croatian_ci = 27, + gbk_chinese_ci = 28, + cp1257_lithuanian_ci = 29, + latin5_turkish_ci = 30, + latin1_german2_ci = 31, + armscii8_general_ci = 32, + utf8mb3_general_ci = 33, + cp1250_czech_cs = 34, + ucs2_general_ci = 35, + cp866_general_ci = 36, + keybcs2_general_ci = 37, + macce_general_ci = 38, + macroman_general_ci = 39, + cp852_general_ci = 40, + latin7_general_ci = 41, + latin7_general_cs = 42, + macce_bin = 43, + cp1250_croatian_ci = 44, + utf8mb4_general_ci = 45, + utf8mb4_bin = 46, + latin1_bin = 47, + latin1_general_ci = 48, + latin1_general_cs = 49, + cp1251_bin = 50, + cp1251_general_ci = 51, + cp1251_general_cs = 52, + macroman_bin = 53, + utf16_general_ci = 54, + utf16_bin = 55, + utf16le_general_ci = 56, + cp1256_general_ci = 57, + cp1257_bin = 58, + cp1257_general_ci = 59, + utf32_general_ci = 60, + utf32_bin = 61, + utf16le_bin = 62, + binary = 63, + armscii8_bin = 64, + ascii_bin = 65, + cp1250_bin = 66, + cp1256_bin = 67, + cp866_bin = 68, + dec8_bin = 69, + greek_bin = 70, + hebrew_bin = 71, + hp8_bin = 72, + keybcs2_bin = 73, + koi8r_bin = 74, + koi8u_bin = 75, + utf8mb3_tolower_ci = 76, + latin2_bin = 77, + latin5_bin = 78, + latin7_bin = 79, + cp850_bin = 80, + cp852_bin = 81, + swe7_bin = 82, + utf8mb3_bin = 83, + big5_bin = 84, + euckr_bin = 85, + gb2312_bin = 86, + gbk_bin = 87, + sjis_bin = 88, + tis620_bin = 89, + ucs2_bin = 90, + ujis_bin = 91, + geostd8_general_ci = 92, + geostd8_bin = 93, + latin1_spanish_ci = 94, + cp932_japanese_ci = 95, + cp932_bin = 96, + eucjpms_japanese_ci = 97, + eucjpms_bin = 98, + cp1250_polish_ci = 99, + utf16_unicode_ci = 101, + utf16_icelandic_ci = 102, + utf16_latvian_ci = 103, + utf16_romanian_ci = 104, + utf16_slovenian_ci = 105, + utf16_polish_ci = 106, + utf16_estonian_ci = 107, + utf16_spanish_ci = 108, + utf16_swedish_ci = 109, + utf16_turkish_ci = 110, + utf16_czech_ci = 111, + utf16_danish_ci = 112, + utf16_lithuanian_ci = 113, + utf16_slovak_ci = 114, + utf16_spanish2_ci = 115, + utf16_roman_ci = 116, + utf16_persian_ci = 117, + utf16_esperanto_ci = 118, + utf16_hungarian_ci = 119, + utf16_sinhala_ci = 120, + utf16_german2_ci = 121, + utf16_croatian_ci = 122, + utf16_unicode_520_ci = 123, + utf16_vietnamese_ci = 124, + ucs2_unicode_ci = 128, + ucs2_icelandic_ci = 129, + ucs2_latvian_ci = 130, + ucs2_romanian_ci = 131, + ucs2_slovenian_ci = 132, + ucs2_polish_ci = 133, + ucs2_estonian_ci = 134, + ucs2_spanish_ci = 135, + ucs2_swedish_ci = 136, + ucs2_turkish_ci = 137, + ucs2_czech_ci = 138, + ucs2_danish_ci = 139, + ucs2_lithuanian_ci = 140, + ucs2_slovak_ci = 141, + ucs2_spanish2_ci = 142, + ucs2_roman_ci = 143, + ucs2_persian_ci = 144, + ucs2_esperanto_ci = 145, + ucs2_hungarian_ci = 146, + ucs2_sinhala_ci = 147, + ucs2_german2_ci = 148, + ucs2_croatian_ci = 149, + ucs2_unicode_520_ci = 150, + ucs2_vietnamese_ci = 151, + ucs2_general_mysql500_ci = 159, + utf32_unicode_ci = 160, + utf32_icelandic_ci = 161, + utf32_latvian_ci = 162, + utf32_romanian_ci = 163, + utf32_slovenian_ci = 164, + utf32_polish_ci = 165, + utf32_estonian_ci = 166, + utf32_spanish_ci = 167, + utf32_swedish_ci = 168, + utf32_turkish_ci = 169, + utf32_czech_ci = 170, + utf32_danish_ci = 171, + utf32_lithuanian_ci = 172, + utf32_slovak_ci = 173, + utf32_spanish2_ci = 174, + utf32_roman_ci = 175, + utf32_persian_ci = 176, + utf32_esperanto_ci = 177, + utf32_hungarian_ci = 178, + utf32_sinhala_ci = 179, + utf32_german2_ci = 180, + utf32_croatian_ci = 181, + utf32_unicode_520_ci = 182, + utf32_vietnamese_ci = 183, + utf8mb3_unicode_ci = 192, + utf8mb3_icelandic_ci = 193, + utf8mb3_latvian_ci = 194, + utf8mb3_romanian_ci = 195, + utf8mb3_slovenian_ci = 196, + utf8mb3_polish_ci = 197, + utf8mb3_estonian_ci = 198, + utf8mb3_spanish_ci = 199, + utf8mb3_swedish_ci = 200, + utf8mb3_turkish_ci = 201, + utf8mb3_czech_ci = 202, + utf8mb3_danish_ci = 203, + utf8mb3_lithuanian_ci = 204, + utf8mb3_slovak_ci = 205, + utf8mb3_spanish2_ci = 206, + utf8mb3_roman_ci = 207, + utf8mb3_persian_ci = 208, + utf8mb3_esperanto_ci = 209, + utf8mb3_hungarian_ci = 210, + utf8mb3_sinhala_ci = 211, + utf8mb3_german2_ci = 212, + utf8mb3_croatian_ci = 213, + utf8mb3_unicode_520_ci = 214, + utf8mb3_vietnamese_ci = 215, + utf8mb3_general_mysql500_ci = 223, + utf8mb4_unicode_ci = 224, + utf8mb4_icelandic_ci = 225, + utf8mb4_latvian_ci = 226, + utf8mb4_romanian_ci = 227, + utf8mb4_slovenian_ci = 228, + utf8mb4_polish_ci = 229, + utf8mb4_estonian_ci = 230, + utf8mb4_spanish_ci = 231, + utf8mb4_swedish_ci = 232, + utf8mb4_turkish_ci = 233, + utf8mb4_czech_ci = 234, + utf8mb4_danish_ci = 235, + utf8mb4_lithuanian_ci = 236, + utf8mb4_slovak_ci = 237, + utf8mb4_spanish2_ci = 238, + utf8mb4_roman_ci = 239, + utf8mb4_persian_ci = 240, + utf8mb4_esperanto_ci = 241, + utf8mb4_hungarian_ci = 242, + utf8mb4_sinhala_ci = 243, + utf8mb4_german2_ci = 244, + utf8mb4_croatian_ci = 245, + utf8mb4_unicode_520_ci = 246, + utf8mb4_vietnamese_ci = 247, + gb18030_chinese_ci = 248, + gb18030_bin = 249, + gb18030_unicode_520_ci = 250, + _, + + pub const default = CharacterSet.utf8mb4_general_ci; + + pub fn label(this: CharacterSet) []const u8 { + if (@intFromEnum(this) < 100 and @intFromEnum(this) > 0) { + return @tagName(this); + } + + return "(unknown)"; + } +}; + +// MySQL field types +// https://dev.mysql.com/doc/dev/mysql-server/latest/binary__log__types_8h.html#a8935f33b06a3a88ba403c63acd806920 +pub const FieldType = enum(u8) { + MYSQL_TYPE_DECIMAL = 0x00, + MYSQL_TYPE_TINY = 0x01, + MYSQL_TYPE_SHORT = 0x02, + MYSQL_TYPE_LONG = 0x03, + MYSQL_TYPE_FLOAT = 0x04, + MYSQL_TYPE_DOUBLE = 0x05, + MYSQL_TYPE_NULL = 0x06, + MYSQL_TYPE_TIMESTAMP = 0x07, + MYSQL_TYPE_LONGLONG = 0x08, + MYSQL_TYPE_INT24 = 0x09, + MYSQL_TYPE_DATE = 0x0a, + MYSQL_TYPE_TIME = 0x0b, + MYSQL_TYPE_DATETIME = 0x0c, + MYSQL_TYPE_YEAR = 0x0d, + MYSQL_TYPE_NEWDATE = 0x0e, + MYSQL_TYPE_VARCHAR = 0x0f, + MYSQL_TYPE_BIT = 0x10, + MYSQL_TYPE_TIMESTAMP2 = 0x11, + MYSQL_TYPE_DATETIME2 = 0x12, + MYSQL_TYPE_TIME2 = 0x13, + MYSQL_TYPE_JSON = 0xf5, + MYSQL_TYPE_NEWDECIMAL = 0xf6, + MYSQL_TYPE_ENUM = 0xf7, + MYSQL_TYPE_SET = 0xf8, + MYSQL_TYPE_TINY_BLOB = 0xf9, + MYSQL_TYPE_MEDIUM_BLOB = 0xfa, + MYSQL_TYPE_LONG_BLOB = 0xfb, + MYSQL_TYPE_BLOB = 0xfc, + MYSQL_TYPE_VAR_STRING = 0xfd, + MYSQL_TYPE_STRING = 0xfe, + MYSQL_TYPE_GEOMETRY = 0xff, + _, + + pub fn fromJS(globalObject: *JSC.JSGlobalObject, value: JSValue, unsigned: *bool) bun.JSError!FieldType { + if (value.isEmptyOrUndefinedOrNull()) { + return .MYSQL_TYPE_NULL; + } + + if (value.isCell()) { + const tag = value.jsType(); + if (tag.isStringLike()) { + return .MYSQL_TYPE_STRING; + } + + if (tag == .JSDate) { + return .MYSQL_TYPE_DATETIME; + } + + if (tag.isTypedArrayOrArrayBuffer()) { + return .MYSQL_TYPE_BLOB; + } + + if (tag == .HeapBigInt) { + if (value.isBigIntInInt64Range(std.math.minInt(i64), std.math.maxInt(i64))) { + return .MYSQL_TYPE_LONGLONG; + } + if (value.isBigIntInUInt64Range(0, std.math.maxInt(u64))) { + unsigned.* = true; + return .MYSQL_TYPE_LONGLONG; + } + return globalObject.ERR(.OUT_OF_RANGE, "The value is out of range. It must be >= {d} and <= {d}.", .{ std.math.minInt(i64), std.math.maxInt(u64) }).throw(); + } + + if (globalObject.hasException()) return error.JSError; + + // Ban these types: + if (tag == .NumberObject) { + return error.JSError; + } + + if (tag == .BooleanObject) { + return error.JSError; + } + + // It's something internal + if (!tag.isIndexable()) { + return error.JSError; + } + + // We will JSON.stringify anything else. + if (tag.isObject()) { + return .MYSQL_TYPE_JSON; + } + } + + if (value.isAnyInt()) { + const int = value.toInt64(); + + if (int >= 0) { + if (int <= std.math.maxInt(i32)) { + return .MYSQL_TYPE_LONG; + } + if (int <= std.math.maxInt(u32)) { + unsigned.* = true; + return .MYSQL_TYPE_LONG; + } + if (int >= std.math.maxInt(i64)) { + unsigned.* = true; + return .MYSQL_TYPE_LONGLONG; + } + return .MYSQL_TYPE_LONGLONG; + } + if (int >= std.math.minInt(i32)) { + return .MYSQL_TYPE_LONG; + } + return .MYSQL_TYPE_LONGLONG; + } + + if (value.isNumber()) { + return .MYSQL_TYPE_DOUBLE; + } + + if (value.isBoolean()) { + return .MYSQL_TYPE_TINY; + } + + return .MYSQL_TYPE_VARCHAR; + } + + pub fn isBinaryFormatSupported(this: FieldType) bool { + return switch (this) { + .MYSQL_TYPE_TINY, + .MYSQL_TYPE_SHORT, + .MYSQL_TYPE_LONG, + .MYSQL_TYPE_LONGLONG, + .MYSQL_TYPE_FLOAT, + .MYSQL_TYPE_DOUBLE, + .MYSQL_TYPE_TIME, + .MYSQL_TYPE_DATE, + .MYSQL_TYPE_DATETIME, + .MYSQL_TYPE_TIMESTAMP, + => true, + else => false, + }; + } +}; + +// Add this near the top of the file +pub const Value = union(enum) { + null, + bool: bool, + short: i16, + ushort: u16, + int: i32, + uint: u32, + long: i64, + ulong: u64, + float: f32, + double: f64, + + string: JSC.ZigString.Slice, + string_data: Data, + bytes: JSC.ZigString.Slice, + bytes_data: Data, + date: DateTime, + time: Time, + // decimal: Decimal, + + pub fn deinit(this: *Value, _: std.mem.Allocator) void { + switch (this.*) { + inline .string, .bytes => |*slice| slice.deinit(), + inline .string_data, .bytes_data => |*data| data.deinit(), + // .decimal => |*decimal| decimal.deinit(allocator), + else => {}, + } + } + + pub fn toData( + this: *const Value, + field_type: FieldType, + ) AnyMySQLError.Error!Data { + var buffer: [15]u8 = undefined; // Large enough for all fixed-size types + var stream = std.io.fixedBufferStream(&buffer); + var writer = stream.writer(); + switch (this.*) { + .null => return Data{ .empty = {} }, + .bool => |b| writer.writeByte(if (b) 1 else 0) catch undefined, + .short => |s| writer.writeInt(i16, s, .little) catch undefined, + .ushort => |s| writer.writeInt(u16, s, .little) catch undefined, + .int => |i| writer.writeInt(i32, i, .little) catch undefined, + .uint => |i| writer.writeInt(u32, i, .little) catch undefined, + .long => |l| writer.writeInt(i64, l, .little) catch undefined, + .ulong => |l| writer.writeInt(u64, l, .little) catch undefined, + .float => |f| writer.writeInt(u32, @bitCast(f), .little) catch undefined, + .double => |d| writer.writeInt(u64, @bitCast(d), .little) catch undefined, + inline .date, .time => |d| { + stream.pos = d.toBinary(field_type, &buffer); + }, + // .decimal => |dec| return try dec.toBinary(field_type), + .string_data, .bytes_data => |data| return data, + .string, .bytes => |slice| return if (slice.len > 0) Data{ .temporary = slice.slice() } else Data{ .empty = {} }, + } + + return try Data.create(buffer[0..stream.pos], bun.default_allocator); + } + + pub fn fromJS(value: JSC.JSValue, globalObject: *JSC.JSGlobalObject, field_type: FieldType, unsigned: bool) AnyMySQLError.Error!Value { + if (value.isEmptyOrUndefinedOrNull()) { + return Value{ .null = {} }; + } + return switch (field_type) { + .MYSQL_TYPE_TINY => Value{ .bool = value.toBoolean() }, + .MYSQL_TYPE_SHORT => { + if (unsigned) { + return Value{ .ushort = try globalObject.validateIntegerRange(value, u16, 0, .{ .min = std.math.minInt(u16), .max = std.math.maxInt(u16), .field_name = "u16" }) }; + } + return Value{ .short = try globalObject.validateIntegerRange(value, i16, 0, .{ .min = std.math.minInt(i16), .max = std.math.maxInt(i16), .field_name = "i16" }) }; + }, + .MYSQL_TYPE_LONG => { + if (unsigned) { + return Value{ .uint = try globalObject.validateIntegerRange(value, u32, 0, .{ .min = std.math.minInt(u32), .max = std.math.maxInt(u32), .field_name = "u32" }) }; + } + return Value{ .int = try globalObject.validateIntegerRange(value, i32, 0, .{ .min = std.math.minInt(i32), .max = std.math.maxInt(i32), .field_name = "i32" }) }; + }, + .MYSQL_TYPE_LONGLONG => { + if (unsigned) { + return Value{ .ulong = try globalObject.validateBigIntRange(value, u64, 0, .{ .field_name = "u64", .min = 0, .max = std.math.maxInt(u64) }) }; + } + return Value{ .long = try globalObject.validateBigIntRange(value, i64, 0, .{ .min = std.math.minInt(i64), .max = std.math.maxInt(i64), .field_name = "i64" }) }; + }, + + .MYSQL_TYPE_FLOAT => Value{ .float = @floatCast(try value.coerce(f64, globalObject)) }, + .MYSQL_TYPE_DOUBLE => Value{ .double = try value.coerce(f64, globalObject) }, + .MYSQL_TYPE_TIME => Value{ .time = try Time.fromJS(value, globalObject) }, + .MYSQL_TYPE_DATE, .MYSQL_TYPE_TIMESTAMP, .MYSQL_TYPE_DATETIME => Value{ .date = try DateTime.fromJS(value, globalObject) }, + .MYSQL_TYPE_TINY_BLOB, .MYSQL_TYPE_MEDIUM_BLOB, .MYSQL_TYPE_LONG_BLOB, .MYSQL_TYPE_BLOB => { + if (value.asArrayBuffer(globalObject)) |array_buffer| { + return Value{ .bytes = JSC.ZigString.Slice.fromUTF8NeverFree(array_buffer.slice()) }; + } + + if (value.as(JSC.WebCore.Blob)) |blob| { + if (blob.needsToReadFile()) { + return globalObject.throwInvalidArguments("File blobs are not supported", .{}); + } + return Value{ .bytes = JSC.ZigString.Slice.fromUTF8NeverFree(blob.sharedView()) }; + } + + if (value.isString()) { + const str = try bun.String.fromJS(value, globalObject); + defer str.deref(); + return Value{ .string = str.toUTF8(bun.default_allocator) }; + } + + return globalObject.throwInvalidArguments("Expected a string, blob, or array buffer", .{}); + }, + + .MYSQL_TYPE_JSON => { + var str: bun.String = bun.String.empty; + try value.jsonStringify(globalObject, 0, &str); + defer str.deref(); + return Value{ .string = str.toUTF8(bun.default_allocator) }; + }, + + // .MYSQL_TYPE_VARCHAR, .MYSQL_TYPE_VAR_STRING, .MYSQL_TYPE_STRING => { + else => { + const str = try bun.String.fromJS(value, globalObject); + defer str.deref(); + return Value{ .string = str.toUTF8(bun.default_allocator) }; + }, + }; + } + + pub const DateTime = struct { + year: u16 = 0, + month: u8 = 0, + day: u8 = 0, + hour: u8 = 0, + minute: u8 = 0, + second: u8 = 0, + microsecond: u32 = 0, + + pub fn fromData(data: *const Data) !DateTime { + return fromBinary(data.slice()); + } + + pub fn fromBinary(val: []const u8) DateTime { + switch (val.len) { + 4 => { + // Byte 1: [year LSB] (8 bits of year) + // Byte 2: [year MSB] (8 bits of year) + // Byte 3: [month] (8-bit unsigned integer, 1-12) + // Byte 4: [day] (8-bit unsigned integer, 1-31) + return .{ + .year = std.mem.readInt(u16, val[0..2], .little), + .month = val[2], + .day = val[3], + }; + }, + 7 => { + // Byte 1: [year LSB] (8 bits of year) + // Byte 2: [year MSB] (8 bits of year) + // Byte 3: [month] (8-bit unsigned integer, 1-12) + // Byte 4: [day] (8-bit unsigned integer, 1-31) + // Byte 5: [hour] (8-bit unsigned integer, 0-23) + // Byte 6: [minute] (8-bit unsigned integer, 0-59) + // Byte 7: [second] (8-bit unsigned integer, 0-59) + return .{ + .year = std.mem.readInt(u16, val[0..2], .little), + .month = val[2], + .day = val[3], + .hour = val[4], + .minute = val[5], + .second = val[6], + }; + }, + 11 => { + // Byte 1: [year LSB] (8 bits of year) + // Byte 2: [year MSB] (8 bits of year) + // Byte 3: [month] (8-bit unsigned integer, 1-12) + // Byte 4: [day] (8-bit unsigned integer, 1-31) + // Byte 5: [hour] (8-bit unsigned integer, 0-23) + // Byte 6: [minute] (8-bit unsigned integer, 0-59) + // Byte 7: [second] (8-bit unsigned integer, 0-59) + // Byte 8-11: [microseconds] (32-bit little-endian unsigned integer + return .{ + .year = std.mem.readInt(u16, val[0..2], .little), + .month = val[2], + .day = val[3], + .hour = val[4], + .minute = val[5], + .second = val[6], + .microsecond = std.mem.readInt(u32, val[7..11], .little), + }; + }, + else => bun.Output.panic("Invalid datetime length: {d}", .{val.len}), + } + } + + pub fn toBinary(this: *const DateTime, field_type: FieldType, buffer: []u8) u8 { + switch (field_type) { + .MYSQL_TYPE_YEAR => { + buffer[0] = 2; + std.mem.writeInt(u16, buffer[1..3], this.year, .little); + return 3; + }, + .MYSQL_TYPE_DATE => { + buffer[0] = 4; + std.mem.writeInt(u16, buffer[1..3], this.year, .little); + buffer[3] = this.month; + buffer[4] = this.day; + return 5; + }, + .MYSQL_TYPE_DATETIME => { + buffer[0] = if (this.microsecond == 0) 7 else 11; + std.mem.writeInt(u16, buffer[1..3], this.year, .little); + buffer[3] = this.month; + buffer[4] = this.day; + buffer[5] = this.hour; + buffer[6] = this.minute; + buffer[7] = this.second; + if (this.microsecond == 0) { + return 8; + } else { + std.mem.writeInt(u32, buffer[8..12], this.microsecond, .little); + return 12; + } + }, + else => return 0, + } + } + + pub fn toJSTimestamp(this: *const DateTime, globalObject: *JSC.JSGlobalObject) bun.JSError!f64 { + return globalObject.gregorianDateTimeToMS( + this.year, + this.month, + this.day, + this.hour, + this.minute, + this.second, + if (this.microsecond > 0) @intCast(@divFloor(this.microsecond, 1000)) else 0, + ); + } + + pub fn fromUnixTimestamp(timestamp: i64, microseconds: u32) DateTime { + var ts = timestamp; + const days = @divFloor(ts, 86400); + ts = @mod(ts, 86400); + + const hour = @divFloor(ts, 3600); + ts = @mod(ts, 3600); + + const minute = @divFloor(ts, 60); + const second = @mod(ts, 60); + + const date = gregorianDate(@intCast(days)); + return .{ + .year = date.year, + .month = date.month, + .day = date.day, + .hour = @intCast(hour), + .minute = @intCast(minute), + .second = @intCast(second), + .microsecond = microseconds, + }; + } + + pub fn toJS(this: DateTime, globalObject: *JSC.JSGlobalObject) JSValue { + return JSValue.fromDateNumber(globalObject, this.toJSTimestamp()); + } + + pub fn fromJS(value: JSValue, globalObject: *JSC.JSGlobalObject) !DateTime { + if (value.isDate()) { + // this is actually ms not seconds + const total_ms = value.getUnixTimestamp(); + const ts: i64 = @intFromFloat(@divFloor(total_ms, 1000)); + const ms: u32 = @intFromFloat(total_ms - (@as(f64, @floatFromInt(ts)) * 1000)); + return DateTime.fromUnixTimestamp(ts, ms * 1000); + } + + if (value.isNumber()) { + const total_ms = value.asNumber(); + const ts: i64 = @intFromFloat(@divFloor(total_ms, 1000)); + const ms: u32 = @intFromFloat(total_ms - (@as(f64, @floatFromInt(ts)) * 1000)); + return DateTime.fromUnixTimestamp(ts, ms * 1000); + } + + return globalObject.throwInvalidArguments("Expected a date or number", .{}); + } + }; + + pub const Time = struct { + negative: bool = false, + days: u32 = 0, + hours: u8 = 0, + minutes: u8 = 0, + seconds: u8 = 0, + microseconds: u32 = 0, + + pub fn fromJS(value: JSValue, globalObject: *JSC.JSGlobalObject) !Time { + if (value.isDate()) { + const total_ms = value.getUnixTimestamp(); + const ts: i64 = @intFromFloat(@divFloor(total_ms, 1000)); + const ms: u32 = @intFromFloat(total_ms - (@as(f64, @floatFromInt(ts)) * 1000)); + return Time.fromUnixTimestamp(ts, ms * 1000); + } else if (value.isNumber()) { + const total_ms = value.asNumber(); + const ts: i64 = @intFromFloat(@divFloor(total_ms, 1000)); + const ms: u32 = @intFromFloat(total_ms - (@as(f64, @floatFromInt(ts)) * 1000)); + return Time.fromUnixTimestamp(ts, ms * 1000); + } else { + return globalObject.throwInvalidArguments("Expected a date or number", .{}); + } + } + + pub fn fromUnixTimestamp(timestamp: i64, microseconds: u32) Time { + const days = @divFloor(timestamp, 86400); + const hours = @divFloor(@mod(timestamp, 86400), 3600); + const minutes = @divFloor(@mod(timestamp, 3600), 60); + const seconds = @mod(timestamp, 60); + return .{ + .negative = timestamp < 0, + .days = @intCast(days), + .hours = @intCast(hours), + .minutes = @intCast(minutes), + .seconds = @intCast(seconds), + .microseconds = microseconds, + }; + } + + pub fn toUnixTimestamp(this: *const Time) i64 { + var total_ms: i64 = 0; + total_ms +|= @as(i64, this.days) *| 86400000; + total_ms +|= @as(i64, this.hours) *| 3600000; + total_ms +|= @as(i64, this.minutes) *| 60000; + total_ms +|= @as(i64, this.seconds) *| 1000; + return total_ms; + } + + pub fn fromData(data: *const Data) !Time { + return fromBinary(data.slice()); + } + + pub fn fromBinary(val: []const u8) Time { + if (val.len == 0) { + return Time{}; + } + + var time = Time{}; + if (val.len >= 8) { + time.negative = val[0] != 0; + time.days = std.mem.readInt(u32, val[1..5], .little); + time.hours = val[5]; + time.minutes = val[6]; + time.seconds = val[7]; + } + + if (val.len > 8) { + time.microseconds = std.mem.readInt(u32, val[8..12], .little); + } + + return time; + } + pub fn toJSTimestamp(this: *const Time) f64 { + var total_ms: i64 = 0; + total_ms +|= @as(i64, this.days) * 86400000; + total_ms +|= @as(i64, this.hours) * 3600000; + total_ms +|= @as(i64, this.minutes) * 60000; + total_ms +|= @as(i64, this.seconds) * 1000; + total_ms +|= @divFloor(this.microseconds, 1000); + + if (this.negative) { + total_ms = -total_ms; + } + + return @as(f64, @floatFromInt(total_ms)); + } + pub fn toJS(this: Time, _: *JSC.JSGlobalObject) JSValue { + return JSValue.jsDoubleNumber(this.toJSTimestamp()); + } + + pub fn toBinary(this: *const Time, field_type: FieldType, buffer: []u8) u8 { + switch (field_type) { + .MYSQL_TYPE_TIME, .MYSQL_TYPE_TIME2 => { + buffer[1] = if (this.negative) 1 else 0; + std.mem.writeInt(u32, buffer[2..6], this.days, .little); + buffer[6] = this.hours; + buffer[7] = this.minutes; + buffer[8] = this.seconds; + if (this.microseconds == 0) { + buffer[0] = 8; // length + return 9; + } else { + buffer[0] = 12; // length + std.mem.writeInt(u32, buffer[9..][0..4], this.microseconds, .little); + return 12; + } + }, + else => unreachable, + } + } + }; + + pub const Decimal = struct { + // MySQL DECIMAL is stored as a sequence of base-10 digits + digits: []const u8, + scale: u8, + negative: bool, + + pub fn deinit(this: *Decimal, allocator: std.mem.Allocator) void { + allocator.free(this.digits); + } + + pub fn toJS(this: Decimal, globalObject: *JSC.JSGlobalObject) JSValue { + var stack = std.heap.stackFallback(64, bun.default_allocator); + var str = std.ArrayList(u8).init(stack.get()); + defer str.deinit(); + + if (this.negative) { + str.append('-') catch return JSValue.jsNumber(0); + } + + const decimal_pos = this.digits.len - this.scale; + for (this.digits, 0..) |digit, i| { + if (i == decimal_pos and this.scale > 0) { + str.append('.') catch return JSValue.jsNumber(0); + } + str.append(digit + '0') catch return JSValue.jsNumber(0); + } + + return bun.String.createUTF8ForJS(globalObject, str.items) catch .zero; + } + + pub fn toBinary(_: Decimal, _: FieldType) !Data { + bun.todoPanic(@src(), "Decimal.toBinary not implemented", .{}); + } + + // pub fn fromData(data: *const Data) !Decimal { + // return fromBinary(data.slice()); + // } + + // pub fn fromBinary(_: []const u8) Decimal { + // bun.todoPanic(@src(), "Decimal.toBinary not implemented", .{}); + // } + }; +}; + +// Helper functions for date calculations +fn isLeapYear(year: u16) bool { + return (year % 4 == 0 and year % 100 != 0) or year % 400 == 0; +} + +fn daysInMonth(year: u16, month: u8) u8 { + const days = [_]u8{ 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31 }; + if (month == 2 and isLeapYear(year)) { + return 29; + } + return days[month - 1]; +} + +const Date = struct { + year: u16, + month: u8, + day: u8, +}; + +fn gregorianDate(days: i32) Date { + // Convert days since 1970-01-01 to year/month/day + var d = days; + var y: u16 = 1970; + + while (d >= 365 + @as(u16, @intFromBool(isLeapYear(y)))) : (y += 1) { + d -= 365 + @as(u16, @intFromBool(isLeapYear(y))); + } + + var m: u8 = 1; + while (d >= daysInMonth(y, m)) : (m += 1) { + d -= daysInMonth(y, m); + } + + return .{ + .year = y, + .month = m, + .day = @intCast(d + 1), + }; +} + +pub const MySQLInt8 = int1; +pub const MySQLInt16 = int2; +pub const MySQLInt24 = int3; +pub const MySQLInt32 = int4; +pub const MySQLInt64 = int8; +pub const int1 = u8; +pub const int2 = u16; +pub const int3 = u24; +pub const int4 = u32; +pub const int8 = u64; + +const AnyMySQLError = @import("./protocol/AnyMySQLError.zig"); +const std = @import("std"); +const Data = @import("../shared/Data.zig").Data; + +const bun = @import("bun"); +const String = bun.String; + +const JSC = bun.jsc; +const JSValue = JSC.JSValue; +const ZigString = JSC.ZigString; diff --git a/src/sql/mysql/SSLMode.zig b/src/sql/mysql/SSLMode.zig new file mode 100644 index 0000000000..7be330c3ea --- /dev/null +++ b/src/sql/mysql/SSLMode.zig @@ -0,0 +1,7 @@ +pub const SSLMode = enum(u8) { + disable = 0, + prefer = 1, + require = 2, + verify_ca = 3, + verify_full = 4, +}; diff --git a/src/sql/mysql/StatusFlags.zig b/src/sql/mysql/StatusFlags.zig new file mode 100644 index 0000000000..d7f5c99a21 --- /dev/null +++ b/src/sql/mysql/StatusFlags.zig @@ -0,0 +1,66 @@ +// MySQL connection status flags +pub const StatusFlag = enum(u16) { + SERVER_STATUS_IN_TRANS = 1, + /// Indicates if autocommit mode is enabled + SERVER_STATUS_AUTOCOMMIT = 2, + /// Indicates there are more result sets from this query + SERVER_MORE_RESULTS_EXISTS = 8, + /// Query used a suboptimal index + SERVER_STATUS_NO_GOOD_INDEX_USED = 16, + /// Query performed a full table scan with no index + SERVER_STATUS_NO_INDEX_USED = 32, + /// Indicates an open cursor exists + SERVER_STATUS_CURSOR_EXISTS = 64, + /// Last row in result set has been sent + SERVER_STATUS_LAST_ROW_SENT = 128, + /// Database was dropped + SERVER_STATUS_DB_DROPPED = 1 << 8, + /// Backslash escaping is disabled + SERVER_STATUS_NO_BACKSLASH_ESCAPES = 1 << 9, + /// Server's metadata has changed + SERVER_STATUS_METADATA_CHANGED = 1 << 10, + /// Query execution was considered slow + SERVER_QUERY_WAS_SLOW = 1 << 11, + /// Statement has output parameters + SERVER_PS_OUT_PARAMS = 1 << 12, + /// Transaction is in read-only mode + SERVER_STATUS_IN_TRANS_READONLY = 1 << 13, + /// Session state has changed + SERVER_SESSION_STATE_CHANGED = 1 << 14, +}; + +pub const StatusFlags = struct { + /// Indicates if a transaction is currently active + _value: u16 = 0, + + pub fn format(self: @This(), comptime _: []const u8, _: anytype, writer: anytype) !void { + var first = true; + inline for (comptime std.meta.fieldNames(StatusFlags)) |field| { + if (@TypeOf(@field(self, field)) == bool) { + if (@field(self, field)) { + if (!first) { + try writer.writeAll(", "); + } + first = false; + try writer.writeAll(field); + } + } + } + } + + pub fn has(this: @This(), flag: StatusFlag) bool { + return this._value & @as(u16, @intFromEnum(flag)) != 0; + } + + pub fn toInt(this: @This()) u16 { + return this._value; + } + + pub fn fromInt(flags: u16) @This() { + return @This(){ + ._value = flags, + }; + } +}; + +const std = @import("std"); diff --git a/src/sql/mysql/TLSStatus.zig b/src/sql/mysql/TLSStatus.zig new file mode 100644 index 0000000000..a711af013a --- /dev/null +++ b/src/sql/mysql/TLSStatus.zig @@ -0,0 +1,11 @@ +pub const TLSStatus = union(enum) { + none, + pending, + + /// Number of bytes sent of the 8-byte SSL request message. + /// Since we may send a partial message, we need to know how many bytes were sent. + message_sent: u8, + + ssl_not_available, + ssl_ok, +}; diff --git a/src/sql/mysql/protocol/AnyMySQLError.zig b/src/sql/mysql/protocol/AnyMySQLError.zig new file mode 100644 index 0000000000..2bcea88279 --- /dev/null +++ b/src/sql/mysql/protocol/AnyMySQLError.zig @@ -0,0 +1,90 @@ +pub const Error = error{ + ConnectionClosed, + ConnectionTimedOut, + LifetimeTimeout, + IdleTimeout, + PasswordRequired, + MissingAuthData, + AuthenticationFailed, + FailedToEncryptPassword, + InvalidPublicKey, + UnsupportedAuthPlugin, + UnsupportedProtocolVersion, + + LocalInfileNotSupported, + JSError, + OutOfMemory, + Overflow, + + WrongNumberOfParametersProvided, + + UnsupportedColumnType, + + InvalidLocalInfileRequest, + InvalidAuthSwitchRequest, + InvalidQueryBinding, + InvalidResultRow, + InvalidBinaryValue, + InvalidEncodedInteger, + InvalidEncodedLength, + + InvalidPrepareOKPacket, + InvalidOKPacket, + InvalidErrorPacket, + UnexpectedPacket, + ShortRead, +}; + +pub fn mysqlErrorToJS(globalObject: *jsc.JSGlobalObject, message: ?[]const u8, err: Error) JSValue { + const msg = message orelse @errorName(err); + const code = switch (err) { + error.ConnectionClosed => "ERR_MYSQL_CONNECTION_CLOSED", + error.Overflow => "ERR_MYSQL_OVERFLOW", + error.AuthenticationFailed => "ERR_MYSQL_AUTHENTICATION_FAILED", + error.UnsupportedAuthPlugin => "ERR_MYSQL_UNSUPPORTED_AUTH_PLUGIN", + error.UnsupportedProtocolVersion => "ERR_MYSQL_UNSUPPORTED_PROTOCOL_VERSION", + error.LocalInfileNotSupported => "ERR_MYSQL_LOCAL_INFILE_NOT_SUPPORTED", + error.WrongNumberOfParametersProvided => "ERR_MYSQL_WRONG_NUMBER_OF_PARAMETERS_PROVIDED", + error.UnsupportedColumnType => "ERR_MYSQL_UNSUPPORTED_COLUMN_TYPE", + error.InvalidLocalInfileRequest => "ERR_MYSQL_INVALID_LOCAL_INFILE_REQUEST", + error.InvalidAuthSwitchRequest => "ERR_MYSQL_INVALID_AUTH_SWITCH_REQUEST", + error.InvalidQueryBinding => "ERR_MYSQL_INVALID_QUERY_BINDING", + error.InvalidResultRow => "ERR_MYSQL_INVALID_RESULT_ROW", + error.InvalidBinaryValue => "ERR_MYSQL_INVALID_BINARY_VALUE", + error.InvalidEncodedInteger => "ERR_MYSQL_INVALID_ENCODED_INTEGER", + error.InvalidEncodedLength => "ERR_MYSQL_INVALID_ENCODED_LENGTH", + error.InvalidPrepareOKPacket => "ERR_MYSQL_INVALID_PREPARE_OK_PACKET", + error.InvalidOKPacket => "ERR_MYSQL_INVALID_OK_PACKET", + error.InvalidErrorPacket => "ERR_MYSQL_INVALID_ERROR_PACKET", + error.UnexpectedPacket => "ERR_MYSQL_UNEXPECTED_PACKET", + error.ConnectionTimedOut => "ERR_MYSQL_CONNECTION_TIMEOUT", + error.IdleTimeout => "ERR_MYSQL_IDLE_TIMEOUT", + error.LifetimeTimeout => "ERR_MYSQL_LIFETIME_TIMEOUT", + error.PasswordRequired => "ERR_MYSQL_PASSWORD_REQUIRED", + error.MissingAuthData => "ERR_MYSQL_MISSING_AUTH_DATA", + error.FailedToEncryptPassword => "ERR_MYSQL_FAILED_TO_ENCRYPT_PASSWORD", + error.InvalidPublicKey => "ERR_MYSQL_INVALID_PUBLIC_KEY", + error.JSError => { + return globalObject.takeException(error.JSError); + }, + error.OutOfMemory => { + // TODO: add binding for creating an out of memory error? + return globalObject.takeException(globalObject.throwOutOfMemory()); + }, + error.ShortRead => { + bun.unreachablePanic("Assertion failed: ShortRead should be handled by the caller in postgres", .{}); + }, + }; + + return createMySQLError(globalObject, msg, .{ + .code = code, + .errno = null, + .sqlState = null, + }) catch |ex| globalObject.takeException(ex); +} + +const bun = @import("bun"); +const createMySQLError = @import("./ErrorPacket.zig").createMySQLError; + +const jsc = bun.jsc; +const JSValue = jsc.JSValue; diff --git a/src/sql/mysql/protocol/Auth.zig b/src/sql/mysql/protocol/Auth.zig new file mode 100644 index 0000000000..1d42311f7c --- /dev/null +++ b/src/sql/mysql/protocol/Auth.zig @@ -0,0 +1,208 @@ +// Authentication methods +const Auth = @This(); + +pub const mysql_native_password = struct { + pub fn scramble(password: []const u8, nonce: []const u8) ![20]u8 { + // SHA1( password ) XOR SHA1( nonce + SHA1( SHA1( password ) ) ) ) + var stage1 = [_]u8{0} ** 20; + var stage2 = [_]u8{0} ** 20; + var stage3 = [_]u8{0} ** 20; + var result: [20]u8 = [_]u8{0} ** 20; + + // Stage 1: SHA1(password) + bun.sha.SHA1.hash(password, &stage1, jsc.VirtualMachine.get().rareData().boringEngine()); + + // Stage 2: SHA1(SHA1(password)) + bun.sha.SHA1.hash(&stage1, &stage2, jsc.VirtualMachine.get().rareData().boringEngine()); + + // Stage 3: SHA1(nonce + SHA1(SHA1(password))) + const combined = try bun.default_allocator.alloc(u8, nonce.len + stage2.len); + defer bun.default_allocator.free(combined); + @memcpy(combined[0..nonce.len], nonce); + @memcpy(combined[nonce.len..], &stage2); + bun.sha.SHA1.hash(combined, &stage3, jsc.VirtualMachine.get().rareData().boringEngine()); + + // Final: stage1 XOR stage3 + for (&result, &stage1, &stage3) |*out, d1, d3| { + out.* = d1 ^ d3; + } + + return result; + } +}; + +pub const caching_sha2_password = struct { + pub fn scramble(password: []const u8, nonce: []const u8) ![32]u8 { + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce)) + var digest1 = [_]u8{0} ** 32; + var digest2 = [_]u8{0} ** 32; + var digest3 = [_]u8{0} ** 32; + var result: [32]u8 = [_]u8{0} ** 32; + + // SHA256(password) + bun.sha.SHA256.hash(password, &digest1, jsc.VirtualMachine.get().rareData().boringEngine()); + + // SHA256(SHA256(password)) + bun.sha.SHA256.hash(&digest1, &digest2, jsc.VirtualMachine.get().rareData().boringEngine()); + + // SHA256(SHA256(SHA256(password)) + nonce) + const combined = try bun.default_allocator.alloc(u8, nonce.len + digest2.len); + defer bun.default_allocator.free(combined); + @memcpy(combined[0..nonce.len], nonce); + @memcpy(combined[nonce.len..], &digest2); + bun.sha.SHA256.hash(combined, &digest3, jsc.VirtualMachine.get().rareData().boringEngine()); + + // XOR(SHA256(password), digest3) + for (&result, &digest1, &digest3) |*out, d1, d3| { + out.* = d1 ^ d3; + } + + return result; + } + + pub const FastAuthStatus = enum(u8) { + success = 0x03, + continue_auth = 0x04, + _, + }; + + pub const Response = struct { + status: FastAuthStatus = .success, + data: Data = .{ .empty = {} }, + + pub fn deinit(this: *Response) void { + this.data.deinit(); + } + + pub fn decodeInternal(this: *Response, comptime Context: type, reader: NewReader(Context)) !void { + const status = try reader.int(u8); + debug("FastAuthStatus: {d}", .{status}); + this.status = @enumFromInt(status); + + // Read remaining data if any + const remaining = reader.peek(); + if (remaining.len > 0) { + this.data = try reader.read(remaining.len); + } + } + + pub const decode = decoderWrap(Response, decodeInternal).decode; + }; + pub const EncryptedPassword = struct { + password: []const u8, + public_key: []const u8, + nonce: []const u8, + sequence_id: u8, + + // https://mariadb.com/kb/en/sha256_password-plugin/#rsa-encrypted-password + // RSA encrypted value of XOR(password, seed) using server public key (RSA_PKCS1_OAEP_PADDING). + + pub fn writeInternal(this: *const EncryptedPassword, comptime Context: type, writer: NewWriter(Context)) !void { + // 1024 is overkill but lets cover all cases + var password_buf: [1024]u8 = undefined; + var needs_to_free_password = false; + var plain_password = brk: { + const needed_len = this.password.len + 1; + if (needed_len > password_buf.len) { + needs_to_free_password = true; + break :brk try bun.default_allocator.alloc(u8, needed_len); + } else { + break :brk password_buf[0..needed_len]; + } + }; + @memcpy(plain_password[0..this.password.len], this.password); + plain_password[this.password.len] = 0; + defer if (needs_to_free_password) bun.default_allocator.free(plain_password); + + for (plain_password, 0..) |*c, i| { + c.* ^= this.nonce[i % this.nonce.len]; + } + BoringSSL.load(); + BoringSSL.c.ERR_clear_error(); + // Decode public key + const bio = BoringSSL.c.BIO_new_mem_buf(&this.public_key[0], @intCast(this.public_key.len)) orelse return error.InvalidPublicKey; + defer _ = BoringSSL.c.BIO_free(bio); + + const rsa = BoringSSL.c.PEM_read_bio_RSA_PUBKEY(bio, null, null, null) orelse return { + if (bun.Environment.isDebug) { + BoringSSL.c.ERR_load_ERR_strings(); + BoringSSL.c.ERR_load_crypto_strings(); + var buf: [256]u8 = undefined; + debug("Failed to read public key: {s}", .{BoringSSL.c.ERR_error_string(BoringSSL.c.ERR_get_error(), &buf)}); + } + return error.InvalidPublicKey; + }; + defer BoringSSL.c.RSA_free(rsa); + // encrypt password + + const rsa_size = BoringSSL.c.RSA_size(rsa); + var needs_to_free_encrypted_password = false; + // should never ne bigger than 4096 but lets cover all cases + var encrypted_password_buf: [4096]u8 = undefined; + var encrypted_password = brk: { + if (rsa_size > encrypted_password_buf.len) { + needs_to_free_encrypted_password = true; + break :brk try bun.default_allocator.alloc(u8, rsa_size); + } else { + break :brk encrypted_password_buf[0..rsa_size]; + } + }; + defer if (needs_to_free_encrypted_password) bun.default_allocator.free(encrypted_password); + + const encrypted_password_len = BoringSSL.c.RSA_public_encrypt( + @intCast(plain_password.len), + plain_password.ptr, + encrypted_password.ptr, + rsa, + BoringSSL.c.RSA_PKCS1_OAEP_PADDING, + ); + if (encrypted_password_len == -1) { + return error.FailedToEncryptPassword; + } + const encrypted_password_slice = encrypted_password[0..@intCast(encrypted_password_len)]; + + var packet = try writer.start(this.sequence_id); + try writer.write(encrypted_password_slice); + try packet.end(); + } + + pub const write = writeWrap(EncryptedPassword, writeInternal).write; + }; + pub const PublicKeyResponse = struct { + data: Data = .{ .empty = {} }, + + pub fn deinit(this: *PublicKeyResponse) void { + this.data.deinit(); + } + pub fn decodeInternal(this: *PublicKeyResponse, comptime Context: type, reader: NewReader(Context)) !void { + // get all the data + const remaining = reader.peek(); + if (remaining.len > 0) { + this.data = try reader.read(remaining.len); + } + } + pub const decode = decoderWrap(PublicKeyResponse, decodeInternal).decode; + }; + + pub const PublicKeyRequest = struct { + pub fn writeInternal(this: *const PublicKeyRequest, comptime Context: type, writer: NewWriter(Context)) !void { + _ = this; + try writer.int1(0x02); // Request public key + } + + pub const write = writeWrap(PublicKeyRequest, writeInternal).write; + }; +}; +const debug = bun.Output.scoped(.Auth, .hidden); + +const Data = @import("../../shared/Data.zig").Data; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; + +const NewWriter = @import("./NewWriter.zig").NewWriter; +const writeWrap = @import("./NewWriter.zig").writeWrap; + +const bun = @import("bun"); +const BoringSSL = bun.BoringSSL; +const jsc = bun.jsc; diff --git a/src/sql/mysql/protocol/AuthSwitchRequest.zig b/src/sql/mysql/protocol/AuthSwitchRequest.zig new file mode 100644 index 0000000000..bb5b07ad15 --- /dev/null +++ b/src/sql/mysql/protocol/AuthSwitchRequest.zig @@ -0,0 +1,42 @@ +const AuthSwitchRequest = @This(); +header: u8 = 0xfe, +plugin_name: Data = .{ .empty = {} }, +plugin_data: Data = .{ .empty = {} }, +packet_size: u24, + +pub fn deinit(this: *AuthSwitchRequest) void { + this.plugin_name.deinit(); + this.plugin_data.deinit(); +} + +pub fn decodeInternal(this: *AuthSwitchRequest, comptime Context: type, reader: NewReader(Context)) !void { + this.header = try reader.int(u8); + if (this.header != 0xfe) { + return error.InvalidAuthSwitchRequest; + } + + const remaining = try reader.read(this.packet_size - 1); + const remaining_slice = remaining.slice(); + bun.assert(remaining == .temporary); + + if (bun.strings.indexOfChar(remaining_slice, 0)) |zero| { + // EOF String + this.plugin_name = .{ + .temporary = remaining_slice[0..zero], + }; + // End Of The Packet String + this.plugin_data = .{ + .temporary = remaining_slice[zero + 1 ..], + }; + return; + } + return error.InvalidAuthSwitchRequest; +} + +pub const decode = decoderWrap(AuthSwitchRequest, decodeInternal).decode; + +const bun = @import("bun"); +const Data = @import("../../shared/Data.zig").Data; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; diff --git a/src/sql/mysql/protocol/AuthSwitchResponse.zig b/src/sql/mysql/protocol/AuthSwitchResponse.zig new file mode 100644 index 0000000000..751d0c21e4 --- /dev/null +++ b/src/sql/mysql/protocol/AuthSwitchResponse.zig @@ -0,0 +1,18 @@ +// Auth switch response packet +const AuthSwitchResponse = @This(); +auth_response: Data = .{ .empty = {} }, + +pub fn deinit(this: *AuthSwitchResponse) void { + this.auth_response.deinit(); +} + +pub fn writeInternal(this: *const AuthSwitchResponse, comptime Context: type, writer: NewWriter(Context)) !void { + try writer.write(this.auth_response.slice()); +} + +pub const write = writeWrap(AuthSwitchResponse, writeInternal).write; + +const Data = @import("../../shared/Data.zig").Data; + +const NewWriter = @import("./NewWriter.zig").NewWriter; +const writeWrap = @import("./NewWriter.zig").writeWrap; diff --git a/src/sql/mysql/protocol/CharacterSet.zig b/src/sql/mysql/protocol/CharacterSet.zig new file mode 100644 index 0000000000..3e9a8c3bca --- /dev/null +++ b/src/sql/mysql/protocol/CharacterSet.zig @@ -0,0 +1,236 @@ +pub const CharacterSet = enum(u8) { + big5_chinese_ci = 1, + latin2_czech_cs = 2, + dec8_swedish_ci = 3, + cp850_general_ci = 4, + latin1_german1_ci = 5, + hp8_english_ci = 6, + koi8r_general_ci = 7, + latin1_swedish_ci = 8, + latin2_general_ci = 9, + swe7_swedish_ci = 10, + ascii_general_ci = 11, + ujis_japanese_ci = 12, + sjis_japanese_ci = 13, + cp1251_bulgarian_ci = 14, + latin1_danish_ci = 15, + hebrew_general_ci = 16, + tis620_thai_ci = 18, + euckr_korean_ci = 19, + latin7_estonian_cs = 20, + latin2_hungarian_ci = 21, + koi8u_general_ci = 22, + cp1251_ukrainian_ci = 23, + gb2312_chinese_ci = 24, + greek_general_ci = 25, + cp1250_general_ci = 26, + latin2_croatian_ci = 27, + gbk_chinese_ci = 28, + cp1257_lithuanian_ci = 29, + latin5_turkish_ci = 30, + latin1_german2_ci = 31, + armscii8_general_ci = 32, + utf8mb3_general_ci = 33, + cp1250_czech_cs = 34, + ucs2_general_ci = 35, + cp866_general_ci = 36, + keybcs2_general_ci = 37, + macce_general_ci = 38, + macroman_general_ci = 39, + cp852_general_ci = 40, + latin7_general_ci = 41, + latin7_general_cs = 42, + macce_bin = 43, + cp1250_croatian_ci = 44, + utf8mb4_general_ci = 45, + utf8mb4_bin = 46, + latin1_bin = 47, + latin1_general_ci = 48, + latin1_general_cs = 49, + cp1251_bin = 50, + cp1251_general_ci = 51, + cp1251_general_cs = 52, + macroman_bin = 53, + utf16_general_ci = 54, + utf16_bin = 55, + utf16le_general_ci = 56, + cp1256_general_ci = 57, + cp1257_bin = 58, + cp1257_general_ci = 59, + utf32_general_ci = 60, + utf32_bin = 61, + utf16le_bin = 62, + binary = 63, + armscii8_bin = 64, + ascii_bin = 65, + cp1250_bin = 66, + cp1256_bin = 67, + cp866_bin = 68, + dec8_bin = 69, + greek_bin = 70, + hebrew_bin = 71, + hp8_bin = 72, + keybcs2_bin = 73, + koi8r_bin = 74, + koi8u_bin = 75, + utf8mb3_tolower_ci = 76, + latin2_bin = 77, + latin5_bin = 78, + latin7_bin = 79, + cp850_bin = 80, + cp852_bin = 81, + swe7_bin = 82, + utf8mb3_bin = 83, + big5_bin = 84, + euckr_bin = 85, + gb2312_bin = 86, + gbk_bin = 87, + sjis_bin = 88, + tis620_bin = 89, + ucs2_bin = 90, + ujis_bin = 91, + geostd8_general_ci = 92, + geostd8_bin = 93, + latin1_spanish_ci = 94, + cp932_japanese_ci = 95, + cp932_bin = 96, + eucjpms_japanese_ci = 97, + eucjpms_bin = 98, + cp1250_polish_ci = 99, + utf16_unicode_ci = 101, + utf16_icelandic_ci = 102, + utf16_latvian_ci = 103, + utf16_romanian_ci = 104, + utf16_slovenian_ci = 105, + utf16_polish_ci = 106, + utf16_estonian_ci = 107, + utf16_spanish_ci = 108, + utf16_swedish_ci = 109, + utf16_turkish_ci = 110, + utf16_czech_ci = 111, + utf16_danish_ci = 112, + utf16_lithuanian_ci = 113, + utf16_slovak_ci = 114, + utf16_spanish2_ci = 115, + utf16_roman_ci = 116, + utf16_persian_ci = 117, + utf16_esperanto_ci = 118, + utf16_hungarian_ci = 119, + utf16_sinhala_ci = 120, + utf16_german2_ci = 121, + utf16_croatian_ci = 122, + utf16_unicode_520_ci = 123, + utf16_vietnamese_ci = 124, + ucs2_unicode_ci = 128, + ucs2_icelandic_ci = 129, + ucs2_latvian_ci = 130, + ucs2_romanian_ci = 131, + ucs2_slovenian_ci = 132, + ucs2_polish_ci = 133, + ucs2_estonian_ci = 134, + ucs2_spanish_ci = 135, + ucs2_swedish_ci = 136, + ucs2_turkish_ci = 137, + ucs2_czech_ci = 138, + ucs2_danish_ci = 139, + ucs2_lithuanian_ci = 140, + ucs2_slovak_ci = 141, + ucs2_spanish2_ci = 142, + ucs2_roman_ci = 143, + ucs2_persian_ci = 144, + ucs2_esperanto_ci = 145, + ucs2_hungarian_ci = 146, + ucs2_sinhala_ci = 147, + ucs2_german2_ci = 148, + ucs2_croatian_ci = 149, + ucs2_unicode_520_ci = 150, + ucs2_vietnamese_ci = 151, + ucs2_general_mysql500_ci = 159, + utf32_unicode_ci = 160, + utf32_icelandic_ci = 161, + utf32_latvian_ci = 162, + utf32_romanian_ci = 163, + utf32_slovenian_ci = 164, + utf32_polish_ci = 165, + utf32_estonian_ci = 166, + utf32_spanish_ci = 167, + utf32_swedish_ci = 168, + utf32_turkish_ci = 169, + utf32_czech_ci = 170, + utf32_danish_ci = 171, + utf32_lithuanian_ci = 172, + utf32_slovak_ci = 173, + utf32_spanish2_ci = 174, + utf32_roman_ci = 175, + utf32_persian_ci = 176, + utf32_esperanto_ci = 177, + utf32_hungarian_ci = 178, + utf32_sinhala_ci = 179, + utf32_german2_ci = 180, + utf32_croatian_ci = 181, + utf32_unicode_520_ci = 182, + utf32_vietnamese_ci = 183, + utf8mb3_unicode_ci = 192, + utf8mb3_icelandic_ci = 193, + utf8mb3_latvian_ci = 194, + utf8mb3_romanian_ci = 195, + utf8mb3_slovenian_ci = 196, + utf8mb3_polish_ci = 197, + utf8mb3_estonian_ci = 198, + utf8mb3_spanish_ci = 199, + utf8mb3_swedish_ci = 200, + utf8mb3_turkish_ci = 201, + utf8mb3_czech_ci = 202, + utf8mb3_danish_ci = 203, + utf8mb3_lithuanian_ci = 204, + utf8mb3_slovak_ci = 205, + utf8mb3_spanish2_ci = 206, + utf8mb3_roman_ci = 207, + utf8mb3_persian_ci = 208, + utf8mb3_esperanto_ci = 209, + utf8mb3_hungarian_ci = 210, + utf8mb3_sinhala_ci = 211, + utf8mb3_german2_ci = 212, + utf8mb3_croatian_ci = 213, + utf8mb3_unicode_520_ci = 214, + utf8mb3_vietnamese_ci = 215, + utf8mb3_general_mysql500_ci = 223, + utf8mb4_unicode_ci = 224, + utf8mb4_icelandic_ci = 225, + utf8mb4_latvian_ci = 226, + utf8mb4_romanian_ci = 227, + utf8mb4_slovenian_ci = 228, + utf8mb4_polish_ci = 229, + utf8mb4_estonian_ci = 230, + utf8mb4_spanish_ci = 231, + utf8mb4_swedish_ci = 232, + utf8mb4_turkish_ci = 233, + utf8mb4_czech_ci = 234, + utf8mb4_danish_ci = 235, + utf8mb4_lithuanian_ci = 236, + utf8mb4_slovak_ci = 237, + utf8mb4_spanish2_ci = 238, + utf8mb4_roman_ci = 239, + utf8mb4_persian_ci = 240, + utf8mb4_esperanto_ci = 241, + utf8mb4_hungarian_ci = 242, + utf8mb4_sinhala_ci = 243, + utf8mb4_german2_ci = 244, + utf8mb4_croatian_ci = 245, + utf8mb4_unicode_520_ci = 246, + utf8mb4_vietnamese_ci = 247, + gb18030_chinese_ci = 248, + gb18030_bin = 249, + gb18030_unicode_520_ci = 250, + _, + + pub const default = CharacterSet.utf8mb4_general_ci; + + pub fn label(this: CharacterSet) []const u8 { + if (@intFromEnum(this) < 100 and @intFromEnum(this) > 0) { + return @tagName(this); + } + + return "(unknown)"; + } +}; diff --git a/src/sql/mysql/protocol/ColumnDefinition41.zig b/src/sql/mysql/protocol/ColumnDefinition41.zig new file mode 100644 index 0000000000..6dae10d7d9 --- /dev/null +++ b/src/sql/mysql/protocol/ColumnDefinition41.zig @@ -0,0 +1,97 @@ +const ColumnDefinition41 = @This(); +catalog: Data = .{ .empty = {} }, +schema: Data = .{ .empty = {} }, +table: Data = .{ .empty = {} }, +org_table: Data = .{ .empty = {} }, +name: Data = .{ .empty = {} }, +org_name: Data = .{ .empty = {} }, +fixed_length_fields_length: u64 = 0, +character_set: u16 = 0, +column_length: u32 = 0, +column_type: types.FieldType = .MYSQL_TYPE_NULL, +flags: ColumnFlags = .{}, +decimals: u8 = 0, +name_or_index: ColumnIdentifier = .{ + .name = .{ .empty = {} }, +}, + +pub const ColumnFlags = packed struct { + NOT_NULL: bool = false, + PRI_KEY: bool = false, + UNIQUE_KEY: bool = false, + MULTIPLE_KEY: bool = false, + BLOB: bool = false, + UNSIGNED: bool = false, + ZEROFILL: bool = false, + BINARY: bool = false, + ENUM: bool = false, + AUTO_INCREMENT: bool = false, + TIMESTAMP: bool = false, + SET: bool = false, + NO_DEFAULT_VALUE: bool = false, + ON_UPDATE_NOW: bool = false, + _padding: u2 = 0, + + pub fn toInt(this: ColumnFlags) u16 { + return @bitCast(this); + } + + pub fn fromInt(flags: u16) ColumnFlags { + return @bitCast(flags); + } +}; + +pub fn deinit(this: *ColumnDefinition41) void { + this.catalog.deinit(); + this.schema.deinit(); + this.table.deinit(); + this.org_table.deinit(); + this.name.deinit(); + this.org_name.deinit(); +} + +pub fn decodeInternal(this: *ColumnDefinition41, comptime Context: type, reader: NewReader(Context)) !void { + // Length encoded strings + this.catalog = try reader.encodeLenString(); + debug("catalog: {s}", .{this.catalog.slice()}); + + this.schema = try reader.encodeLenString(); + debug("schema: {s}", .{this.schema.slice()}); + + this.table = try reader.encodeLenString(); + debug("table: {s}", .{this.table.slice()}); + + this.org_table = try reader.encodeLenString(); + debug("org_table: {s}", .{this.org_table.slice()}); + + this.name = try reader.encodeLenString(); + debug("name: {s}", .{this.name.slice()}); + + this.org_name = try reader.encodeLenString(); + debug("org_name: {s}", .{this.org_name.slice()}); + + this.fixed_length_fields_length = try reader.encodedLenInt(); + this.character_set = try reader.int(u16); + this.column_length = try reader.int(u32); + this.column_type = @enumFromInt(try reader.int(u8)); + this.flags = ColumnFlags.fromInt(try reader.int(u16)); + this.decimals = try reader.int(u8); + + this.name_or_index = try ColumnIdentifier.init(this.name); + + // https://mariadb.com/kb/en/result-set-packets/#column-definition-packet + // According to mariadb, there seem to be extra 2 bytes at the end that is not being used + reader.skip(2); +} + +pub const decode = decoderWrap(ColumnDefinition41, decodeInternal).decode; + +const debug = bun.Output.scoped(.ColumnDefinition41, .hidden); + +const bun = @import("bun"); +const types = @import("../MySQLTypes.zig"); +const ColumnIdentifier = @import("../../shared/ColumnIdentifier.zig").ColumnIdentifier; +const Data = @import("../../shared/Data.zig").Data; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; diff --git a/src/sql/mysql/protocol/CommandType.zig b/src/sql/mysql/protocol/CommandType.zig new file mode 100644 index 0000000000..8dc861487d --- /dev/null +++ b/src/sql/mysql/protocol/CommandType.zig @@ -0,0 +1,34 @@ +// Command packet types +pub const CommandType = enum(u8) { + COM_QUIT = 0x01, + COM_INIT_DB = 0x02, + COM_QUERY = 0x03, + COM_FIELD_LIST = 0x04, + COM_CREATE_DB = 0x05, + COM_DROP_DB = 0x06, + COM_REFRESH = 0x07, + COM_SHUTDOWN = 0x08, + COM_STATISTICS = 0x09, + COM_PROCESS_INFO = 0x0a, + COM_CONNECT = 0x0b, + COM_PROCESS_KILL = 0x0c, + COM_DEBUG = 0x0d, + COM_PING = 0x0e, + COM_TIME = 0x0f, + COM_DELAYED_INSERT = 0x10, + COM_CHANGE_USER = 0x11, + COM_BINLOG_DUMP = 0x12, + COM_TABLE_DUMP = 0x13, + COM_CONNECT_OUT = 0x14, + COM_REGISTER_SLAVE = 0x15, + COM_STMT_PREPARE = 0x16, + COM_STMT_EXECUTE = 0x17, + COM_STMT_SEND_LONG_DATA = 0x18, + COM_STMT_CLOSE = 0x19, + COM_STMT_RESET = 0x1a, + COM_SET_OPTION = 0x1b, + COM_STMT_FETCH = 0x1c, + COM_DAEMON = 0x1d, + COM_BINLOG_DUMP_GTID = 0x1e, + COM_RESET_CONNECTION = 0x1f, +}; diff --git a/src/sql/mysql/protocol/DecodeBinaryValue.zig b/src/sql/mysql/protocol/DecodeBinaryValue.zig new file mode 100644 index 0000000000..2fd083873f --- /dev/null +++ b/src/sql/mysql/protocol/DecodeBinaryValue.zig @@ -0,0 +1,153 @@ +pub fn decodeBinaryValue(globalObject: *jsc.JSGlobalObject, field_type: types.FieldType, raw: bool, bigint: bool, unsigned: bool, comptime Context: type, reader: NewReader(Context)) !SQLDataCell { + debug("decodeBinaryValue: {s}", .{@tagName(field_type)}); + return switch (field_type) { + .MYSQL_TYPE_TINY => { + if (raw) { + var data = try reader.read(1); + defer data.deinit(); + return SQLDataCell.raw(&data); + } + const val = try reader.byte(); + return SQLDataCell{ .tag = .bool, .value = .{ .bool = val } }; + }, + .MYSQL_TYPE_SHORT => { + if (raw) { + var data = try reader.read(2); + defer data.deinit(); + return SQLDataCell.raw(&data); + } + if (unsigned) { + return SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = try reader.int(u16) } }; + } + return SQLDataCell{ .tag = .int4, .value = .{ .int4 = try reader.int(i16) } }; + }, + .MYSQL_TYPE_LONG => { + if (raw) { + var data = try reader.read(4); + defer data.deinit(); + return SQLDataCell.raw(&data); + } + if (unsigned) { + return SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = try reader.int(u32) } }; + } + return SQLDataCell{ .tag = .int4, .value = .{ .int4 = try reader.int(i32) } }; + }, + .MYSQL_TYPE_LONGLONG => { + if (raw) { + return SQLDataCell.raw(&try reader.read(8)); + } + if (unsigned) { + const val = try reader.int(u64); + if (val <= std.math.maxInt(u32)) { + return SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = @intCast(val) } }; + } + if (bigint) { + return SQLDataCell{ .tag = .uint8, .value = .{ .uint8 = val } }; + } + var buffer: [22]u8 = undefined; + const slice = std.fmt.bufPrint(&buffer, "{d}", .{val}) catch unreachable; + return SQLDataCell{ .tag = .string, .value = .{ .string = if (slice.len > 0) bun.String.cloneUTF8(slice).value.WTFStringImpl else null }, .free_value = 1 }; + } + const val = try reader.int(i64); + if (val >= std.math.minInt(i32) and val <= std.math.maxInt(i32)) { + return SQLDataCell{ .tag = .int4, .value = .{ .int4 = @intCast(val) } }; + } + if (bigint) { + return SQLDataCell{ .tag = .int8, .value = .{ .int8 = val } }; + } + var buffer: [22]u8 = undefined; + const slice = std.fmt.bufPrint(&buffer, "{d}", .{val}) catch unreachable; + return SQLDataCell{ .tag = .string, .value = .{ .string = if (slice.len > 0) bun.String.cloneUTF8(slice).value.WTFStringImpl else null }, .free_value = 1 }; + }, + .MYSQL_TYPE_FLOAT => { + if (raw) { + var data = try reader.read(4); + defer data.deinit(); + return SQLDataCell.raw(&data); + } + return SQLDataCell{ .tag = .float8, .value = .{ .float8 = @as(f32, @bitCast(try reader.int(u32))) } }; + }, + .MYSQL_TYPE_DOUBLE => { + if (raw) { + var data = try reader.read(8); + defer data.deinit(); + return SQLDataCell.raw(&data); + } + return SQLDataCell{ .tag = .float8, .value = .{ .float8 = @bitCast(try reader.int(u64)) } }; + }, + .MYSQL_TYPE_TIME => { + return switch (try reader.byte()) { + 0 => SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }, + 8, 12 => |l| { + var data = try reader.read(l); + defer data.deinit(); + const time = try Time.fromData(&data); + return SQLDataCell{ .tag = .date, .value = .{ .date = time.toJSTimestamp() } }; + }, + else => return error.InvalidBinaryValue, + }; + }, + .MYSQL_TYPE_DATE, .MYSQL_TYPE_TIMESTAMP, .MYSQL_TYPE_DATETIME => switch (try reader.byte()) { + 0 => SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }, + 11, 7, 4 => |l| { + var data = try reader.read(l); + defer data.deinit(); + const time = try DateTime.fromData(&data); + return SQLDataCell{ .tag = .date, .value = .{ .date = try time.toJSTimestamp(globalObject) } }; + }, + else => error.InvalidBinaryValue, + }, + + .MYSQL_TYPE_ENUM, + .MYSQL_TYPE_SET, + .MYSQL_TYPE_GEOMETRY, + .MYSQL_TYPE_NEWDECIMAL, + .MYSQL_TYPE_STRING, + .MYSQL_TYPE_VARCHAR, + .MYSQL_TYPE_VAR_STRING, + // We could return Buffer here BUT TEXT, LONGTEXT, MEDIUMTEXT, TINYTEXT, etc. are BLOB and the user expects a string + .MYSQL_TYPE_TINY_BLOB, + .MYSQL_TYPE_MEDIUM_BLOB, + .MYSQL_TYPE_LONG_BLOB, + .MYSQL_TYPE_BLOB, + => { + if (raw) { + var data = try reader.rawEncodeLenData(); + defer data.deinit(); + return SQLDataCell.raw(&data); + } + var string_data = try reader.encodeLenString(); + defer string_data.deinit(); + + const slice = string_data.slice(); + return SQLDataCell{ .tag = .string, .value = .{ .string = if (slice.len > 0) bun.String.cloneUTF8(slice).value.WTFStringImpl else null }, .free_value = 1 }; + }, + + .MYSQL_TYPE_JSON => { + if (raw) { + var data = try reader.rawEncodeLenData(); + defer data.deinit(); + return SQLDataCell.raw(&data); + } + var string_data = try reader.encodeLenString(); + defer string_data.deinit(); + const slice = string_data.slice(); + return SQLDataCell{ .tag = .json, .value = .{ .json = if (slice.len > 0) bun.String.cloneUTF8(slice).value.WTFStringImpl else null }, .free_value = 1 }; + }, + else => return error.UnsupportedColumnType, + }; +} + +const debug = bun.Output.scoped(.MySQLDecodeBinaryValue, .visible); + +const std = @import("std"); +const types = @import("../MySQLTypes.zig"); +const NewReader = @import("./NewReader.zig").NewReader; +const SQLDataCell = @import("../../shared/SQLDataCell.zig").SQLDataCell; + +const Value = @import("../MySQLTypes.zig").Value; +const DateTime = Value.DateTime; +const Time = Value.Time; + +const bun = @import("bun"); +const jsc = bun.jsc; diff --git a/src/sql/mysql/protocol/EOFPacket.zig b/src/sql/mysql/protocol/EOFPacket.zig new file mode 100644 index 0000000000..02da929d83 --- /dev/null +++ b/src/sql/mysql/protocol/EOFPacket.zig @@ -0,0 +1,21 @@ +const EOFPacket = @This(); +header: u8 = 0xfe, +warnings: u16 = 0, +status_flags: StatusFlags = .{}, + +pub fn decodeInternal(this: *EOFPacket, comptime Context: type, reader: NewReader(Context)) !void { + this.header = try reader.int(u8); + if (this.header != 0xfe) { + return error.InvalidEOFPacket; + } + + this.warnings = try reader.int(u16); + this.status_flags = StatusFlags.fromInt(try reader.int(u16)); +} + +pub const decode = decoderWrap(EOFPacket, decodeInternal).decode; + +const StatusFlags = @import("../StatusFlags.zig").StatusFlags; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; diff --git a/src/sql/mysql/protocol/EncodeInt.zig b/src/sql/mysql/protocol/EncodeInt.zig new file mode 100644 index 0000000000..b42c7d795d --- /dev/null +++ b/src/sql/mysql/protocol/EncodeInt.zig @@ -0,0 +1,73 @@ +// Length-encoded integer encoding/decoding +pub fn encodeLengthInt(value: u64) std.BoundedArray(u8, 9) { + var array: std.BoundedArray(u8, 9) = .{}; + if (value < 0xfb) { + array.len = 1; + array.buffer[0] = @intCast(value); + } else if (value < 0xffff) { + array.len = 3; + array.buffer[0] = 0xfc; + array.buffer[1] = @intCast(value & 0xff); + array.buffer[2] = @intCast((value >> 8) & 0xff); + } else if (value < 0xffffff) { + array.len = 4; + array.buffer[0] = 0xfd; + array.buffer[1] = @intCast(value & 0xff); + array.buffer[2] = @intCast((value >> 8) & 0xff); + array.buffer[3] = @intCast((value >> 16) & 0xff); + } else { + array.len = 9; + array.buffer[0] = 0xfe; + array.buffer[1] = @intCast(value & 0xff); + array.buffer[2] = @intCast((value >> 8) & 0xff); + array.buffer[3] = @intCast((value >> 16) & 0xff); + array.buffer[4] = @intCast((value >> 24) & 0xff); + array.buffer[5] = @intCast((value >> 32) & 0xff); + array.buffer[6] = @intCast((value >> 40) & 0xff); + array.buffer[7] = @intCast((value >> 48) & 0xff); + array.buffer[8] = @intCast((value >> 56) & 0xff); + } + return array; +} + +pub fn decodeLengthInt(bytes: []const u8) ?struct { value: u64, bytes_read: usize } { + if (bytes.len == 0) return null; + + const first_byte = bytes[0]; + + switch (first_byte) { + 0xfc => { + if (bytes.len < 3) return null; + return .{ + .value = @as(u64, bytes[1]) | (@as(u64, bytes[2]) << 8), + .bytes_read = 3, + }; + }, + 0xfd => { + if (bytes.len < 4) return null; + return .{ + .value = @as(u64, bytes[1]) | + (@as(u64, bytes[2]) << 8) | + (@as(u64, bytes[3]) << 16), + .bytes_read = 4, + }; + }, + 0xfe => { + if (bytes.len < 9) return null; + return .{ + .value = @as(u64, bytes[1]) | + (@as(u64, bytes[2]) << 8) | + (@as(u64, bytes[3]) << 16) | + (@as(u64, bytes[4]) << 24) | + (@as(u64, bytes[5]) << 32) | + (@as(u64, bytes[6]) << 40) | + (@as(u64, bytes[7]) << 48) | + (@as(u64, bytes[8]) << 56), + .bytes_read = 9, + }; + }, + else => return .{ .value = @byteSwap(first_byte), .bytes_read = 1 }, + } +} + +const std = @import("std"); diff --git a/src/sql/mysql/protocol/ErrorPacket.zig b/src/sql/mysql/protocol/ErrorPacket.zig new file mode 100644 index 0000000000..5e16c7c97f --- /dev/null +++ b/src/sql/mysql/protocol/ErrorPacket.zig @@ -0,0 +1,82 @@ +const ErrorPacket = @This(); +header: u8 = 0xff, +error_code: u16 = 0, +sql_state_marker: ?u8 = null, +sql_state: ?[5]u8 = null, +error_message: Data = .{ .empty = {} }, + +pub fn deinit(this: *ErrorPacket) void { + this.error_message.deinit(); +} +pub const MySQLErrorOptions = struct { + code: []const u8, + errno: ?u16 = null, + sqlState: ?[5]u8 = null, +}; + +pub fn createMySQLError( + globalObject: *JSC.JSGlobalObject, + message: []const u8, + options: MySQLErrorOptions, +) bun.JSError!JSValue { + const opts_obj = JSValue.createEmptyObject(globalObject, 18); + opts_obj.ensureStillAlive(); + opts_obj.put(globalObject, JSC.ZigString.static("code"), try bun.String.createUTF8ForJS(globalObject, options.code)); + if (options.errno) |errno| { + opts_obj.put(globalObject, JSC.ZigString.static("errno"), JSC.JSValue.jsNumber(errno)); + } + if (options.sqlState) |state| { + opts_obj.put(globalObject, JSC.ZigString.static("sqlState"), try bun.String.createUTF8ForJS(globalObject, state[0..])); + } + opts_obj.put(globalObject, JSC.ZigString.static("message"), try bun.String.createUTF8ForJS(globalObject, message)); + + return opts_obj; +} + +pub fn decodeInternal(this: *ErrorPacket, comptime Context: type, reader: NewReader(Context)) !void { + this.header = try reader.int(u8); + if (this.header != 0xff) { + return error.InvalidErrorPacket; + } + + this.error_code = try reader.int(u16); + + // Check if we have a SQL state marker + const next_byte = try reader.int(u8); + if (next_byte == '#') { + this.sql_state_marker = '#'; + var sql_state_data = try reader.read(5); + defer sql_state_data.deinit(); + this.sql_state = sql_state_data.slice()[0..5].*; + } else { + // No SQL state, rewind one byte + reader.skip(-1); + } + + // Read the error message (rest of packet) + this.error_message = try reader.read(reader.peek().len); +} + +pub const decode = decoderWrap(ErrorPacket, decodeInternal).decode; + +pub fn toJS(this: ErrorPacket, globalObject: *JSC.JSGlobalObject) JSValue { + var msg = this.error_message.slice(); + if (msg.len == 0) { + msg = "MySQL error occurred"; + } + + return createMySQLError(globalObject, msg, .{ + .code = if (this.error_code == 1064) "ERR_MYSQL_SYNTAX_ERROR" else "ERR_MYSQL_SERVER_ERROR", + .errno = this.error_code, + .sqlState = this.sql_state, + }) catch |err| globalObject.takeException(err); +} + +const bun = @import("bun"); +const Data = @import("../../shared/Data.zig").Data; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; + +const JSC = bun.jsc; +const JSValue = JSC.JSValue; diff --git a/src/sql/mysql/protocol/HandshakeResponse41.zig b/src/sql/mysql/protocol/HandshakeResponse41.zig new file mode 100644 index 0000000000..5d56b3942e --- /dev/null +++ b/src/sql/mysql/protocol/HandshakeResponse41.zig @@ -0,0 +1,108 @@ +// Client authentication response +const HandshakeResponse41 = @This(); +capability_flags: Capabilities, +max_packet_size: u32 = 0xFFFFFF, // 16MB default +character_set: CharacterSet = CharacterSet.default, +username: Data, +auth_response: Data, +database: Data, +auth_plugin_name: Data, +connect_attrs: bun.StringHashMapUnmanaged([]const u8) = .{}, + +pub fn deinit(this: *HandshakeResponse41) void { + this.username.deinit(); + this.auth_response.deinit(); + this.database.deinit(); + this.auth_plugin_name.deinit(); + + var it = this.connect_attrs.iterator(); + while (it.next()) |entry| { + bun.default_allocator.free(entry.key_ptr.*); + bun.default_allocator.free(entry.value_ptr.*); + } + this.connect_attrs.deinit(bun.default_allocator); +} + +pub fn writeInternal(this: *HandshakeResponse41, comptime Context: type, writer: NewWriter(Context)) !void { + var packet = try writer.start(1); + + this.capability_flags.CLIENT_CONNECT_ATTRS = this.connect_attrs.count() > 0; + + // Write client capabilities flags (4 bytes) + const caps = this.capability_flags.toInt(); + try writer.int4(caps); + debug("Client capabilities: [{}] 0x{x:0>8}", .{ this.capability_flags, caps }); + + // Write max packet size (4 bytes) + try writer.int4(this.max_packet_size); + + // Write character set (1 byte) + try writer.int1(@intFromEnum(this.character_set)); + + // Write 23 bytes of padding + try writer.write(&[_]u8{0} ** 23); + + // Write username (null terminated) + try writer.writeZ(this.username.slice()); + + // Write auth response based on capabilities + const auth_data = this.auth_response.slice(); + if (this.capability_flags.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) { + try writer.writeLengthEncodedString(auth_data); + } else if (this.capability_flags.CLIENT_SECURE_CONNECTION) { + try writer.int1(@intCast(auth_data.len)); + try writer.write(auth_data); + } else { + try writer.writeZ(auth_data); + } + + // Write database name if requested + if (this.capability_flags.CLIENT_CONNECT_WITH_DB and this.database.slice().len > 0) { + try writer.writeZ(this.database.slice()); + } + + // Write auth plugin name if supported + if (this.capability_flags.CLIENT_PLUGIN_AUTH) { + try writer.writeZ(this.auth_plugin_name.slice()); + } + + // Write connect attributes if enabled + if (this.capability_flags.CLIENT_CONNECT_ATTRS) { + var total_length: usize = 0; + var it = this.connect_attrs.iterator(); + while (it.next()) |entry| { + total_length += encodeLengthInt(entry.key_ptr.len).len; + total_length += entry.key_ptr.len; + total_length += encodeLengthInt(entry.value_ptr.len).len; + total_length += entry.value_ptr.len; + } + + try writer.writeLengthEncodedInt(total_length); + + it = this.connect_attrs.iterator(); + while (it.next()) |entry| { + try writer.writeLengthEncodedString(entry.key_ptr.*); + try writer.writeLengthEncodedString(entry.value_ptr.*); + } + } + + if (this.capability_flags.CLIENT_ZSTD_COMPRESSION_ALGORITHM) { + // try writer.writeInt(u8, this.zstd_compression_algorithm); + bun.assertf(false, "zstd compression algorithm is not supported", .{}); + } + + try packet.end(); +} + +pub const write = writeWrap(HandshakeResponse41, writeInternal).write; + +const debug = bun.Output.scoped(.MySQLConnection, .hidden); + +const Capabilities = @import("../Capabilities.zig"); +const bun = @import("bun"); +const CharacterSet = @import("./CharacterSet.zig").CharacterSet; +const Data = @import("../../shared/Data.zig").Data; +const encodeLengthInt = @import("./EncodeInt.zig").encodeLengthInt; + +const NewWriter = @import("./NewWriter.zig").NewWriter; +const writeWrap = @import("./NewWriter.zig").writeWrap; diff --git a/src/sql/mysql/protocol/HandshakeV10.zig b/src/sql/mysql/protocol/HandshakeV10.zig new file mode 100644 index 0000000000..dcb8df3ea6 --- /dev/null +++ b/src/sql/mysql/protocol/HandshakeV10.zig @@ -0,0 +1,82 @@ +// Initial handshake packet from server +const HandshakeV10 = @This(); +protocol_version: u8 = 10, +server_version: Data = .{ .empty = {} }, +connection_id: u32 = 0, +auth_plugin_data_part_1: [8]u8 = undefined, +auth_plugin_data_part_2: []const u8 = &[_]u8{}, +capability_flags: Capabilities = .{}, +character_set: CharacterSet = CharacterSet.default, +status_flags: StatusFlags = .{}, +auth_plugin_name: Data = .{ .empty = {} }, + +pub fn deinit(this: *HandshakeV10) void { + this.server_version.deinit(); + this.auth_plugin_name.deinit(); +} + +pub fn decodeInternal(this: *HandshakeV10, comptime Context: type, reader: NewReader(Context)) !void { + // Protocol version + this.protocol_version = try reader.int(u8); + if (this.protocol_version != 10) { + return error.UnsupportedProtocolVersion; + } + + // Server version (null-terminated string) + this.server_version = try reader.readZ(); + + // Connection ID (4 bytes) + this.connection_id = try reader.int(u32); + + // Auth plugin data part 1 (8 bytes) + var auth_data = try reader.read(8); + defer auth_data.deinit(); + @memcpy(&this.auth_plugin_data_part_1, auth_data.slice()); + + // Skip filler byte + _ = try reader.int(u8); + + // Capability flags (lower 2 bytes) + const capabilities_lower = try reader.int(u16); + + // Character set + this.character_set = @enumFromInt(try reader.int(u8)); + + // Status flags + this.status_flags = StatusFlags.fromInt(try reader.int(u16)); + + // Capability flags (upper 2 bytes) + const capabilities_upper = try reader.int(u16); + this.capability_flags = Capabilities.fromInt(@as(u32, capabilities_upper) << 16 | capabilities_lower); + + // Length of auth plugin data + var auth_plugin_data_len = try reader.int(u8); + if (auth_plugin_data_len < 21) { + auth_plugin_data_len = 21; + } + + // Skip reserved bytes + reader.skip(10); + + // Auth plugin data part 2 + const remaining_auth_len = @max(13, auth_plugin_data_len - 8); + var auth_data_2 = try reader.read(remaining_auth_len); + defer auth_data_2.deinit(); + this.auth_plugin_data_part_2 = try bun.default_allocator.dupe(u8, auth_data_2.slice()); + + // Auth plugin name + if (this.capability_flags.CLIENT_PLUGIN_AUTH) { + this.auth_plugin_name = try reader.readZ(); + } +} + +pub const decode = decoderWrap(HandshakeV10, decodeInternal).decode; + +const Capabilities = @import("../Capabilities.zig"); +const bun = @import("bun"); +const CharacterSet = @import("./CharacterSet.zig").CharacterSet; +const Data = @import("../../shared/Data.zig").Data; +const StatusFlags = @import("../StatusFlags.zig").StatusFlags; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; diff --git a/src/sql/mysql/protocol/LocalInfileRequest.zig b/src/sql/mysql/protocol/LocalInfileRequest.zig new file mode 100644 index 0000000000..eb00320171 --- /dev/null +++ b/src/sql/mysql/protocol/LocalInfileRequest.zig @@ -0,0 +1,22 @@ +const LocalInfileRequest = @This(); +filename: Data = .{ .empty = {} }, +packet_size: u24, +pub fn deinit(this: *LocalInfileRequest) void { + this.filename.deinit(); +} + +pub fn decodeInternal(this: *LocalInfileRequest, comptime Context: type, reader: NewReader(Context)) !void { + const header = try reader.int(u8); + if (header != 0xFB) { + return error.InvalidLocalInfileRequest; + } + + this.filename = try reader.read(this.packet_size - 1); +} + +pub const decode = decoderWrap(LocalInfileRequest, decodeInternal).decode; + +const Data = @import("../../shared/Data.zig").Data; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; diff --git a/src/sql/mysql/protocol/NewReader.zig b/src/sql/mysql/protocol/NewReader.zig new file mode 100644 index 0000000000..ba7e4a3405 --- /dev/null +++ b/src/sql/mysql/protocol/NewReader.zig @@ -0,0 +1,136 @@ +pub fn NewReaderWrap( + comptime Context: type, + comptime markMessageStartFn_: (fn (ctx: Context) void), + comptime peekFn_: (fn (ctx: Context) []const u8), + comptime skipFn_: (fn (ctx: Context, count: isize) void), + comptime ensureCapacityFn_: (fn (ctx: Context, count: usize) bool), + comptime readFunction_: (fn (ctx: Context, count: usize) AnyMySQLError.Error!Data), + comptime readZ_: (fn (ctx: Context) AnyMySQLError.Error!Data), + comptime setOffsetFromStart_: (fn (ctx: Context, offset: usize) void), +) type { + return struct { + wrapped: Context, + const readFn = readFunction_; + const readZFn = readZ_; + const ensureCapacityFn = ensureCapacityFn_; + const skipFn = skipFn_; + const peekFn = peekFn_; + const markMessageStartFn = markMessageStartFn_; + const setOffsetFromStartFn = setOffsetFromStart_; + pub const Ctx = Context; + + pub const is_wrapped = true; + + pub fn markMessageStart(this: @This()) void { + markMessageStartFn(this.wrapped); + } + + pub fn setOffsetFromStart(this: @This(), offset: usize) void { + return setOffsetFromStartFn(this.wrapped, offset); + } + + pub fn read(this: @This(), count: usize) AnyMySQLError.Error!Data { + return readFn(this.wrapped, count); + } + + pub fn skip(this: @This(), count: anytype) void { + skipFn(this.wrapped, @as(isize, @intCast(count))); + } + + pub fn peek(this: @This()) []const u8 { + return peekFn(this.wrapped); + } + + pub fn readZ(this: @This()) AnyMySQLError.Error!Data { + return readZFn(this.wrapped); + } + + pub fn byte(this: @This()) AnyMySQLError.Error!u8 { + const data = try this.read(1); + return data.slice()[0]; + } + + pub fn ensureCapacity(this: @This(), count: usize) AnyMySQLError.Error!void { + if (!ensureCapacityFn(this.wrapped, count)) { + return AnyMySQLError.Error.ShortRead; + } + } + + pub fn int(this: @This(), comptime Int: type) AnyMySQLError.Error!Int { + var data = try this.read(@sizeOf(Int)); + defer data.deinit(); + if (comptime Int == u8) { + return @as(Int, data.slice()[0]); + } + const size = @divExact(@typeInfo(Int).int.bits, 8); + return @as(Int, @bitCast(data.slice()[0..size].*)); + } + + pub fn encodeLenString(this: @This()) AnyMySQLError.Error!Data { + if (decodeLengthInt(this.peek())) |result| { + this.skip(result.bytes_read); + return try this.read(@intCast(result.value)); + } + return AnyMySQLError.Error.InvalidEncodedLength; + } + + pub fn rawEncodeLenData(this: @This()) AnyMySQLError.Error!Data { + if (decodeLengthInt(this.peek())) |result| { + return try this.read(@intCast(result.value + result.bytes_read)); + } + return AnyMySQLError.Error.InvalidEncodedLength; + } + + pub fn encodedLenInt(this: @This()) AnyMySQLError.Error!u64 { + if (decodeLengthInt(this.peek())) |result| { + this.skip(result.bytes_read); + return result.value; + } + return AnyMySQLError.Error.InvalidEncodedInteger; + } + + pub fn encodedLenIntWithSize(this: @This(), size: *usize) !u64 { + if (decodeLengthInt(this.peek())) |result| { + this.skip(result.bytes_read); + size.* += result.bytes_read; + return result.value; + } + return error.InvalidEncodedInteger; + } + }; +} + +pub fn NewReader(comptime Context: type) type { + if (@hasDecl(Context, "is_wrapped")) { + return Context; + } + + return NewReaderWrap(Context, Context.markMessageStart, Context.peek, Context.skip, Context.ensureCapacity, Context.read, Context.readZ, Context.setOffsetFromStart); +} + +pub fn decoderWrap(comptime Container: type, comptime decodeFn: anytype) type { + return struct { + pub fn decode(this: *Container, context: anytype) AnyMySQLError.Error!void { + const Context = @TypeOf(context); + if (@hasDecl(Context, "is_wrapped")) { + try decodeFn(this, Context, context); + } else { + try decodeFn(this, Context, .{ .wrapped = context }); + } + } + + pub fn decodeAllocator(this: *Container, allocator: std.mem.Allocator, context: anytype) AnyMySQLError.Error!void { + const Context = @TypeOf(context); + if (@hasDecl(Context, "is_wrapped")) { + try decodeFn(this, allocator, Context, context); + } else { + try decodeFn(this, allocator, Context, .{ .wrapped = context }); + } + } + }; +} + +const AnyMySQLError = @import("./AnyMySQLError.zig"); +const std = @import("std"); +const Data = @import("../../shared/Data.zig").Data; +const decodeLengthInt = @import("./EncodeInt.zig").decodeLengthInt; diff --git a/src/sql/mysql/protocol/NewWriter.zig b/src/sql/mysql/protocol/NewWriter.zig new file mode 100644 index 0000000000..8dc35dd525 --- /dev/null +++ b/src/sql/mysql/protocol/NewWriter.zig @@ -0,0 +1,132 @@ +pub fn NewWriterWrap( + comptime Context: type, + comptime offsetFn_: (fn (ctx: Context) usize), + comptime writeFunction_: (fn (ctx: Context, bytes: []const u8) AnyMySQLError.Error!void), + comptime pwriteFunction_: (fn (ctx: Context, bytes: []const u8, offset: usize) AnyMySQLError.Error!void), +) type { + return struct { + wrapped: Context, + + const writeFn = writeFunction_; + const pwriteFn = pwriteFunction_; + const offsetFn = offsetFn_; + pub const Ctx = Context; + + pub const is_wrapped = true; + + pub const WrappedWriter = @This(); + + pub inline fn writeLengthEncodedInt(this: @This(), data: u64) AnyMySQLError.Error!void { + try writeFn(this.wrapped, encodeLengthInt(data).slice()); + } + + pub inline fn writeLengthEncodedString(this: @This(), data: []const u8) AnyMySQLError.Error!void { + try this.writeLengthEncodedInt(data.len); + try writeFn(this.wrapped, data); + } + + pub fn write(this: @This(), data: []const u8) AnyMySQLError.Error!void { + try writeFn(this.wrapped, data); + } + + const Packet = struct { + header: PacketHeader, + offset: usize, + ctx: WrappedWriter, + + pub fn end(this: *@This()) AnyMySQLError.Error!void { + const new_offset = offsetFn(this.ctx.wrapped); + // fix position for packet header + const length = new_offset - this.offset - PacketHeader.size; + this.header.length = @intCast(length); + debug("writing packet header: {d}", .{this.header.length}); + try pwrite(this.ctx, &this.header.encode(), this.offset); + } + }; + + pub fn start(this: @This(), sequence_id: u8) AnyMySQLError.Error!Packet { + const o = offsetFn(this.wrapped); + debug("starting packet: {d}", .{o}); + try this.write(&[_]u8{0} ** PacketHeader.size); + return .{ + .header = .{ .sequence_id = sequence_id, .length = 0 }, + .offset = o, + .ctx = this, + }; + } + + pub fn offset(this: @This()) usize { + return offsetFn(this.wrapped); + } + + pub fn pwrite(this: @This(), data: []const u8, i: usize) AnyMySQLError.Error!void { + try pwriteFn(this.wrapped, data, i); + } + + pub fn int4(this: @This(), value: MySQLInt32) AnyMySQLError.Error!void { + try this.write(&std.mem.toBytes(value)); + } + + pub fn int8(this: @This(), value: MySQLInt64) AnyMySQLError.Error!void { + try this.write(&std.mem.toBytes(value)); + } + + pub fn int1(this: @This(), value: u8) AnyMySQLError.Error!void { + try this.write(&[_]u8{value}); + } + + pub fn writeZ(this: @This(), value: []const u8) AnyMySQLError.Error!void { + try this.write(value); + if (value.len == 0 or value[value.len - 1] != 0) + try this.write(&[_]u8{0}); + } + + pub fn String(this: @This(), value: bun.String) AnyMySQLError.Error!void { + if (value.isEmpty()) { + try this.write(&[_]u8{0}); + return; + } + + var sliced = value.toUTF8(bun.default_allocator); + defer sliced.deinit(); + const slice = sliced.slice(); + + try this.write(slice); + if (slice.len == 0 or slice[slice.len - 1] != 0) + try this.write(&[_]u8{0}); + } + }; +} + +pub fn NewWriter(comptime Context: type) type { + if (@hasDecl(Context, "is_wrapped")) { + return Context; + } + + return NewWriterWrap(Context, Context.offset, Context.write, Context.pwrite); +} + +pub fn writeWrap(comptime Container: type, comptime writeFn: anytype) type { + return struct { + pub fn write(this: *Container, context: anytype) AnyMySQLError.Error!void { + const Context = @TypeOf(context); + if (@hasDecl(Context, "is_wrapped")) { + try writeFn(this, Context, context); + } else { + try writeFn(this, Context, .{ .wrapped = context }); + } + } + }; +} + +const debug = bun.Output.scoped(.NewWriter, .hidden); + +const AnyMySQLError = @import("./AnyMySQLError.zig"); +const PacketHeader = @import("./PacketHeader.zig"); +const bun = @import("bun"); +const std = @import("std"); +const encodeLengthInt = @import("./EncodeInt.zig").encodeLengthInt; + +const types = @import("../MySQLTypes.zig"); +const MySQLInt32 = types.MySQLInt32; +const MySQLInt64 = types.MySQLInt64; diff --git a/src/sql/mysql/protocol/OKPacket.zig b/src/sql/mysql/protocol/OKPacket.zig new file mode 100644 index 0000000000..d9483d6b8b --- /dev/null +++ b/src/sql/mysql/protocol/OKPacket.zig @@ -0,0 +1,49 @@ +// OK Packet +const OKPacket = @This(); +header: u8 = 0x00, +affected_rows: u64 = 0, +last_insert_id: u64 = 0, +status_flags: StatusFlags = .{}, +warnings: u16 = 0, +info: Data = .{ .empty = {} }, +session_state_changes: Data = .{ .empty = {} }, +packet_size: u24, + +pub fn deinit(this: *OKPacket) void { + this.info.deinit(); + this.session_state_changes.deinit(); +} + +pub fn decodeInternal(this: *OKPacket, comptime Context: type, reader: NewReader(Context)) !void { + var read_size: usize = 5; // header + status flags + warnings + this.header = try reader.int(u8); + if (this.header != 0x00 and this.header != 0xfe) { + return error.InvalidOKPacket; + } + + // Affected rows (length encoded integer) + this.affected_rows = try reader.encodedLenIntWithSize(&read_size); + + // Last insert ID (length encoded integer) + this.last_insert_id = try reader.encodedLenIntWithSize(&read_size); + + // Status flags + this.status_flags = StatusFlags.fromInt(try reader.int(u16)); + // Warnings + this.warnings = try reader.int(u16); + + // Info (EOF-terminated string) + if (reader.peek().len > 0) { + // everything else is info + this.info = try reader.read(@truncate(this.packet_size - read_size)); + } +} + +pub const decode = decoderWrap(OKPacket, decodeInternal).decode; + +const Data = @import("../../shared/Data.zig").Data; + +const StatusFlags = @import("../StatusFlags.zig").StatusFlags; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; diff --git a/src/sql/mysql/protocol/PacketHeader.zig b/src/sql/mysql/protocol/PacketHeader.zig new file mode 100644 index 0000000000..f7a6d9be22 --- /dev/null +++ b/src/sql/mysql/protocol/PacketHeader.zig @@ -0,0 +1,25 @@ +const PacketHeader = @This(); +length: u24, +sequence_id: u8, + +pub const size = 4; + +pub fn decode(bytes: []const u8) ?PacketHeader { + if (bytes.len < 4) return null; + + return PacketHeader{ + .length = @as(u24, bytes[0]) | + (@as(u24, bytes[1]) << 8) | + (@as(u24, bytes[2]) << 16), + .sequence_id = bytes[3], + }; +} + +pub fn encode(self: PacketHeader) [4]u8 { + return [4]u8{ + @intCast(self.length & 0xff), + @intCast((self.length >> 8) & 0xff), + @intCast((self.length >> 16) & 0xff), + self.sequence_id, + }; +} diff --git a/src/sql/mysql/protocol/PacketType.zig b/src/sql/mysql/protocol/PacketType.zig new file mode 100644 index 0000000000..e51f9746a8 --- /dev/null +++ b/src/sql/mysql/protocol/PacketType.zig @@ -0,0 +1,14 @@ +pub const PacketType = enum(u8) { + // Server packets + OK = 0x00, + EOF = 0xfe, + ERROR = 0xff, + LOCAL_INFILE = 0xfb, + + // Client/server packets + HANDSHAKE = 0x0a, + MORE_DATA = 0x01, + + _, + pub const AUTH_SWITCH = 0xfe; +}; diff --git a/src/sql/mysql/protocol/PreparedStatement.zig b/src/sql/mysql/protocol/PreparedStatement.zig new file mode 100644 index 0000000000..0ca0810f61 --- /dev/null +++ b/src/sql/mysql/protocol/PreparedStatement.zig @@ -0,0 +1,115 @@ +const PreparedStatement = @This(); + +pub const PrepareOK = struct { + status: u8 = 0, + statement_id: u32, + num_columns: u16, + num_params: u16, + warning_count: u16, + + pub fn decodeInternal(this: *PrepareOK, comptime Context: type, reader: NewReader(Context)) !void { + this.status = try reader.int(u8); + if (this.status != 0) { + return error.InvalidPrepareOKPacket; + } + + this.statement_id = try reader.int(u32); + this.num_columns = try reader.int(u16); + this.num_params = try reader.int(u16); + _ = try reader.int(u8); // reserved_1 + this.warning_count = try reader.int(u16); + } + + pub const decode = decoderWrap(PrepareOK, decodeInternal).decode; +}; + +pub const Execute = struct { + /// ID of the prepared statement to execute, returned from COM_STMT_PREPARE + statement_id: u32, + /// Execution flags. Currently only CURSOR_TYPE_READ_ONLY (0x01) is supported + flags: u8 = 0, + /// Number of times to execute the statement (usually 1) + iteration_count: u32 = 1, + /// Parameter values to bind to the prepared statement + params: []Value = &[_]Value{}, + /// Types of each parameter in the prepared statement + param_types: []const Param, + /// Whether to send parameter types. Set to true for first execution, false for subsequent executions + new_params_bind_flag: bool, + + pub fn deinit(this: *Execute) void { + for (this.params) |*param| { + param.deinit(bun.default_allocator); + } + } + + fn writeNullBitmap(this: *const Execute, comptime Context: type, writer: NewWriter(Context)) AnyMySQLError.Error!void { + const MYSQL_MAX_PARAMS = (std.math.maxInt(u16) / 8) + 1; + + var null_bitmap_buf: [MYSQL_MAX_PARAMS]u8 = undefined; + const bitmap_bytes = (this.params.len + 7) / 8; + const null_bitmap = null_bitmap_buf[0..bitmap_bytes]; + @memset(null_bitmap, 0); + + for (this.params, 0..) |param, i| { + if (param == .null) { + null_bitmap[i >> 3] |= @as(u8, 1) << @as(u3, @truncate(i & 7)); + } + } + + try writer.write(null_bitmap); + } + + pub fn writeInternal(this: *const Execute, comptime Context: type, writer: NewWriter(Context)) AnyMySQLError.Error!void { + try writer.int1(@intFromEnum(CommandType.COM_STMT_EXECUTE)); + try writer.int4(this.statement_id); + try writer.int1(this.flags); + try writer.int4(this.iteration_count); + + if (this.params.len > 0) { + try this.writeNullBitmap(Context, writer); + + // Write new params bind flag + try writer.int1(@intFromBool(this.new_params_bind_flag)); + + if (this.new_params_bind_flag) { + // Write parameter types + for (this.param_types) |param_type| { + debug("New params bind flag {s} unsigned? {}", .{ @tagName(param_type.type), param_type.flags.UNSIGNED }); + try writer.int1(@intFromEnum(param_type.type)); + try writer.int1(if (param_type.flags.UNSIGNED) 0x80 else 0); + } + } + + // Write parameter values + for (this.params, this.param_types) |*param, param_type| { + if (param.* == .null or param_type.type == .MYSQL_TYPE_NULL) continue; + + var value = try param.toData(param_type.type); + defer value.deinit(); + if (param_type.type.isBinaryFormatSupported()) { + try writer.write(value.slice()); + } else { + try writer.writeLengthEncodedString(value.slice()); + } + } + } + } + + pub const write = writeWrap(Execute, writeInternal).write; +}; + +const debug = bun.Output.scoped(.PreparedStatement, .hidden); + +const AnyMySQLError = @import("./AnyMySQLError.zig"); +const bun = @import("bun"); +const std = @import("std"); +const CommandType = @import("./CommandType.zig").CommandType; +const Param = @import("../MySQLStatement.zig").Param; +const Value = @import("../MySQLTypes.zig").Value; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; + +const NewWriter = @import("./NewWriter.zig").NewWriter; +const writeWrap = @import("./NewWriter.zig").writeWrap; diff --git a/src/sql/mysql/protocol/Query.zig b/src/sql/mysql/protocol/Query.zig new file mode 100644 index 0000000000..e6a5cc23eb --- /dev/null +++ b/src/sql/mysql/protocol/Query.zig @@ -0,0 +1,70 @@ +pub const Execute = struct { + query: []const u8, + /// Parameter values to bind to the prepared statement + params: []Data = &[_]Data{}, + /// Types of each parameter in the prepared statement + param_types: []const Param, + + pub fn deinit(this: *Execute) void { + for (this.params) |*param| { + param.deinit(); + } + } + + pub fn writeInternal(this: *const Execute, comptime Context: type, writer: NewWriter(Context)) !void { + var packet = try writer.start(0); + try writer.int1(@intFromEnum(CommandType.COM_QUERY)); + try writer.write(this.query); + + if (this.params.len > 0) { + try writer.writeNullBitmap(this.params); + + // Always 1. Malformed packet error if not 1 + try writer.int1(1); + // if 22 chars = u64 + 2 for :p and this should be more than enough + var param_name_buf: [22]u8 = undefined; + // Write parameter types + for (this.param_types, 1..) |param_type, i| { + debug("New params bind flag {s} unsigned? {}", .{ @tagName(param_type.type), param_type.flags.UNSIGNED }); + try writer.int1(@intFromEnum(param_type.type)); + try writer.int1(if (param_type.flags.UNSIGNED) 0x80 else 0); + const param_name = std.fmt.bufPrint(¶m_name_buf, ":p{d}", .{i}) catch return error.TooManyParameters; + try writer.writeLengthEncodedString(param_name); + } + + // Write parameter values + for (this.params, this.param_types) |*param, param_type| { + if (param.* == .empty or param_type.type == .MYSQL_TYPE_NULL) continue; + + const value = param.slice(); + debug("Write param type {s} len {d} hex {s}", .{ @tagName(param_type.type), value.len, std.fmt.fmtSliceHexLower(value) }); + if (param_type.type.isBinaryFormatSupported()) { + try writer.write(value); + } else { + try writer.writeLengthEncodedString(value); + } + } + } + try packet.end(); + } + + pub const write = writeWrap(Execute, writeInternal).write; +}; + +pub fn execute(query: []const u8, writer: anytype) !void { + var packet = try writer.start(0); + try writer.int1(@intFromEnum(CommandType.COM_QUERY)); + try writer.write(query); + try packet.end(); +} + +const debug = bun.Output.scoped(.MySQLQuery, .visible); + +const bun = @import("bun"); +const std = @import("std"); +const CommandType = @import("./CommandType.zig").CommandType; +const Data = @import("../../shared/Data.zig").Data; +const Param = @import("../MySQLStatement.zig").Param; + +const NewWriter = @import("./NewWriter.zig").NewWriter; +const writeWrap = @import("./NewWriter.zig").writeWrap; diff --git a/src/sql/mysql/protocol/ResultSet.zig b/src/sql/mysql/protocol/ResultSet.zig new file mode 100644 index 0000000000..8e02c95141 --- /dev/null +++ b/src/sql/mysql/protocol/ResultSet.zig @@ -0,0 +1,247 @@ +pub const Header = @import("./ResultSetHeader.zig"); + +pub const Row = struct { + values: []SQLDataCell = &[_]SQLDataCell{}, + columns: []const ColumnDefinition41, + binary: bool = false, + raw: bool = false, + bigint: bool = false, + globalObject: *jsc.JSGlobalObject, + + pub fn toJS(this: *Row, globalObject: *jsc.JSGlobalObject, array: JSValue, structure: JSValue, flags: SQLDataCell.Flags, result_mode: SQLQueryResultMode, cached_structure: ?CachedStructure) JSValue { + var names: ?[*]jsc.JSObject.ExternColumnIdentifier = null; + var names_count: u32 = 0; + if (cached_structure) |c| { + if (c.fields) |f| { + names = f.ptr; + names_count = @truncate(f.len); + } + } + + return SQLDataCell.JSC__constructObjectFromDataCell( + globalObject, + array, + structure, + this.values.ptr, + @truncate(this.values.len), + flags, + @intFromEnum(result_mode), + names, + names_count, + ); + } + + pub fn deinit(this: *Row, allocator: std.mem.Allocator) void { + for (this.values) |*value| { + value.deinit(); + } + allocator.free(this.values); + + // this.columns is intentionally left out. + } + + pub fn decodeInternal(this: *Row, allocator: std.mem.Allocator, comptime Context: type, reader: NewReader(Context)) AnyMySQLError.Error!void { + if (this.binary) { + try this.decodeBinary(allocator, Context, reader); + } else { + try this.decodeText(allocator, Context, reader); + } + } + + fn parseValueAndSetCell(this: *Row, cell: *SQLDataCell, column: *const ColumnDefinition41, value: *const Data) void { + debug("parseValueAndSetCell: {s} {s}", .{ @tagName(column.column_type), value.slice() }); + return switch (column.column_type) { + .MYSQL_TYPE_FLOAT, .MYSQL_TYPE_DOUBLE => { + const val: f64 = bun.parseDouble(value.slice()) catch std.math.nan(f64); + cell.* = SQLDataCell{ .tag = .float8, .value = .{ .float8 = val } }; + }, + .MYSQL_TYPE_TINY => { + const str = value.slice(); + const val: u8 = if (str.len > 0 and (str[0] == '1' or str[0] == 't' or str[0] == 'T')) 1 else 0; + cell.* = SQLDataCell{ .tag = .bool, .value = .{ .bool = val } }; + }, + .MYSQL_TYPE_SHORT => { + if (column.flags.UNSIGNED) { + const val: u16 = std.fmt.parseInt(u16, value.slice(), 10) catch 0; + cell.* = SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = val } }; + } else { + const val: i16 = std.fmt.parseInt(i16, value.slice(), 10) catch 0; + cell.* = SQLDataCell{ .tag = .int4, .value = .{ .int4 = val } }; + } + }, + .MYSQL_TYPE_LONG => { + if (column.flags.UNSIGNED) { + const val: u32 = std.fmt.parseInt(u32, value.slice(), 10) catch 0; + cell.* = SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = val } }; + } else { + const val: i32 = std.fmt.parseInt(i32, value.slice(), 10) catch std.math.minInt(i32); + cell.* = SQLDataCell{ .tag = .int4, .value = .{ .int4 = val } }; + } + }, + .MYSQL_TYPE_LONGLONG => { + if (column.flags.UNSIGNED) { + const val: u64 = std.fmt.parseInt(u64, value.slice(), 10) catch 0; + if (val <= std.math.maxInt(u32)) { + cell.* = SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = @intCast(val) } }; + return; + } + if (this.bigint) { + cell.* = SQLDataCell{ .tag = .uint8, .value = .{ .uint8 = val } }; + return; + } + } else { + const val: i64 = std.fmt.parseInt(i64, value.slice(), 10) catch 0; + if (val >= std.math.minInt(i32) and val <= std.math.maxInt(i32)) { + cell.* = SQLDataCell{ .tag = .int4, .value = .{ .int4 = @intCast(val) } }; + return; + } + if (this.bigint) { + cell.* = SQLDataCell{ .tag = .int8, .value = .{ .int8 = val } }; + return; + } + } + + const slice = value.slice(); + cell.* = SQLDataCell{ .tag = .string, .value = .{ .string = if (slice.len > 0) bun.String.cloneUTF8(slice).value.WTFStringImpl else null }, .free_value = 1 }; + }, + .MYSQL_TYPE_JSON => { + const slice = value.slice(); + cell.* = SQLDataCell{ .tag = .json, .value = .{ .json = if (slice.len > 0) bun.String.cloneUTF8(slice).value.WTFStringImpl else null }, .free_value = 1 }; + }, + + .MYSQL_TYPE_DATE, .MYSQL_TYPE_TIME, .MYSQL_TYPE_DATETIME, .MYSQL_TYPE_TIMESTAMP => { + var str = bun.String.init(value.slice()); + defer str.deref(); + const date = brk: { + break :brk str.parseDate(this.globalObject) catch |err| { + _ = this.globalObject.takeException(err); + break :brk std.math.nan(f64); + }; + }; + cell.* = SQLDataCell{ .tag = .date, .value = .{ .date = date } }; + }, + else => { + const slice = value.slice(); + cell.* = SQLDataCell{ .tag = .string, .value = .{ .string = if (slice.len > 0) bun.String.cloneUTF8(slice).value.WTFStringImpl else null }, .free_value = 1 }; + }, + }; + } + + fn decodeText(this: *Row, allocator: std.mem.Allocator, comptime Context: type, reader: NewReader(Context)) AnyMySQLError.Error!void { + const cells = try allocator.alloc(SQLDataCell, this.columns.len); + @memset(cells, SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }); + errdefer { + for (cells) |*value| { + value.deinit(); + } + allocator.free(cells); + } + + for (cells, 0..) |*value, index| { + if (decodeLengthInt(reader.peek())) |result| { + const column = this.columns[index]; + if (result.value == 0xfb) { + // NULL value + reader.skip(result.bytes_read); + // this dont matter if is raw because we will sent as null too like in postgres + value.* = SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }; + } else { + if (this.raw) { + var data = try reader.rawEncodeLenData(); + defer data.deinit(); + value.* = SQLDataCell.raw(&data); + } else { + reader.skip(result.bytes_read); + var string_data = try reader.read(@intCast(result.value)); + defer string_data.deinit(); + this.parseValueAndSetCell(value, &column, &string_data); + } + } + value.index = switch (column.name_or_index) { + // The indexed columns can be out of order. + .index => |i| i, + + else => @intCast(index), + }; + value.isIndexedColumn = switch (column.name_or_index) { + .duplicate => 2, + .index => 1, + .name => 0, + }; + } else { + return error.InvalidResultRow; + } + } + + this.values = cells; + } + + fn decodeBinary(this: *Row, allocator: std.mem.Allocator, comptime Context: type, reader: NewReader(Context)) AnyMySQLError.Error!void { + // Header + _ = try reader.int(u8); + + // Null bitmap + const bitmap_bytes = (this.columns.len + 7 + 2) / 8; + var null_bitmap = try reader.read(bitmap_bytes); + defer null_bitmap.deinit(); + + const cells = try allocator.alloc(SQLDataCell, this.columns.len); + @memset(cells, SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }); + errdefer { + for (cells) |*value| { + value.deinit(); + } + allocator.free(cells); + } + // Skip first 2 bits of null bitmap (reserved) + const bitmap_offset: usize = 2; + + for (cells, 0..) |*value, i| { + const byte_pos = (bitmap_offset + i) >> 3; + const bit_pos = @as(u3, @truncate((bitmap_offset + i) & 7)); + const is_null = (null_bitmap.slice()[byte_pos] & (@as(u8, 1) << bit_pos)) != 0; + + if (is_null) { + value.* = SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }; + continue; + } + + const column = this.columns[i]; + value.* = try decodeBinaryValue(this.globalObject, column.column_type, this.raw, this.bigint, column.flags.UNSIGNED, Context, reader); + value.index = switch (column.name_or_index) { + // The indexed columns can be out of order. + .index => |idx| idx, + + else => @intCast(i), + }; + value.isIndexedColumn = switch (column.name_or_index) { + .duplicate => 2, + .index => 1, + .name => 0, + }; + } + + this.values = cells; + } + + pub const decode = decoderWrap(Row, decodeInternal).decodeAllocator; +}; + +const debug = bun.Output.scoped(.MySQLResultSet, .visible); + +const AnyMySQLError = @import("./AnyMySQLError.zig"); +const CachedStructure = @import("../../shared/CachedStructure.zig"); +const ColumnDefinition41 = @import("./ColumnDefinition41.zig"); +const bun = @import("bun"); +const std = @import("std"); +const Data = @import("../../shared/Data.zig").Data; +const SQLDataCell = @import("../../shared/SQLDataCell.zig").SQLDataCell; +const SQLQueryResultMode = @import("../../shared/SQLQueryResultMode.zig").SQLQueryResultMode; +const decodeBinaryValue = @import("./DecodeBinaryValue.zig").decodeBinaryValue; +const decodeLengthInt = @import("./EncodeInt.zig").decodeLengthInt; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; + +const jsc = bun.jsc; +const JSValue = jsc.JSValue; diff --git a/src/sql/mysql/protocol/ResultSetHeader.zig b/src/sql/mysql/protocol/ResultSetHeader.zig new file mode 100644 index 0000000000..6a8c99b688 --- /dev/null +++ b/src/sql/mysql/protocol/ResultSetHeader.zig @@ -0,0 +1,12 @@ +const ResultSetHeader = @This(); +field_count: u64 = 0, + +pub fn decodeInternal(this: *ResultSetHeader, comptime Context: type, reader: NewReader(Context)) !void { + // Field count (length encoded integer) + this.field_count = try reader.encodedLenInt(); +} + +pub const decode = decoderWrap(ResultSetHeader, decodeInternal).decode; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; diff --git a/src/sql/mysql/protocol/Signature.zig b/src/sql/mysql/protocol/Signature.zig new file mode 100644 index 0000000000..9bb6c0915d --- /dev/null +++ b/src/sql/mysql/protocol/Signature.zig @@ -0,0 +1,86 @@ +const Signature = @This(); +fields: []Param = &.{}, +name: []const u8 = "", +query: []const u8 = "", + +pub fn empty() Signature { + return Signature{ + .fields = &.{}, + .name = "", + .query = "", + }; +} + +pub fn deinit(this: *Signature) void { + if (this.fields.len > 0) { + bun.default_allocator.free(this.fields); + } + if (this.name.len > 0) { + bun.default_allocator.free(this.name); + } + if (this.query.len > 0) { + bun.default_allocator.free(this.query); + } +} + +pub fn hash(this: *const Signature) u64 { + var hasher = std.hash.Wyhash.init(0); + hasher.update(this.name); + hasher.update(std.mem.sliceAsBytes(this.fields)); + return hasher.final(); +} + +pub fn generate(globalObject: *jsc.JSGlobalObject, query: []const u8, array_value: JSValue, columns: JSValue) !Signature { + var fields = std.ArrayList(Param).init(bun.default_allocator); + var name = try std.ArrayList(u8).initCapacity(bun.default_allocator, query.len); + + name.appendSliceAssumeCapacity(query); + + errdefer { + fields.deinit(); + name.deinit(); + } + + var iter = try QueryBindingIterator.init(array_value, columns, globalObject); + + while (try iter.next()) |value| { + if (value.isEmptyOrUndefinedOrNull()) { + // Allow MySQL to decide the type + try fields.append(.{ .type = .MYSQL_TYPE_NULL, .flags = .{} }); + try name.appendSlice(".null"); + continue; + } + var unsigned = false; + const tag = try types.FieldType.fromJS(globalObject, value, &unsigned); + if (unsigned) { + // 128 is more than enought right now + var tag_name_buf = [_]u8{0} ** 128; + try name.appendSlice(std.fmt.bufPrint(tag_name_buf[0..], "U{s}", .{@tagName(tag)}) catch @tagName(tag)); + } else { + try name.appendSlice(@tagName(tag)); + } + // TODO: add flags if necessary right now the only relevant would be unsigned but is JS and is never unsigned + try fields.append(.{ .type = tag, .flags = .{ .UNSIGNED = unsigned } }); + } + + if (iter.anyFailed()) { + return error.InvalidQueryBinding; + } + + return Signature{ + .name = name.items, + .fields = fields.items, + .query = try bun.default_allocator.dupe(u8, query), + }; +} + +const bun = @import("bun"); +const std = @import("std"); +const Param = @import("../MySQLStatement.zig").Param; +const QueryBindingIterator = @import("../../shared/QueryBindingIterator.zig").QueryBindingIterator; + +const types = @import("../MySQLTypes.zig"); +const FieldType = types.FieldType; + +const jsc = bun.jsc; +const JSValue = jsc.JSValue; diff --git a/src/sql/mysql/protocol/StackReader.zig b/src/sql/mysql/protocol/StackReader.zig new file mode 100644 index 0000000000..ed242270bc --- /dev/null +++ b/src/sql/mysql/protocol/StackReader.zig @@ -0,0 +1,78 @@ +const StackReader = @This(); +buffer: []const u8 = "", +offset: *usize, +message_start: *usize, + +pub fn markMessageStart(this: @This()) void { + this.message_start.* = this.offset.*; +} +pub fn setOffsetFromStart(this: @This(), offset: usize) void { + this.offset.* = this.message_start.* + offset; +} + +pub fn ensureCapacity(this: @This(), length: usize) bool { + return this.buffer.len >= (this.offset.* + length); +} + +pub fn init(buffer: []const u8, offset: *usize, message_start: *usize) NewReader(StackReader) { + return .{ + .wrapped = .{ + .buffer = buffer, + .offset = offset, + .message_start = message_start, + }, + }; +} + +pub fn peek(this: StackReader) []const u8 { + return this.buffer[this.offset.*..]; +} + +pub fn skip(this: StackReader, count: isize) void { + if (count < 0) { + const abs_count = @abs(count); + if (abs_count > this.offset.*) { + this.offset.* = 0; + return; + } + this.offset.* -= @intCast(abs_count); + return; + } + + const ucount: usize = @intCast(count); + if (this.offset.* + ucount > this.buffer.len) { + this.offset.* = this.buffer.len; + return; + } + + this.offset.* += ucount; +} + +pub fn read(this: StackReader, count: usize) AnyMySQLError.Error!Data { + const offset = this.offset.*; + if (!this.ensureCapacity(count)) { + return AnyMySQLError.Error.ShortRead; + } + + this.skip(@intCast(count)); + return Data{ + .temporary = this.buffer[offset..this.offset.*], + }; +} + +pub fn readZ(this: StackReader) AnyMySQLError.Error!Data { + const remaining = this.peek(); + if (bun.strings.indexOfChar(remaining, 0)) |zero| { + this.skip(@intCast(zero + 1)); + return Data{ + .temporary = remaining[0..zero], + }; + } + + return error.ShortRead; +} + +const AnyMySQLError = @import("./AnyMySQLError.zig"); +const bun = @import("bun"); +const Data = @import("../../shared/Data.zig").Data; +const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/mysql/protocol/StmtPrepareOKPacket.zig b/src/sql/mysql/protocol/StmtPrepareOKPacket.zig new file mode 100644 index 0000000000..0238021ce1 --- /dev/null +++ b/src/sql/mysql/protocol/StmtPrepareOKPacket.zig @@ -0,0 +1,26 @@ +const StmtPrepareOKPacket = @This(); +status: u8 = 0, +statement_id: u32 = 0, +num_columns: u16 = 0, +num_params: u16 = 0, +warning_count: u16 = 0, +packet_length: u24, +pub fn decodeInternal(this: *StmtPrepareOKPacket, comptime Context: type, reader: NewReader(Context)) !void { + this.status = try reader.int(u8); + if (this.status != 0) { + return error.InvalidPrepareOKPacket; + } + + this.statement_id = try reader.int(u32); + this.num_columns = try reader.int(u16); + this.num_params = try reader.int(u16); + _ = try reader.int(u8); // reserved_1 + if (this.packet_length >= 12) { + this.warning_count = try reader.int(u16); + } +} + +pub const decode = decoderWrap(StmtPrepareOKPacket, decodeInternal).decode; + +const NewReader = @import("./NewReader.zig").NewReader; +const decoderWrap = @import("./NewReader.zig").decoderWrap; diff --git a/src/sql/postgres/AnyPostgresError.zig b/src/sql/postgres/AnyPostgresError.zig index 7f79945cea..f2044b732e 100644 --- a/src/sql/postgres/AnyPostgresError.zig +++ b/src/sql/postgres/AnyPostgresError.zig @@ -59,44 +59,20 @@ pub fn createPostgresError( message: []const u8, options: PostgresErrorOptions, ) bun.JSError!JSValue { - const bun_ns = (try globalObject.toJSValue().get(globalObject, "Bun")).?; - const sql_constructor = (try bun_ns.get(globalObject, "SQL")).?; - const pg_error_constructor = (try sql_constructor.get(globalObject, "PostgresError")).?; - - const opts_obj = JSValue.createEmptyObject(globalObject, 0); - opts_obj.put(globalObject, jsc.ZigString.static("code"), jsc.ZigString.init(options.code).toJS(globalObject)); - - if (options.errno) |errno| opts_obj.put(globalObject, jsc.ZigString.static("errno"), jsc.ZigString.init(errno).toJS(globalObject)); - if (options.detail) |detail| opts_obj.put(globalObject, jsc.ZigString.static("detail"), jsc.ZigString.init(detail).toJS(globalObject)); - if (options.hint) |hint| opts_obj.put(globalObject, jsc.ZigString.static("hint"), jsc.ZigString.init(hint).toJS(globalObject)); - if (options.severity) |severity| opts_obj.put(globalObject, jsc.ZigString.static("severity"), jsc.ZigString.init(severity).toJS(globalObject)); - if (options.position) |pos| opts_obj.put(globalObject, jsc.ZigString.static("position"), jsc.ZigString.init(pos).toJS(globalObject)); - if (options.internalPosition) |pos| opts_obj.put(globalObject, jsc.ZigString.static("internalPosition"), jsc.ZigString.init(pos).toJS(globalObject)); - if (options.internalQuery) |query| opts_obj.put(globalObject, jsc.ZigString.static("internalQuery"), jsc.ZigString.init(query).toJS(globalObject)); - if (options.where) |w| opts_obj.put(globalObject, jsc.ZigString.static("where"), jsc.ZigString.init(w).toJS(globalObject)); - if (options.schema) |s| opts_obj.put(globalObject, jsc.ZigString.static("schema"), jsc.ZigString.init(s).toJS(globalObject)); - if (options.table) |t| opts_obj.put(globalObject, jsc.ZigString.static("table"), jsc.ZigString.init(t).toJS(globalObject)); - if (options.column) |c| opts_obj.put(globalObject, jsc.ZigString.static("column"), jsc.ZigString.init(c).toJS(globalObject)); - if (options.dataType) |dt| opts_obj.put(globalObject, jsc.ZigString.static("dataType"), jsc.ZigString.init(dt).toJS(globalObject)); - if (options.constraint) |c| opts_obj.put(globalObject, jsc.ZigString.static("constraint"), jsc.ZigString.init(c).toJS(globalObject)); - if (options.file) |f| opts_obj.put(globalObject, jsc.ZigString.static("file"), jsc.ZigString.init(f).toJS(globalObject)); - if (options.line) |l| opts_obj.put(globalObject, jsc.ZigString.static("line"), jsc.ZigString.init(l).toJS(globalObject)); - if (options.routine) |r| opts_obj.put(globalObject, jsc.ZigString.static("routine"), jsc.ZigString.init(r).toJS(globalObject)); - - const args = [_]JSValue{ - jsc.ZigString.init(message).toJS(globalObject), - opts_obj, - }; - - const JSC = @import("../../bun.js/javascript_core_c_api.zig"); - var exception: JSC.JSValueRef = null; - const result = JSC.JSObjectCallAsConstructor(globalObject, pg_error_constructor.asObjectRef(), args.len, @ptrCast(&args), &exception); - - if (exception != null) { - return bun.JSError.JSError; + const opts_obj = JSValue.createEmptyObject(globalObject, 18); + opts_obj.ensureStillAlive(); + opts_obj.put(globalObject, jsc.ZigString.static("code"), try bun.String.createUTF8ForJS(globalObject, options.code)); + inline for (std.meta.fields(PostgresErrorOptions)) |field| { + const FieldType = @typeInfo(@TypeOf(@field(options, field.name))); + if (FieldType == .optional) { + if (@field(options, field.name)) |value| { + opts_obj.put(globalObject, jsc.ZigString.static(field.name), try bun.String.createUTF8ForJS(globalObject, value)); + } + } } + opts_obj.put(globalObject, jsc.ZigString.static("message"), try bun.String.createUTF8ForJS(globalObject, message)); - return JSValue.fromRef(result); + return opts_obj; } pub fn postgresErrorToJS(globalObject: *jsc.JSGlobalObject, message: ?[]const u8, err: AnyPostgresError) JSValue { @@ -142,10 +118,8 @@ pub fn postgresErrorToJS(globalObject: *jsc.JSGlobalObject, message: ?[]const u8 }, }; - const msg = message orelse std.fmt.allocPrint(bun.default_allocator, "Failed to bind query: {s}", .{@errorName(err)}) catch unreachable; - defer { - if (message == null) bun.default_allocator.free(msg); - } + var buffer_message = [_]u8{0} ** 256; + const msg = message orelse std.fmt.bufPrint(buffer_message[0..], "Failed to bind query: {s}", .{@errorName(err)}) catch "Failed to bind query"; return createPostgresError(globalObject, msg, .{ .code = code }) catch |e| globalObject.takeError(e); } diff --git a/src/sql/postgres/DataCell.zig b/src/sql/postgres/DataCell.zig index e7e219e942..e4d51ddacf 100644 --- a/src/sql/postgres/DataCell.zig +++ b/src/sql/postgres/DataCell.zig @@ -1,1113 +1,961 @@ -pub const DataCell = extern struct { - tag: Tag, +pub const SQLDataCell = @import("../shared/SQLDataCell.zig").SQLDataCell; - value: Value, - free_value: u8 = 0, - isIndexedColumn: u8 = 0, - index: u32 = 0, +fn parseBytea(hex: []const u8) !SQLDataCell { + const len = hex.len / 2; + const buf = try bun.default_allocator.alloc(u8, len); + errdefer bun.default_allocator.free(buf); - pub const Tag = enum(u8) { - null = 0, - string = 1, - float8 = 2, - int4 = 3, - int8 = 4, - bool = 5, - date = 6, - date_with_time_zone = 7, - bytea = 8, - json = 9, - array = 10, - typed_array = 11, - raw = 12, - uint4 = 13, - }; - - pub const Value = extern union { - null: u8, - string: ?bun.WTF.StringImpl, - float8: f64, - int4: i32, - int8: i64, - bool: u8, - date: f64, - date_with_time_zone: f64, - bytea: [2]usize, - json: ?bun.WTF.StringImpl, - array: Array, - typed_array: TypedArray, - raw: Raw, - uint4: u32, - }; - - pub const Array = extern struct { - ptr: ?[*]DataCell = null, - len: u32, - cap: u32, - pub fn slice(this: *Array) []DataCell { - const ptr = this.ptr orelse return &.{}; - return ptr[0..this.len]; - } - - pub fn allocatedSlice(this: *Array) []DataCell { - const ptr = this.ptr orelse return &.{}; - return ptr[0..this.cap]; - } - - pub fn deinit(this: *Array) void { - const allocated = this.allocatedSlice(); - this.ptr = null; - this.len = 0; - this.cap = 0; - bun.default_allocator.free(allocated); - } - }; - pub const Raw = extern struct { - ptr: ?[*]const u8 = null, - len: u64, - }; - pub const TypedArray = extern struct { - head_ptr: ?[*]u8 = null, - ptr: ?[*]u8 = null, - len: u32, - byte_len: u32, - type: JSValue.JSType, - - pub fn slice(this: *TypedArray) []u8 { - const ptr = this.ptr orelse return &.{}; - return ptr[0..this.len]; - } - - pub fn byteSlice(this: *TypedArray) []u8 { - const ptr = this.head_ptr orelse return &.{}; - return ptr[0..this.len]; - } - }; - - pub fn deinit(this: *DataCell) void { - if (this.free_value == 0) return; - - switch (this.tag) { - .string => { - if (this.value.string) |str| { - str.deref(); - } - }, - .json => { - if (this.value.json) |str| { - str.deref(); - } - }, - .bytea => { - if (this.value.bytea[1] == 0) return; - const slice = @as([*]u8, @ptrFromInt(this.value.bytea[0]))[0..this.value.bytea[1]]; - bun.default_allocator.free(slice); - }, - .array => { - for (this.value.array.slice()) |*cell| { - cell.deinit(); - } - this.value.array.deinit(); - }, - .typed_array => { - bun.default_allocator.free(this.value.typed_array.byteSlice()); + return SQLDataCell{ + .tag = .bytea, + .value = .{ + .bytea = .{ + @intFromPtr(buf.ptr), + try bun.strings.decodeHexToBytes(buf, u8, hex), }, + }, + .free_value = 1, + }; +} - else => {}, - } - } - pub fn raw(optional_bytes: ?*Data) DataCell { - if (optional_bytes) |bytes| { - const bytes_slice = bytes.slice(); - return DataCell{ - .tag = .raw, - .value = .{ .raw = .{ .ptr = @ptrCast(bytes_slice.ptr), .len = bytes_slice.len } }, - }; - } - // TODO: check empty and null fields - return DataCell{ - .tag = .null, - .value = .{ .null = 0 }, - }; - } +fn unescapePostgresString(input: []const u8, buffer: []u8) ![]u8 { + var out_index: usize = 0; + var i: usize = 0; - fn parseBytea(hex: []const u8) !DataCell { - const len = hex.len / 2; - const buf = try bun.default_allocator.alloc(u8, len); - errdefer bun.default_allocator.free(buf); + while (i < input.len) : (i += 1) { + if (out_index >= buffer.len) return error.BufferTooSmall; - return DataCell{ - .tag = .bytea, - .value = .{ - .bytea = .{ - @intFromPtr(buf.ptr), - try bun.strings.decodeHexToBytes(buf, u8, hex), + if (input[i] == '\\' and i + 1 < input.len) { + i += 1; + switch (input[i]) { + // Common escapes + 'b' => buffer[out_index] = '\x08', // Backspace + 'f' => buffer[out_index] = '\x0C', // Form feed + 'n' => buffer[out_index] = '\n', // Line feed + 'r' => buffer[out_index] = '\r', // Carriage return + 't' => buffer[out_index] = '\t', // Tab + '"' => buffer[out_index] = '"', // Double quote + '\\' => buffer[out_index] = '\\', // Backslash + '\'' => buffer[out_index] = '\'', // Single quote + + // JSON allows forward slash escaping + '/' => buffer[out_index] = '/', + + // PostgreSQL hex escapes (used for unicode too) + 'x' => { + if (i + 2 >= input.len) return error.InvalidEscapeSequence; + const hex_value = try std.fmt.parseInt(u8, input[i + 1 .. i + 3], 16); + buffer[out_index] = hex_value; + i += 2; }, - }, - .free_value = 1, - }; + + else => return error.UnknownEscapeSequence, + } + } else { + buffer[out_index] = input[i]; + } + out_index += 1; } - fn unescapePostgresString(input: []const u8, buffer: []u8) ![]u8 { - var out_index: usize = 0; - var i: usize = 0; + return buffer[0..out_index]; +} +fn trySlice(slice: []const u8, count: usize) []const u8 { + if (slice.len <= count) return ""; + return slice[count..]; +} +fn parseArray(bytes: []const u8, bigint: bool, comptime arrayType: types.Tag, globalObject: *jsc.JSGlobalObject, offset: ?*usize, comptime is_json_sub_array: bool) !SQLDataCell { + const closing_brace = if (is_json_sub_array) ']' else '}'; + const opening_brace = if (is_json_sub_array) '[' else '{'; + if (bytes.len < 2 or bytes[0] != opening_brace) { + return error.UnsupportedArrayFormat; + } + // empty array + if (bytes.len == 2 and bytes[1] == closing_brace) { + if (offset) |offset_ptr| { + offset_ptr.* = 2; + } + return SQLDataCell{ .tag = .array, .value = .{ .array = .{ .ptr = null, .len = 0, .cap = 0 } } }; + } - while (i < input.len) : (i += 1) { - if (out_index >= buffer.len) return error.BufferTooSmall; + var array = std.ArrayListUnmanaged(SQLDataCell){}; + var stack_buffer: [16 * 1024]u8 = undefined; - if (input[i] == '\\' and i + 1 < input.len) { - i += 1; - switch (input[i]) { - // Common escapes - 'b' => buffer[out_index] = '\x08', // Backspace - 'f' => buffer[out_index] = '\x0C', // Form feed - 'n' => buffer[out_index] = '\n', // Line feed - 'r' => buffer[out_index] = '\r', // Carriage return - 't' => buffer[out_index] = '\t', // Tab - '"' => buffer[out_index] = '"', // Double quote - '\\' => buffer[out_index] = '\\', // Backslash - '\'' => buffer[out_index] = '\'', // Single quote - - // JSON allows forward slash escaping - '/' => buffer[out_index] = '/', - - // PostgreSQL hex escapes (used for unicode too) - 'x' => { - if (i + 2 >= input.len) return error.InvalidEscapeSequence; - const hex_value = try std.fmt.parseInt(u8, input[i + 1 .. i + 3], 16); - buffer[out_index] = hex_value; - i += 2; - }, - - else => return error.UnknownEscapeSequence, + errdefer { + if (array.capacity > 0) array.deinit(bun.default_allocator); + } + var slice = bytes[1..]; + var reached_end = false; + const separator = switch (arrayType) { + .box_array => ';', + else => ',', + }; + while (slice.len > 0) { + switch (slice[0]) { + closing_brace => { + if (reached_end) { + // cannot reach end twice + return error.UnsupportedArrayFormat; } - } else { - buffer[out_index] = input[i]; - } - out_index += 1; - } - - return buffer[0..out_index]; - } - fn trySlice(slice: []const u8, count: usize) []const u8 { - if (slice.len <= count) return ""; - return slice[count..]; - } - fn parseArray(bytes: []const u8, bigint: bool, comptime arrayType: types.Tag, globalObject: *jsc.JSGlobalObject, offset: ?*usize, comptime is_json_sub_array: bool) !DataCell { - const closing_brace = if (is_json_sub_array) ']' else '}'; - const opening_brace = if (is_json_sub_array) '[' else '{'; - if (bytes.len < 2 or bytes[0] != opening_brace) { - return error.UnsupportedArrayFormat; - } - // empty array - if (bytes.len == 2 and bytes[1] == closing_brace) { - if (offset) |offset_ptr| { - offset_ptr.* = 2; - } - return DataCell{ .tag = .array, .value = .{ .array = .{ .ptr = null, .len = 0, .cap = 0 } } }; - } - - var array = std.ArrayListUnmanaged(DataCell){}; - var stack_buffer: [16 * 1024]u8 = undefined; - - errdefer { - if (array.capacity > 0) array.deinit(bun.default_allocator); - } - var slice = bytes[1..]; - var reached_end = false; - const separator = switch (arrayType) { - .box_array => ';', - else => ',', - }; - while (slice.len > 0) { - switch (slice[0]) { - closing_brace => { - if (reached_end) { - // cannot reach end twice - return error.UnsupportedArrayFormat; + // end of array + reached_end = true; + slice = trySlice(slice, 1); + break; + }, + opening_brace => { + var sub_array_offset: usize = 0; + const sub_array = try parseArray(slice, bigint, arrayType, globalObject, &sub_array_offset, is_json_sub_array); + try array.append(bun.default_allocator, sub_array); + slice = trySlice(slice, sub_array_offset); + continue; + }, + '"' => { + // parse string + var current_idx: usize = 0; + const source = slice[1..]; + // simple escape check to avoid something like "\\\\" and "\"" + var is_escaped = false; + for (source, 0..source.len) |byte, index| { + if (byte == '"' and !is_escaped) { + current_idx = index + 1; + break; } - // end of array - reached_end = true; - slice = trySlice(slice, 1); - break; - }, - opening_brace => { - var sub_array_offset: usize = 0; - const sub_array = try parseArray(slice, bigint, arrayType, globalObject, &sub_array_offset, is_json_sub_array); - try array.append(bun.default_allocator, sub_array); - slice = trySlice(slice, sub_array_offset); - continue; - }, - '"' => { - // parse string - var current_idx: usize = 0; - const source = slice[1..]; - // simple escape check to avoid something like "\\\\" and "\"" - var is_escaped = false; - for (source, 0..source.len) |byte, index| { - if (byte == '"' and !is_escaped) { - current_idx = index + 1; - break; + is_escaped = !is_escaped and byte == '\\'; + } + // did not find a closing quote + if (current_idx == 0) return error.UnsupportedArrayFormat; + switch (arrayType) { + .bytea_array => { + // this is a bytea array so we need to parse the bytea strings + const bytea_bytes = slice[1..current_idx]; + if (bun.strings.startsWith(bytea_bytes, "\\\\x")) { + // its a bytea string lets parse it as a bytea + try array.append(bun.default_allocator, try parseBytea(bytea_bytes[3..][0 .. bytea_bytes.len - 3])); + slice = trySlice(slice, current_idx + 1); + continue; } - is_escaped = !is_escaped and byte == '\\'; - } - // did not find a closing quote - if (current_idx == 0) return error.UnsupportedArrayFormat; - switch (arrayType) { - .bytea_array => { - // this is a bytea array so we need to parse the bytea strings - const bytea_bytes = slice[1..current_idx]; - if (bun.strings.startsWith(bytea_bytes, "\\\\x")) { - // its a bytea string lets parse it as a bytea - try array.append(bun.default_allocator, try parseBytea(bytea_bytes[3..][0 .. bytea_bytes.len - 3])); - slice = trySlice(slice, current_idx + 1); - continue; - } - // invalid bytea array - return error.UnsupportedByteaFormat; - }, - .timestamptz_array, - .timestamp_array, - .date_array, - => { - const date_str = slice[1..current_idx]; - var str = bun.String.init(date_str); - defer str.deref(); - try array.append(bun.default_allocator, DataCell{ .tag = .date, .value = .{ .date = try str.parseDate(globalObject) } }); + // invalid bytea array + return error.UnsupportedByteaFormat; + }, + .timestamptz_array, + .timestamp_array, + .date_array, + => { + const date_str = slice[1..current_idx]; + var str = bun.String.init(date_str); + defer str.deref(); + try array.append(bun.default_allocator, SQLDataCell{ .tag = .date, .value = .{ .date = try str.parseDate(globalObject) } }); - slice = trySlice(slice, current_idx + 1); - continue; - }, - .json_array, - .jsonb_array, - => { - const str_bytes = slice[1..current_idx]; - const needs_dynamic_buffer = str_bytes.len < stack_buffer.len; - const buffer = if (needs_dynamic_buffer) try bun.default_allocator.alloc(u8, str_bytes.len) else stack_buffer[0..]; - defer if (needs_dynamic_buffer) bun.default_allocator.free(buffer); - const unescaped = unescapePostgresString(str_bytes, buffer) catch return error.InvalidByteSequence; - try array.append(bun.default_allocator, DataCell{ .tag = .json, .value = .{ .json = if (unescaped.len > 0) String.cloneUTF8(unescaped).value.WTFStringImpl else null }, .free_value = 1 }); - slice = trySlice(slice, current_idx + 1); - continue; - }, - else => {}, - } - const str_bytes = slice[1..current_idx]; - if (str_bytes.len == 0) { - // empty string - try array.append(bun.default_allocator, DataCell{ .tag = .string, .value = .{ .string = null }, .free_value = 1 }); slice = trySlice(slice, current_idx + 1); continue; - } - const needs_dynamic_buffer = str_bytes.len < stack_buffer.len; - const buffer = if (needs_dynamic_buffer) try bun.default_allocator.alloc(u8, str_bytes.len) else stack_buffer[0..]; - defer if (needs_dynamic_buffer) bun.default_allocator.free(buffer); - const string_bytes = unescapePostgresString(str_bytes, buffer) catch return error.InvalidByteSequence; - try array.append(bun.default_allocator, DataCell{ .tag = .string, .value = .{ .string = if (string_bytes.len > 0) String.cloneUTF8(string_bytes).value.WTFStringImpl else null }, .free_value = 1 }); - + }, + .json_array, + .jsonb_array, + => { + const str_bytes = slice[1..current_idx]; + const needs_dynamic_buffer = str_bytes.len < stack_buffer.len; + const buffer = if (needs_dynamic_buffer) try bun.default_allocator.alloc(u8, str_bytes.len) else stack_buffer[0..]; + defer if (needs_dynamic_buffer) bun.default_allocator.free(buffer); + const unescaped = unescapePostgresString(str_bytes, buffer) catch return error.InvalidByteSequence; + try array.append(bun.default_allocator, SQLDataCell{ .tag = .json, .value = .{ .json = if (unescaped.len > 0) String.cloneUTF8(unescaped).value.WTFStringImpl else null }, .free_value = 1 }); + slice = trySlice(slice, current_idx + 1); + continue; + }, + else => {}, + } + const str_bytes = slice[1..current_idx]; + if (str_bytes.len == 0) { + // empty string + try array.append(bun.default_allocator, SQLDataCell{ .tag = .string, .value = .{ .string = null }, .free_value = 1 }); slice = trySlice(slice, current_idx + 1); continue; - }, - separator => { - // next element or positive number, just advance - slice = trySlice(slice, 1); - continue; - }, - else => { - switch (arrayType) { - // timez, date, time, interval are handled like single string cases - .timetz_array, - .date_array, - .time_array, - .interval_array, - // text array types - .bpchar_array, - .varchar_array, - .char_array, - .text_array, - .name_array, - .numeric_array, - .money_array, - .varbit_array, - .int2vector_array, - .bit_array, - .path_array, - .xml_array, - .point_array, - .lseg_array, - .box_array, - .polygon_array, - .line_array, - .cidr_array, - .circle_array, - .macaddr8_array, - .macaddr_array, - .inet_array, - .aclitem_array, - .pg_database_array, - .pg_database_array2, - => { - // this is also a string until we reach "," or "}" but a single word string like Bun - var current_idx: usize = 0; + } + const needs_dynamic_buffer = str_bytes.len < stack_buffer.len; + const buffer = if (needs_dynamic_buffer) try bun.default_allocator.alloc(u8, str_bytes.len) else stack_buffer[0..]; + defer if (needs_dynamic_buffer) bun.default_allocator.free(buffer); + const string_bytes = unescapePostgresString(str_bytes, buffer) catch return error.InvalidByteSequence; + try array.append(bun.default_allocator, SQLDataCell{ .tag = .string, .value = .{ .string = if (string_bytes.len > 0) String.cloneUTF8(string_bytes).value.WTFStringImpl else null }, .free_value = 1 }); - for (slice, 0..slice.len) |byte, index| { - switch (byte) { - '}', separator => { - current_idx = index; - break; - }, - else => {}, - } - } - if (current_idx == 0) return error.UnsupportedArrayFormat; - const element = slice[0..current_idx]; - // lets handle NULL case here, if is a string "NULL" it will have quotes, if its a NULL it will be just NULL - if (bun.strings.eqlComptime(element, "NULL")) { - try array.append(bun.default_allocator, DataCell{ .tag = .null, .value = .{ .null = 0 } }); - slice = trySlice(slice, current_idx); - continue; - } - if (arrayType == .date_array) { - var str = bun.String.init(element); - defer str.deref(); - try array.append(bun.default_allocator, DataCell{ .tag = .date, .value = .{ .date = try str.parseDate(globalObject) } }); - } else { - // the only escape sequency possible here is \b - if (bun.strings.eqlComptime(element, "\\b")) { - try array.append(bun.default_allocator, DataCell{ .tag = .string, .value = .{ .string = bun.String.cloneUTF8("\x08").value.WTFStringImpl }, .free_value = 1 }); - } else { - try array.append(bun.default_allocator, DataCell{ .tag = .string, .value = .{ .string = if (element.len > 0) bun.String.cloneUTF8(element).value.WTFStringImpl else null }, .free_value = 0 }); - } + slice = trySlice(slice, current_idx + 1); + continue; + }, + separator => { + // next element or positive number, just advance + slice = trySlice(slice, 1); + continue; + }, + else => { + switch (arrayType) { + // timez, date, time, interval are handled like single string cases + .timetz_array, + .date_array, + .time_array, + .interval_array, + // text array types + .bpchar_array, + .varchar_array, + .char_array, + .text_array, + .name_array, + .numeric_array, + .money_array, + .varbit_array, + .int2vector_array, + .bit_array, + .path_array, + .xml_array, + .point_array, + .lseg_array, + .box_array, + .polygon_array, + .line_array, + .cidr_array, + .circle_array, + .macaddr8_array, + .macaddr_array, + .inet_array, + .aclitem_array, + .pg_database_array, + .pg_database_array2, + => { + // this is also a string until we reach "," or "}" but a single word string like Bun + var current_idx: usize = 0; + + for (slice, 0..slice.len) |byte, index| { + switch (byte) { + '}', separator => { + current_idx = index; + break; + }, + else => {}, } + } + if (current_idx == 0) return error.UnsupportedArrayFormat; + const element = slice[0..current_idx]; + // lets handle NULL case here, if is a string "NULL" it will have quotes, if its a NULL it will be just NULL + if (bun.strings.eqlComptime(element, "NULL")) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }); slice = trySlice(slice, current_idx); continue; - }, - else => { - // non text array, NaN, Null, False, True etc are special cases here - switch (slice[0]) { - 'N' => { - // null or nan - if (slice.len < 3) return error.UnsupportedArrayFormat; - if (slice.len >= 4) { - if (bun.strings.eqlComptime(slice[0..4], "NULL")) { - try array.append(bun.default_allocator, DataCell{ .tag = .null, .value = .{ .null = 0 } }); - slice = trySlice(slice, 4); - continue; - } - } - if (bun.strings.eqlComptime(slice[0..3], "NaN")) { - try array.append(bun.default_allocator, DataCell{ .tag = .float8, .value = .{ .float8 = std.math.nan(f64) } }); - slice = trySlice(slice, 3); + } + if (arrayType == .date_array) { + var str = bun.String.init(element); + defer str.deref(); + try array.append(bun.default_allocator, SQLDataCell{ .tag = .date, .value = .{ .date = try str.parseDate(globalObject) } }); + } else { + // the only escape sequency possible here is \b + if (bun.strings.eqlComptime(element, "\\b")) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .string, .value = .{ .string = bun.String.cloneUTF8("\x08").value.WTFStringImpl }, .free_value = 1 }); + } else { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .string, .value = .{ .string = if (element.len > 0) bun.String.cloneUTF8(element).value.WTFStringImpl else null }, .free_value = 0 }); + } + } + slice = trySlice(slice, current_idx); + continue; + }, + else => { + // non text array, NaN, Null, False, True etc are special cases here + switch (slice[0]) { + 'N' => { + // null or nan + if (slice.len < 3) return error.UnsupportedArrayFormat; + if (slice.len >= 4) { + if (bun.strings.eqlComptime(slice[0..4], "NULL")) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }); + slice = trySlice(slice, 4); continue; } - return error.UnsupportedArrayFormat; - }, - 'f' => { - // false - if (arrayType == .json_array or arrayType == .jsonb_array) { - if (slice.len < 5) return error.UnsupportedArrayFormat; - if (bun.strings.eqlComptime(slice[0..5], "false")) { - try array.append(bun.default_allocator, DataCell{ .tag = .bool, .value = .{ .bool = 0 } }); - slice = trySlice(slice, 5); - continue; - } - } else { - try array.append(bun.default_allocator, DataCell{ .tag = .bool, .value = .{ .bool = 0 } }); - slice = trySlice(slice, 1); + } + if (bun.strings.eqlComptime(slice[0..3], "NaN")) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .float8, .value = .{ .float8 = std.math.nan(f64) } }); + slice = trySlice(slice, 3); + continue; + } + return error.UnsupportedArrayFormat; + }, + 'f' => { + // false + if (arrayType == .json_array or arrayType == .jsonb_array) { + if (slice.len < 5) return error.UnsupportedArrayFormat; + if (bun.strings.eqlComptime(slice[0..5], "false")) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .bool, .value = .{ .bool = 0 } }); + slice = trySlice(slice, 5); continue; } - }, - 't' => { - // true - if (arrayType == .json_array or arrayType == .jsonb_array) { - if (slice.len < 4) return error.UnsupportedArrayFormat; - if (bun.strings.eqlComptime(slice[0..4], "true")) { - try array.append(bun.default_allocator, DataCell{ .tag = .bool, .value = .{ .bool = 1 } }); - slice = trySlice(slice, 4); - continue; - } - } else { - try array.append(bun.default_allocator, DataCell{ .tag = .bool, .value = .{ .bool = 1 } }); - slice = trySlice(slice, 1); - continue; - } - }, - 'I', - 'i', - => { - // infinity - if (slice.len < 8) return error.UnsupportedArrayFormat; - - if (bun.strings.eqlCaseInsensitiveASCII(slice[0..8], "Infinity", false)) { - if (arrayType == .date_array or arrayType == .timestamp_array or arrayType == .timestamptz_array) { - try array.append(bun.default_allocator, DataCell{ .tag = .date, .value = .{ .date = std.math.inf(f64) } }); - } else { - try array.append(bun.default_allocator, DataCell{ .tag = .float8, .value = .{ .float8 = std.math.inf(f64) } }); - } - slice = trySlice(slice, 8); - continue; - } - - return error.UnsupportedArrayFormat; - }, - '+' => { + } else { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .bool, .value = .{ .bool = 0 } }); slice = trySlice(slice, 1); continue; - }, - '-', '0'...'9' => { - // parse number, detect float, int, if starts with - it can be -Infinity or -Infinity - var is_negative = false; - var is_float = false; - var current_idx: usize = 0; - var is_infinity = false; - // track exponent stuff (1.1e-12, 1.1e+12) - var has_exponent = false; - var has_negative_sign = false; - var has_positive_sign = false; - for (slice, 0..slice.len) |byte, index| { - switch (byte) { - '0'...'9' => {}, - closing_brace, separator => { - current_idx = index; - // end of element - break; - }, - 'e' => { - if (!is_float) return error.UnsupportedArrayFormat; - if (has_exponent) return error.UnsupportedArrayFormat; - has_exponent = true; - continue; - }, - '+' => { - if (!has_exponent) return error.UnsupportedArrayFormat; - if (has_positive_sign) return error.UnsupportedArrayFormat; - has_positive_sign = true; - continue; - }, - '-' => { - if (index == 0) { - is_negative = true; - continue; - } - if (!has_exponent) return error.UnsupportedArrayFormat; - if (has_negative_sign) return error.UnsupportedArrayFormat; - has_negative_sign = true; - continue; - }, - '.' => { - // we can only have one dot and the dot must be before the exponent - if (is_float) return error.UnsupportedArrayFormat; - is_float = true; - }, - 'I', 'i' => { - // infinity - is_infinity = true; - const element = if (is_negative) slice[1..] else slice; - if (element.len < 8) return error.UnsupportedArrayFormat; - if (bun.strings.eqlCaseInsensitiveASCII(element[0..8], "Infinity", false)) { - if (arrayType == .date_array or arrayType == .timestamp_array or arrayType == .timestamptz_array) { - try array.append(bun.default_allocator, DataCell{ .tag = .date, .value = .{ .date = if (is_negative) -std.math.inf(f64) else std.math.inf(f64) } }); - } else { - try array.append(bun.default_allocator, DataCell{ .tag = .float8, .value = .{ .float8 = if (is_negative) -std.math.inf(f64) else std.math.inf(f64) } }); - } - slice = trySlice(slice, 8 + @as(usize, @intFromBool(is_negative))); - break; - } + } + }, + 't' => { + // true + if (arrayType == .json_array or arrayType == .jsonb_array) { + if (slice.len < 4) return error.UnsupportedArrayFormat; + if (bun.strings.eqlComptime(slice[0..4], "true")) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .bool, .value = .{ .bool = 1 } }); + slice = trySlice(slice, 4); + continue; + } + } else { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .bool, .value = .{ .bool = 1 } }); + slice = trySlice(slice, 1); + continue; + } + }, + 'I', + 'i', + => { + // infinity + if (slice.len < 8) return error.UnsupportedArrayFormat; - return error.UnsupportedArrayFormat; - }, - else => { - return error.UnsupportedArrayFormat; - }, - } + if (bun.strings.eqlCaseInsensitiveASCII(slice[0..8], "Infinity", false)) { + if (arrayType == .date_array or arrayType == .timestamp_array or arrayType == .timestamptz_array) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .date, .value = .{ .date = std.math.inf(f64) } }); + } else { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .float8, .value = .{ .float8 = std.math.inf(f64) } }); } - if (is_infinity) { - continue; - } - if (current_idx == 0) return error.UnsupportedArrayFormat; - const element = slice[0..current_idx]; - if (is_float or arrayType == .float8_array) { - try array.append(bun.default_allocator, DataCell{ .tag = .float8, .value = .{ .float8 = bun.parseDouble(element) catch std.math.nan(f64) } }); - slice = trySlice(slice, current_idx); - continue; - } - switch (arrayType) { - .int8_array => { - if (bigint) { - try array.append(bun.default_allocator, DataCell{ .tag = .int8, .value = .{ .int8 = std.fmt.parseInt(i64, element, 0) catch return error.UnsupportedArrayFormat } }); - } else { - try array.append(bun.default_allocator, DataCell{ .tag = .string, .value = .{ .string = if (element.len > 0) bun.String.cloneUTF8(element).value.WTFStringImpl else null }, .free_value = 1 }); - } - slice = trySlice(slice, current_idx); + slice = trySlice(slice, 8); + continue; + } + + return error.UnsupportedArrayFormat; + }, + '+' => { + slice = trySlice(slice, 1); + continue; + }, + '-', '0'...'9' => { + // parse number, detect float, int, if starts with - it can be -Infinity or -Infinity + var is_negative = false; + var is_float = false; + var current_idx: usize = 0; + var is_infinity = false; + // track exponent stuff (1.1e-12, 1.1e+12) + var has_exponent = false; + var has_negative_sign = false; + var has_positive_sign = false; + for (slice, 0..slice.len) |byte, index| { + switch (byte) { + '0'...'9' => {}, + closing_brace, separator => { + current_idx = index; + // end of element + break; + }, + 'e' => { + if (!is_float) return error.UnsupportedArrayFormat; + if (has_exponent) return error.UnsupportedArrayFormat; + has_exponent = true; continue; }, - .cid_array, .xid_array, .oid_array => { - try array.append(bun.default_allocator, DataCell{ .tag = .uint4, .value = .{ .uint4 = std.fmt.parseInt(u32, element, 0) catch 0 } }); - slice = trySlice(slice, current_idx); + '+' => { + if (!has_exponent) return error.UnsupportedArrayFormat; + if (has_positive_sign) return error.UnsupportedArrayFormat; + has_positive_sign = true; continue; }, + '-' => { + if (index == 0) { + is_negative = true; + continue; + } + if (!has_exponent) return error.UnsupportedArrayFormat; + if (has_negative_sign) return error.UnsupportedArrayFormat; + has_negative_sign = true; + continue; + }, + '.' => { + // we can only have one dot and the dot must be before the exponent + if (is_float) return error.UnsupportedArrayFormat; + is_float = true; + }, + 'I', 'i' => { + // infinity + is_infinity = true; + const element = if (is_negative) slice[1..] else slice; + if (element.len < 8) return error.UnsupportedArrayFormat; + if (bun.strings.eqlCaseInsensitiveASCII(element[0..8], "Infinity", false)) { + if (arrayType == .date_array or arrayType == .timestamp_array or arrayType == .timestamptz_array) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .date, .value = .{ .date = if (is_negative) -std.math.inf(f64) else std.math.inf(f64) } }); + } else { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .float8, .value = .{ .float8 = if (is_negative) -std.math.inf(f64) else std.math.inf(f64) } }); + } + slice = trySlice(slice, 8 + @as(usize, @intFromBool(is_negative))); + break; + } + + return error.UnsupportedArrayFormat; + }, else => { - const value = std.fmt.parseInt(i32, element, 0) catch return error.UnsupportedArrayFormat; - - try array.append(bun.default_allocator, DataCell{ .tag = .int4, .value = .{ .int4 = @intCast(value) } }); - slice = trySlice(slice, current_idx); - continue; + return error.UnsupportedArrayFormat; }, } - }, - else => { - if (arrayType == .json_array or arrayType == .jsonb_array) { - if (slice[0] == '[') { - var sub_array_offset: usize = 0; - const sub_array = try parseArray(slice, bigint, arrayType, globalObject, &sub_array_offset, true); - try array.append(bun.default_allocator, sub_array); - slice = trySlice(slice, sub_array_offset); - continue; + } + if (is_infinity) { + continue; + } + if (current_idx == 0) return error.UnsupportedArrayFormat; + const element = slice[0..current_idx]; + if (is_float or arrayType == .float8_array) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .float8, .value = .{ .float8 = bun.parseDouble(element) catch std.math.nan(f64) } }); + slice = trySlice(slice, current_idx); + continue; + } + switch (arrayType) { + .int8_array => { + if (bigint) { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .int8, .value = .{ .int8 = std.fmt.parseInt(i64, element, 0) catch return error.UnsupportedArrayFormat } }); + } else { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .string, .value = .{ .string = if (element.len > 0) bun.String.cloneUTF8(element).value.WTFStringImpl else null }, .free_value = 1 }); } + slice = trySlice(slice, current_idx); + continue; + }, + .cid_array, .xid_array, .oid_array => { + try array.append(bun.default_allocator, SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = std.fmt.parseInt(u32, element, 0) catch 0 } }); + slice = trySlice(slice, current_idx); + continue; + }, + else => { + const value = std.fmt.parseInt(i32, element, 0) catch return error.UnsupportedArrayFormat; + + try array.append(bun.default_allocator, SQLDataCell{ .tag = .int4, .value = .{ .int4 = @intCast(value) } }); + slice = trySlice(slice, current_idx); + continue; + }, + } + }, + else => { + if (arrayType == .json_array or arrayType == .jsonb_array) { + if (slice[0] == '[') { + var sub_array_offset: usize = 0; + const sub_array = try parseArray(slice, bigint, arrayType, globalObject, &sub_array_offset, true); + try array.append(bun.default_allocator, sub_array); + slice = trySlice(slice, sub_array_offset); + continue; } - return error.UnsupportedArrayFormat; - }, - } - }, - } - }, - } + } + return error.UnsupportedArrayFormat; + }, + } + }, + } + }, } - - if (offset) |offset_ptr| { - offset_ptr.* = bytes.len - slice.len; - } - - // postgres dont really support arrays with more than 2^31 elements, 2ˆ32 is the max we support, but users should never reach this branch - if (!reached_end or array.items.len > std.math.maxInt(u32)) { - @branchHint(.unlikely); - - return error.UnsupportedArrayFormat; - } - return DataCell{ .tag = .array, .value = .{ .array = .{ .ptr = array.items.ptr, .len = @truncate(array.items.len), .cap = @truncate(array.capacity) } } }; } - pub fn fromBytes(binary: bool, bigint: bool, oid: types.Tag, bytes: []const u8, globalObject: *jsc.JSGlobalObject) !DataCell { - switch (oid) { - // TODO: .int2_array, .float8_array - inline .int4_array, .float4_array => |tag| { - if (binary) { - if (bytes.len < 16) { - return error.InvalidBinaryData; - } - // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c#L1549-L1645 - const dimensions_raw: int4 = @bitCast(bytes[0..4].*); - const contains_nulls: int4 = @bitCast(bytes[4..8].*); + if (offset) |offset_ptr| { + offset_ptr.* = bytes.len - slice.len; + } - const dimensions = @byteSwap(dimensions_raw); - if (dimensions > 1) { - return error.MultidimensionalArrayNotSupportedYet; - } + // postgres dont really support arrays with more than 2^31 elements, 2ˆ32 is the max we support, but users should never reach this branch + if (!reached_end or array.items.len > std.math.maxInt(u32)) { + @branchHint(.unlikely); - if (contains_nulls != 0) { - return error.NullsInArrayNotSupportedYet; - } + return error.UnsupportedArrayFormat; + } + return SQLDataCell{ .tag = .array, .value = .{ .array = .{ .ptr = array.items.ptr, .len = @truncate(array.items.len), .cap = @truncate(array.capacity) } } }; +} - if (dimensions == 0) { - return DataCell{ - .tag = .typed_array, - .value = .{ - .typed_array = .{ - .ptr = null, - .len = 0, - .byte_len = 0, - .type = try tag.toJSTypedArrayType(), - }, - }, - }; - } +pub fn fromBytes(binary: bool, bigint: bool, oid: types.Tag, bytes: []const u8, globalObject: *jsc.JSGlobalObject) !SQLDataCell { + switch (oid) { + // TODO: .int2_array, .float8_array + inline .int4_array, .float4_array => |tag| { + if (binary) { + if (bytes.len < 16) { + return error.InvalidBinaryData; + } + // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c#L1549-L1645 + const dimensions_raw: int4 = @bitCast(bytes[0..4].*); + const contains_nulls: int4 = @bitCast(bytes[4..8].*); - const elements = (try tag.pgArrayType()).init(bytes).slice(); + const dimensions = @byteSwap(dimensions_raw); + if (dimensions > 1) { + return error.MultidimensionalArrayNotSupportedYet; + } - return DataCell{ + if (contains_nulls != 0) { + return error.NullsInArrayNotSupportedYet; + } + + if (dimensions == 0) { + return SQLDataCell{ .tag = .typed_array, .value = .{ .typed_array = .{ - .head_ptr = if (bytes.len > 0) @constCast(bytes.ptr) else null, - .ptr = if (elements.len > 0) @ptrCast(elements.ptr) else null, - .len = @truncate(elements.len), - .byte_len = @truncate(bytes.len), + .ptr = null, + .len = 0, + .byte_len = 0, .type = try tag.toJSTypedArrayType(), }, }, }; - } else { - return try parseArray(bytes, bigint, tag, globalObject, null, false); } - }, - .int2 => { - if (binary) { - return DataCell{ .tag = .int4, .value = .{ .int4 = try parseBinary(.int2, i16, bytes) } }; - } else { - return DataCell{ .tag = .int4, .value = .{ .int4 = std.fmt.parseInt(i32, bytes, 0) catch 0 } }; - } - }, - .cid, .xid, .oid => { - if (binary) { - return DataCell{ .tag = .uint4, .value = .{ .uint4 = try parseBinary(.oid, u32, bytes) } }; - } else { - return DataCell{ .tag = .uint4, .value = .{ .uint4 = std.fmt.parseInt(u32, bytes, 0) catch 0 } }; - } - }, - .int4 => { - if (binary) { - return DataCell{ .tag = .int4, .value = .{ .int4 = try parseBinary(.int4, i32, bytes) } }; - } else { - return DataCell{ .tag = .int4, .value = .{ .int4 = std.fmt.parseInt(i32, bytes, 0) catch 0 } }; - } - }, - // postgres when reading bigint as int8 it returns a string unless type: { bigint: postgres.BigInt is set - .int8 => { - if (bigint) { - // .int8 is a 64-bit integer always string - return DataCell{ .tag = .int8, .value = .{ .int8 = std.fmt.parseInt(i64, bytes, 0) catch 0 } }; - } else { - return DataCell{ .tag = .string, .value = .{ .string = if (bytes.len > 0) bun.String.cloneUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; - } - }, - .float8 => { - if (binary and bytes.len == 8) { - return DataCell{ .tag = .float8, .value = .{ .float8 = try parseBinary(.float8, f64, bytes) } }; - } else { - const float8: f64 = bun.parseDouble(bytes) catch std.math.nan(f64); - return DataCell{ .tag = .float8, .value = .{ .float8 = float8 } }; - } - }, - .float4 => { - if (binary and bytes.len == 4) { - return DataCell{ .tag = .float8, .value = .{ .float8 = try parseBinary(.float4, f32, bytes) } }; - } else { - const float4: f64 = bun.parseDouble(bytes) catch std.math.nan(f64); - return DataCell{ .tag = .float8, .value = .{ .float8 = float4 } }; - } - }, - .numeric => { - if (binary) { - // this is probrably good enough for most cases - var stack_buffer = std.heap.stackFallback(1024, bun.default_allocator); - const allocator = stack_buffer.get(); - var numeric_buffer = std.ArrayList(u8).fromOwnedSlice(allocator, &stack_buffer.buffer); - numeric_buffer.items.len = 0; - defer numeric_buffer.deinit(); - // if is binary format lets display as a string because JS cant handle it in a safe way - const result = parseBinaryNumeric(bytes, &numeric_buffer) catch return error.UnsupportedNumericFormat; - return DataCell{ .tag = .string, .value = .{ .string = bun.String.cloneUTF8(result.slice()).value.WTFStringImpl }, .free_value = 1 }; - } else { - // nice text is actually what we want here - return DataCell{ .tag = .string, .value = .{ .string = if (bytes.len > 0) String.cloneUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; - } - }, - .jsonb, .json => { - return DataCell{ .tag = .json, .value = .{ .json = if (bytes.len > 0) String.cloneUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; - }, - .bool => { - if (binary) { - return DataCell{ .tag = .bool, .value = .{ .bool = @intFromBool(bytes.len > 0 and bytes[0] == 1) } }; - } else { - return DataCell{ .tag = .bool, .value = .{ .bool = @intFromBool(bytes.len > 0 and bytes[0] == 't') } }; - } - }, - .date, .timestamp, .timestamptz => |tag| { - if (bytes.len == 0) { - return DataCell{ .tag = .null, .value = .{ .null = 0 } }; - } - if (binary and bytes.len == 8) { - switch (tag) { - .timestamptz => return DataCell{ .tag = .date_with_time_zone, .value = .{ .date_with_time_zone = types.date.fromBinary(bytes) } }, - .timestamp => return DataCell{ .tag = .date, .value = .{ .date = types.date.fromBinary(bytes) } }, - else => unreachable, - } - } else { - if (bun.strings.eqlCaseInsensitiveASCII(bytes, "NULL", true)) { - return DataCell{ .tag = .null, .value = .{ .null = 0 } }; - } - var str = bun.String.init(bytes); - defer str.deref(); - return DataCell{ .tag = .date, .value = .{ .date = try str.parseDate(globalObject) } }; - } - }, + const elements = (try tag.pgArrayType()).init(bytes).slice(); - .bytea => { - if (binary) { - return DataCell{ .tag = .bytea, .value = .{ .bytea = .{ @intFromPtr(bytes.ptr), bytes.len } } }; - } else { - if (bun.strings.hasPrefixComptime(bytes, "\\x")) { - return try parseBytea(bytes[2..]); - } - return error.UnsupportedByteaFormat; - } - }, - // text array types - inline .bpchar_array, - .varchar_array, - .char_array, - .text_array, - .name_array, - .json_array, - .jsonb_array, - // special types handled as text array - .path_array, - .xml_array, - .point_array, - .lseg_array, - .box_array, - .polygon_array, - .line_array, - .cidr_array, - .numeric_array, - .money_array, - .varbit_array, - .bit_array, - .int2vector_array, - .circle_array, - .macaddr8_array, - .macaddr_array, - .inet_array, - .aclitem_array, - .tid_array, - .pg_database_array, - .pg_database_array2, - // numeric array types - .int8_array, - .int2_array, - .float8_array, - .oid_array, - .xid_array, - .cid_array, - - // special types - .bool_array, - .bytea_array, - - //time types - .time_array, - .date_array, - .timetz_array, - .timestamp_array, - .timestamptz_array, - .interval_array, - => |tag| { + return SQLDataCell{ + .tag = .typed_array, + .value = .{ + .typed_array = .{ + .head_ptr = if (bytes.len > 0) @constCast(bytes.ptr) else null, + .ptr = if (elements.len > 0) @ptrCast(elements.ptr) else null, + .len = @truncate(elements.len), + .byte_len = @truncate(bytes.len), + .type = try tag.toJSTypedArrayType(), + }, + }, + }; + } else { return try parseArray(bytes, bigint, tag, globalObject, null, false); - }, - else => { - return DataCell{ .tag = .string, .value = .{ .string = if (bytes.len > 0) bun.String.cloneUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; - }, - } + } + }, + .int2 => { + if (binary) { + return SQLDataCell{ .tag = .int4, .value = .{ .int4 = try parseBinary(.int2, i16, bytes) } }; + } else { + return SQLDataCell{ .tag = .int4, .value = .{ .int4 = std.fmt.parseInt(i32, bytes, 0) catch 0 } }; + } + }, + .cid, .xid, .oid => { + if (binary) { + return SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = try parseBinary(.oid, u32, bytes) } }; + } else { + return SQLDataCell{ .tag = .uint4, .value = .{ .uint4 = std.fmt.parseInt(u32, bytes, 0) catch 0 } }; + } + }, + .int4 => { + if (binary) { + return SQLDataCell{ .tag = .int4, .value = .{ .int4 = try parseBinary(.int4, i32, bytes) } }; + } else { + return SQLDataCell{ .tag = .int4, .value = .{ .int4 = std.fmt.parseInt(i32, bytes, 0) catch 0 } }; + } + }, + // postgres when reading bigint as int8 it returns a string unless type: { bigint: postgres.BigInt is set + .int8 => { + if (bigint) { + // .int8 is a 64-bit integer always string + return SQLDataCell{ .tag = .int8, .value = .{ .int8 = std.fmt.parseInt(i64, bytes, 0) catch 0 } }; + } else { + return SQLDataCell{ .tag = .string, .value = .{ .string = if (bytes.len > 0) bun.String.cloneUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; + } + }, + .float8 => { + if (binary and bytes.len == 8) { + return SQLDataCell{ .tag = .float8, .value = .{ .float8 = try parseBinary(.float8, f64, bytes) } }; + } else { + const float8: f64 = bun.parseDouble(bytes) catch std.math.nan(f64); + return SQLDataCell{ .tag = .float8, .value = .{ .float8 = float8 } }; + } + }, + .float4 => { + if (binary and bytes.len == 4) { + return SQLDataCell{ .tag = .float8, .value = .{ .float8 = try parseBinary(.float4, f32, bytes) } }; + } else { + const float4: f64 = bun.parseDouble(bytes) catch std.math.nan(f64); + return SQLDataCell{ .tag = .float8, .value = .{ .float8 = float4 } }; + } + }, + .numeric => { + if (binary) { + // this is probrably good enough for most cases + var stack_buffer = std.heap.stackFallback(1024, bun.default_allocator); + const allocator = stack_buffer.get(); + var numeric_buffer = std.ArrayList(u8).fromOwnedSlice(allocator, &stack_buffer.buffer); + numeric_buffer.items.len = 0; + defer numeric_buffer.deinit(); + + // if is binary format lets display as a string because JS cant handle it in a safe way + const result = parseBinaryNumeric(bytes, &numeric_buffer) catch return error.UnsupportedNumericFormat; + return SQLDataCell{ .tag = .string, .value = .{ .string = bun.String.cloneUTF8(result.slice()).value.WTFStringImpl }, .free_value = 1 }; + } else { + // nice text is actually what we want here + return SQLDataCell{ .tag = .string, .value = .{ .string = if (bytes.len > 0) String.cloneUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; + } + }, + .jsonb, .json => { + return SQLDataCell{ .tag = .json, .value = .{ .json = if (bytes.len > 0) String.cloneUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; + }, + .bool => { + if (binary) { + return SQLDataCell{ .tag = .bool, .value = .{ .bool = @intFromBool(bytes.len > 0 and bytes[0] == 1) } }; + } else { + return SQLDataCell{ .tag = .bool, .value = .{ .bool = @intFromBool(bytes.len > 0 and bytes[0] == 't') } }; + } + }, + .date, .timestamp, .timestamptz => |tag| { + if (bytes.len == 0) { + return SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }; + } + if (binary and bytes.len == 8) { + switch (tag) { + .timestamptz => return SQLDataCell{ .tag = .date_with_time_zone, .value = .{ .date_with_time_zone = types.date.fromBinary(bytes) } }, + .timestamp => return SQLDataCell{ .tag = .date, .value = .{ .date = types.date.fromBinary(bytes) } }, + else => unreachable, + } + } else { + if (bun.strings.eqlCaseInsensitiveASCII(bytes, "NULL", true)) { + return SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }; + } + var str = bun.String.init(bytes); + defer str.deref(); + return SQLDataCell{ .tag = .date, .value = .{ .date = try str.parseDate(globalObject) } }; + } + }, + + .bytea => { + if (binary) { + return SQLDataCell{ .tag = .bytea, .value = .{ .bytea = .{ @intFromPtr(bytes.ptr), bytes.len } } }; + } else { + if (bun.strings.hasPrefixComptime(bytes, "\\x")) { + return try parseBytea(bytes[2..]); + } + return error.UnsupportedByteaFormat; + } + }, + // text array types + inline .bpchar_array, + .varchar_array, + .char_array, + .text_array, + .name_array, + .json_array, + .jsonb_array, + // special types handled as text array + .path_array, + .xml_array, + .point_array, + .lseg_array, + .box_array, + .polygon_array, + .line_array, + .cidr_array, + .numeric_array, + .money_array, + .varbit_array, + .bit_array, + .int2vector_array, + .circle_array, + .macaddr8_array, + .macaddr_array, + .inet_array, + .aclitem_array, + .tid_array, + .pg_database_array, + .pg_database_array2, + // numeric array types + .int8_array, + .int2_array, + .float8_array, + .oid_array, + .xid_array, + .cid_array, + + // special types + .bool_array, + .bytea_array, + + //time types + .time_array, + .date_array, + .timetz_array, + .timestamp_array, + .timestamptz_array, + .interval_array, + => |tag| { + return try parseArray(bytes, bigint, tag, globalObject, null, false); + }, + else => { + return SQLDataCell{ .tag = .string, .value = .{ .string = if (bytes.len > 0) bun.String.cloneUTF8(bytes).value.WTFStringImpl else null }, .free_value = 1 }; + }, + } +} + +// #define pg_hton16(x) (x) +// #define pg_hton32(x) (x) +// #define pg_hton64(x) (x) + +// #define pg_ntoh16(x) (x) +// #define pg_ntoh32(x) (x) +// #define pg_ntoh64(x) (x) + +fn pg_ntoT(comptime IntSize: usize, i: anytype) std.meta.Int(.unsigned, IntSize) { + @setRuntimeSafety(false); + const T = @TypeOf(i); + if (@typeInfo(T) == .array) { + return pg_ntoT(IntSize, @as(std.meta.Int(.unsigned, IntSize), @bitCast(i))); } - // #define pg_hton16(x) (x) - // #define pg_hton32(x) (x) - // #define pg_hton64(x) (x) + const casted: std.meta.Int(.unsigned, IntSize) = @intCast(i); + return @byteSwap(casted); +} +fn pg_ntoh16(x: anytype) u16 { + return pg_ntoT(16, x); +} - // #define pg_ntoh16(x) (x) - // #define pg_ntoh32(x) (x) - // #define pg_ntoh64(x) (x) +fn pg_ntoh32(x: anytype) u32 { + return pg_ntoT(32, x); +} +const PGNummericString = union(enum) { + static: [:0]const u8, + dynamic: []const u8, - fn pg_ntoT(comptime IntSize: usize, i: anytype) std.meta.Int(.unsigned, IntSize) { - @setRuntimeSafety(false); - const T = @TypeOf(i); - if (@typeInfo(T) == .array) { - return pg_ntoT(IntSize, @as(std.meta.Int(.unsigned, IntSize), @bitCast(i))); - } - - const casted: std.meta.Int(.unsigned, IntSize) = @intCast(i); - return @byteSwap(casted); + pub fn slice(this: PGNummericString) []const u8 { + return switch (this) { + .static => |value| value, + .dynamic => |value| value, + }; } - fn pg_ntoh16(x: anytype) u16 { - return pg_ntoT(16, x); +}; + +fn parseBinaryNumeric(input: []const u8, result: *std.ArrayList(u8)) !PGNummericString { + // Reference: https://github.com/postgres/postgres/blob/50e6eb731d98ab6d0e625a0b87fb327b172bbebd/src/backend/utils/adt/numeric.c#L7612-L7740 + if (input.len < 8) return error.InvalidBuffer; + var fixed_buffer = std.io.fixedBufferStream(input); + var reader = fixed_buffer.reader(); + + // Read header values using big-endian + const ndigits = try reader.readInt(i16, .big); + const weight = try reader.readInt(i16, .big); + const sign = try reader.readInt(u16, .big); + const dscale = try reader.readInt(i16, .big); + + // Handle special cases + switch (sign) { + 0xC000 => return PGNummericString{ .static = "NaN" }, + 0xD000 => return PGNummericString{ .static = "Infinity" }, + 0xF000 => return PGNummericString{ .static = "-Infinity" }, + 0x4000, 0x0000 => {}, + else => return error.InvalidSign, } - fn pg_ntoh32(x: anytype) u32 { - return pg_ntoT(32, x); + if (ndigits == 0) { + return PGNummericString{ .static = "0" }; } - const PGNummericString = union(enum) { - static: [:0]const u8, - dynamic: []const u8, - pub fn slice(this: PGNummericString) []const u8 { - return switch (this) { - .static => |value| value, - .dynamic => |value| value, - }; + // Add negative sign if needed + if (sign == 0x4000) { + try result.append('-'); + } + + // Calculate decimal point position + var decimal_pos: i32 = @as(i32, weight + 1) * 4; + if (decimal_pos <= 0) { + decimal_pos = 1; + } + // Output all digits before the decimal point + + var scale_start: i32 = 0; + if (weight < 0) { + try result.append('0'); + scale_start = @as(i32, @intCast(weight)) + 1; + } else { + var idx: usize = 0; + var first_non_zero = false; + + while (idx <= weight) : (idx += 1) { + const digit = if (idx < ndigits) try reader.readInt(u16, .big) else 0; + var digit_str: [4]u8 = undefined; + const digit_len = std.fmt.formatIntBuf(&digit_str, digit, 10, .lower, .{ .width = 4, .fill = '0' }); + if (!first_non_zero) { + //In the first digit, suppress extra leading decimal zeroes + var start_idx: usize = 0; + while (start_idx < digit_len and digit_str[start_idx] == '0') : (start_idx += 1) {} + if (start_idx == digit_len) continue; + const digit_slice = digit_str[start_idx..digit_len]; + try result.appendSlice(digit_slice); + first_non_zero = true; + } else { + try result.appendSlice(digit_str[0..digit_len]); + } } - }; - - fn parseBinaryNumeric(input: []const u8, result: *std.ArrayList(u8)) !PGNummericString { - // Reference: https://github.com/postgres/postgres/blob/50e6eb731d98ab6d0e625a0b87fb327b172bbebd/src/backend/utils/adt/numeric.c#L7612-L7740 - if (input.len < 8) return error.InvalidBuffer; - var fixed_buffer = std.io.fixedBufferStream(input); - var reader = fixed_buffer.reader(); - - // Read header values using big-endian - const ndigits = try reader.readInt(i16, .big); - const weight = try reader.readInt(i16, .big); - const sign = try reader.readInt(u16, .big); - const dscale = try reader.readInt(i16, .big); - - // Handle special cases - switch (sign) { - 0xC000 => return PGNummericString{ .static = "NaN" }, - 0xD000 => return PGNummericString{ .static = "Infinity" }, - 0xF000 => return PGNummericString{ .static = "-Infinity" }, - 0x4000, 0x0000 => {}, - else => return error.InvalidSign, - } - - if (ndigits == 0) { - return PGNummericString{ .static = "0" }; - } - - // Add negative sign if needed - if (sign == 0x4000) { - try result.append('-'); - } - - // Calculate decimal point position - var decimal_pos: i32 = @as(i32, weight + 1) * 4; - if (decimal_pos <= 0) { - decimal_pos = 1; - } - // Output all digits before the decimal point - - var scale_start: i32 = 0; - if (weight < 0) { - try result.append('0'); - scale_start = @as(i32, @intCast(weight)) + 1; - } else { - var idx: usize = 0; - var first_non_zero = false; - - while (idx <= weight) : (idx += 1) { - const digit = if (idx < ndigits) try reader.readInt(u16, .big) else 0; + } + // If requested, output a decimal point and all the digits that follow it. + // We initially put out a multiple of 4 digits, then truncate if needed. + if (dscale > 0) { + try result.append('.'); + // negative scale means we need to add zeros before the decimal point + // greater than ndigits means we need to add zeros after the decimal point + var idx: isize = scale_start; + const end: usize = result.items.len + @as(usize, @intCast(dscale)); + while (idx < dscale) : (idx += 4) { + if (idx >= 0 and idx < ndigits) { + const digit = reader.readInt(u16, .big) catch 0; var digit_str: [4]u8 = undefined; const digit_len = std.fmt.formatIntBuf(&digit_str, digit, 10, .lower, .{ .width = 4, .fill = '0' }); - if (!first_non_zero) { - //In the first digit, suppress extra leading decimal zeroes - var start_idx: usize = 0; - while (start_idx < digit_len and digit_str[start_idx] == '0') : (start_idx += 1) {} - if (start_idx == digit_len) continue; - const digit_slice = digit_str[start_idx..digit_len]; - try result.appendSlice(digit_slice); - first_non_zero = true; - } else { - try result.appendSlice(digit_str[0..digit_len]); - } - } - } - // If requested, output a decimal point and all the digits that follow it. - // We initially put out a multiple of 4 digits, then truncate if needed. - if (dscale > 0) { - try result.append('.'); - // negative scale means we need to add zeros before the decimal point - // greater than ndigits means we need to add zeros after the decimal point - var idx: isize = scale_start; - const end: usize = result.items.len + @as(usize, @intCast(dscale)); - while (idx < dscale) : (idx += 4) { - if (idx >= 0 and idx < ndigits) { - const digit = reader.readInt(u16, .big) catch 0; - var digit_str: [4]u8 = undefined; - const digit_len = std.fmt.formatIntBuf(&digit_str, digit, 10, .lower, .{ .width = 4, .fill = '0' }); - try result.appendSlice(digit_str[0..digit_len]); - } else { - try result.appendSlice("0000"); - } - } - if (result.items.len > end) { - result.items.len = end; - } - } - return PGNummericString{ .dynamic = result.items }; - } - - pub fn parseBinary(comptime tag: types.Tag, comptime ReturnType: type, bytes: []const u8) AnyPostgresError!ReturnType { - switch (comptime tag) { - .float8 => { - return @as(f64, @bitCast(try parseBinary(.int8, i64, bytes))); - }, - .int8 => { - // pq_getmsgfloat8 - if (bytes.len != 8) return error.InvalidBinaryData; - return @byteSwap(@as(i64, @bitCast(bytes[0..8].*))); - }, - .int4 => { - // pq_getmsgint - switch (bytes.len) { - 1 => { - return bytes[0]; - }, - 2 => { - return pg_ntoh16(@as(u16, @bitCast(bytes[0..2].*))); - }, - 4 => { - return @bitCast(pg_ntoh32(@as(u32, @bitCast(bytes[0..4].*)))); - }, - else => { - return error.UnsupportedIntegerSize; - }, - } - }, - .oid => { - switch (bytes.len) { - 1 => { - return bytes[0]; - }, - 2 => { - return pg_ntoh16(@as(u16, @bitCast(bytes[0..2].*))); - }, - 4 => { - return pg_ntoh32(@as(u32, @bitCast(bytes[0..4].*))); - }, - else => { - return error.UnsupportedIntegerSize; - }, - } - }, - .int2 => { - // pq_getmsgint - switch (bytes.len) { - 1 => { - return bytes[0]; - }, - 2 => { - // PostgreSQL stores numbers in big-endian format, so we must read as big-endian - // Read as raw 16-bit unsigned integer - const value: u16 = @bitCast(bytes[0..2].*); - // Convert from big-endian to native-endian (we always use little endian) - return @bitCast(@byteSwap(value)); // Cast to signed 16-bit integer (i16) - }, - else => { - return error.UnsupportedIntegerSize; - }, - } - }, - .float4 => { - // pq_getmsgfloat4 - return @as(f32, @bitCast(try parseBinary(.int4, i32, bytes))); - }, - else => @compileError("TODO"), - } - } - - pub const Flags = packed struct(u32) { - has_indexed_columns: bool = false, - has_named_columns: bool = false, - has_duplicate_columns: bool = false, - _: u29 = 0, - }; - - pub const Putter = struct { - list: []DataCell, - fields: []const protocol.FieldDescription, - binary: bool = false, - bigint: bool = false, - count: usize = 0, - globalObject: *jsc.JSGlobalObject, - - extern fn JSC__constructObjectFromDataCell( - *jsc.JSGlobalObject, - JSValue, - JSValue, - [*]DataCell, - u32, - Flags, - u8, // result_mode - ?[*]jsc.JSObject.ExternColumnIdentifier, // names - u32, // names count - ) JSValue; - - pub fn toJS(this: *Putter, globalObject: *jsc.JSGlobalObject, array: JSValue, structure: JSValue, flags: Flags, result_mode: PostgresSQLQueryResultMode, cached_structure: ?PostgresCachedStructure) JSValue { - var names: ?[*]jsc.JSObject.ExternColumnIdentifier = null; - var names_count: u32 = 0; - if (cached_structure) |c| { - if (c.fields) |f| { - names = f.ptr; - names_count = @truncate(f.len); - } - } - - return JSC__constructObjectFromDataCell( - globalObject, - array, - structure, - this.list.ptr, - @truncate(this.fields.len), - flags, - @intFromEnum(result_mode), - names, - names_count, - ); - } - - fn putImpl(this: *Putter, index: u32, optional_bytes: ?*Data, comptime is_raw: bool) !bool { - // Bounds check to prevent crash when fields/list arrays are empty - if (index >= this.fields.len) { - debug("putImpl: index {d} >= fields.len {d}, ignoring extra field", .{ index, this.fields.len }); - return false; - } - if (index >= this.list.len) { - debug("putImpl: index {d} >= list.len {d}, ignoring extra field", .{ index, this.list.len }); - return false; - } - - const field = &this.fields[index]; - const oid = field.type_oid; - debug("index: {d}, oid: {d}", .{ index, oid }); - const cell: *DataCell = &this.list[index]; - if (is_raw) { - cell.* = DataCell.raw(optional_bytes); + try result.appendSlice(digit_str[0..digit_len]); } else { - const tag = if (std.math.maxInt(short) < oid) .text else @as(types.Tag, @enumFromInt(@as(short, @intCast(oid)))); - cell.* = if (optional_bytes) |data| - try DataCell.fromBytes((field.binary or this.binary) and tag.isBinaryFormatSupported(), this.bigint, tag, data.slice(), this.globalObject) - else - DataCell{ - .tag = .null, - .value = .{ - .null = 0, - }, - }; + try result.appendSlice("0000"); } - this.count += 1; - cell.index = switch (field.name_or_index) { - // The indexed columns can be out of order. - .index => |i| i, + } + if (result.items.len > end) { + result.items.len = end; + } + } + return PGNummericString{ .dynamic = result.items }; +} - else => @intCast(index), - }; +pub fn parseBinary(comptime tag: types.Tag, comptime ReturnType: type, bytes: []const u8) AnyPostgresError!ReturnType { + switch (comptime tag) { + .float8 => { + return @as(f64, @bitCast(try parseBinary(.int8, i64, bytes))); + }, + .int8 => { + // pq_getmsgfloat8 + if (bytes.len != 8) return error.InvalidBinaryData; + return @byteSwap(@as(i64, @bitCast(bytes[0..8].*))); + }, + .int4 => { + // pq_getmsgint + switch (bytes.len) { + 1 => { + return bytes[0]; + }, + 2 => { + return pg_ntoh16(@as(u16, @bitCast(bytes[0..2].*))); + }, + 4 => { + return @bitCast(pg_ntoh32(@as(u32, @bitCast(bytes[0..4].*)))); + }, + else => { + return error.UnsupportedIntegerSize; + }, + } + }, + .oid => { + switch (bytes.len) { + 1 => { + return bytes[0]; + }, + 2 => { + return pg_ntoh16(@as(u16, @bitCast(bytes[0..2].*))); + }, + 4 => { + return pg_ntoh32(@as(u32, @bitCast(bytes[0..4].*))); + }, + else => { + return error.UnsupportedIntegerSize; + }, + } + }, + .int2 => { + // pq_getmsgint + switch (bytes.len) { + 1 => { + return bytes[0]; + }, + 2 => { + // PostgreSQL stores numbers in big-endian format, so we must read as big-endian + // Read as raw 16-bit unsigned integer + const value: u16 = @bitCast(bytes[0..2].*); + // Convert from big-endian to native-endian (we always use little endian) + return @bitCast(@byteSwap(value)); // Cast to signed 16-bit integer (i16) + }, + else => { + return error.UnsupportedIntegerSize; + }, + } + }, + .float4 => { + // pq_getmsgfloat4 + return @as(f32, @bitCast(try parseBinary(.int4, i32, bytes))); + }, + else => @compileError("TODO"), + } +} +pub const Putter = struct { + list: []SQLDataCell, + fields: []const protocol.FieldDescription, + binary: bool = false, + bigint: bool = false, + count: usize = 0, + globalObject: *jsc.JSGlobalObject, - // TODO: when duplicate and we know the result will be an object - // and not a .values() array, we can discard the data - // immediately. - cell.isIndexedColumn = switch (field.name_or_index) { - .duplicate => 2, - .index => 1, - .name => 0, - }; - return true; + pub fn toJS(this: *Putter, globalObject: *jsc.JSGlobalObject, array: JSValue, structure: JSValue, flags: SQLDataCell.Flags, result_mode: PostgresSQLQueryResultMode, cached_structure: ?PostgresCachedStructure) JSValue { + var names: ?[*]jsc.JSObject.ExternColumnIdentifier = null; + var names_count: u32 = 0; + if (cached_structure) |c| { + if (c.fields) |f| { + names = f.ptr; + names_count = @truncate(f.len); + } } - pub fn putRaw(this: *Putter, index: u32, optional_bytes: ?*Data) !bool { - return this.putImpl(index, optional_bytes, true); + return SQLDataCell.JSC__constructObjectFromDataCell( + globalObject, + array, + structure, + this.list.ptr, + @truncate(this.fields.len), + flags, + @intFromEnum(result_mode), + names, + names_count, + ); + } + + fn putImpl(this: *Putter, index: u32, optional_bytes: ?*Data, comptime is_raw: bool) !bool { + // Bounds check to prevent crash when fields/list arrays are empty + if (index >= this.fields.len) { + debug("putImpl: index {d} >= fields.len {d}, ignoring extra field", .{ index, this.fields.len }); + return false; } - pub fn put(this: *Putter, index: u32, optional_bytes: ?*Data) !bool { - return this.putImpl(index, optional_bytes, false); + if (index >= this.list.len) { + debug("putImpl: index {d} >= list.len {d}, ignoring extra field", .{ index, this.list.len }); + return false; } - }; + + const field = &this.fields[index]; + const oid = field.type_oid; + debug("index: {d}, oid: {d}", .{ index, oid }); + const cell: *SQLDataCell = &this.list[index]; + if (is_raw) { + cell.* = SQLDataCell.raw(optional_bytes); + } else { + const tag = if (std.math.maxInt(short) < oid) .text else @as(types.Tag, @enumFromInt(@as(short, @intCast(oid)))); + cell.* = if (optional_bytes) |data| + try fromBytes((field.binary or this.binary) and tag.isBinaryFormatSupported(), this.bigint, tag, data.slice(), this.globalObject) + else + SQLDataCell{ + .tag = .null, + .value = .{ + .null = 0, + }, + }; + } + this.count += 1; + cell.index = switch (field.name_or_index) { + // The indexed columns can be out of order. + .index => |i| i, + + else => @intCast(index), + }; + + // TODO: when duplicate and we know the result will be an object + // and not a .values() array, we can discard the data + // immediately. + cell.isIndexedColumn = switch (field.name_or_index) { + .duplicate => 2, + .index => 1, + .name => 0, + }; + return true; + } + + pub fn putRaw(this: *Putter, index: u32, optional_bytes: ?*Data) !bool { + return this.putImpl(index, optional_bytes, true); + } + pub fn put(this: *Putter, index: u32, optional_bytes: ?*Data) !bool { + return this.putImpl(index, optional_bytes, false); + } }; const debug = bun.Output.scoped(.Postgres, .visible); -const PostgresCachedStructure = @import("./PostgresCachedStructure.zig"); +const PostgresCachedStructure = @import("../shared/CachedStructure.zig"); const protocol = @import("./PostgresProtocol.zig"); const std = @import("std"); -const Data = @import("./Data.zig").Data; -const PostgresSQLQueryResultMode = @import("./PostgresSQLQueryResultMode.zig").PostgresSQLQueryResultMode; +const Data = @import("../shared/Data.zig").Data; +const PostgresSQLQueryResultMode = @import("../shared/SQLQueryResultMode.zig").SQLQueryResultMode; const types = @import("./PostgresTypes.zig"); const AnyPostgresError = types.AnyPostgresError; diff --git a/src/sql/postgres/PostgresProtocol.zig b/src/sql/postgres/PostgresProtocol.zig index 49427252ad..20e6cd2190 100644 --- a/src/sql/postgres/PostgresProtocol.zig +++ b/src/sql/postgres/PostgresProtocol.zig @@ -45,7 +45,7 @@ pub const SASLResponse = @import("./protocol/SASLResponse.zig"); pub const StackReader = @import("./protocol/StackReader.zig"); pub const StartupMessage = @import("./protocol/StartupMessage.zig"); pub const Authentication = @import("./protocol/Authentication.zig").Authentication; -pub const ColumnIdentifier = @import("./protocol/ColumnIdentifier.zig").ColumnIdentifier; +pub const ColumnIdentifier = @import("../shared/ColumnIdentifier.zig").ColumnIdentifier; pub const DecoderWrap = @import("./protocol/DecoderWrap.zig").DecoderWrap; pub const FieldMessage = @import("./protocol/FieldMessage.zig").FieldMessage; pub const FieldType = @import("./protocol/FieldType.zig").FieldType; diff --git a/src/sql/postgres/PostgresRequest.zig b/src/sql/postgres/PostgresRequest.zig index c302874e28..7f7800c49a 100644 --- a/src/sql/postgres/PostgresRequest.zig +++ b/src/sql/postgres/PostgresRequest.zig @@ -332,7 +332,7 @@ const PostgresSQLStatement = @import("./PostgresSQLStatement.zig"); const Signature = @import("./Signature.zig"); const protocol = @import("./PostgresProtocol.zig"); const std = @import("std"); -const QueryBindingIterator = @import("./QueryBindingIterator.zig").QueryBindingIterator; +const QueryBindingIterator = @import("../shared/QueryBindingIterator.zig").QueryBindingIterator; const types = @import("./PostgresTypes.zig"); const AnyPostgresError = @import("./PostgresTypes.zig").AnyPostgresError; diff --git a/src/sql/postgres/PostgresSQLConnection.zig b/src/sql/postgres/PostgresSQLConnection.zig index 483945ceba..5c394074d5 100644 --- a/src/sql/postgres/PostgresSQLConnection.zig +++ b/src/sql/postgres/PostgresSQLConnection.zig @@ -311,7 +311,7 @@ pub fn failWithJSValue(this: *PostgresSQLConnection, value: JSValue) void { this.stopTimers(); if (this.status == .failed) return; - this.status = .failed; + this.setStatus(.failed); this.ref(); defer this.deref(); @@ -584,7 +584,7 @@ comptime { pub fn call(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { var vm = globalObject.bunVM(); - const arguments = callframe.arguments_old(15).slice(); + const arguments = callframe.arguments(); const hostname_str = try arguments[0].toBunString(globalObject); defer hostname_str.deref(); const port = try arguments[1].coerce(i32, globalObject); @@ -700,7 +700,7 @@ pub fn call(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JS ptr.* = PostgresSQLConnection{ .globalObject = globalObject, - .vm = globalObject.bunVM(), + .vm = vm, .database = database, .user = username, .password = password, @@ -1157,7 +1157,9 @@ fn advance(this: *PostgresSQLConnection) void { } else { // deinit later req.status = .fail; + offset += 1; } + continue; }, .prepared => { @@ -1185,9 +1187,9 @@ fn advance(this: *PostgresSQLConnection) void { } else { // deinit later req.status = .fail; + offset += 1; } debug("bind and execute failed: {s}", .{@errorName(err)}); - continue; }; @@ -1356,8 +1358,8 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera .globalObject = this.globalObject, }; - var stack_buf: [70]DataCell = undefined; - var cells: []DataCell = stack_buf[0..@min(statement.fields.len, jsc.JSObject.maxInlineCapacity())]; + var stack_buf: [70]DataCell.SQLDataCell = undefined; + var cells: []DataCell.SQLDataCell = stack_buf[0..@min(statement.fields.len, jsc.JSObject.maxInlineCapacity())]; var free_cells = false; defer { for (cells[0..putter.count]) |*cell| { @@ -1367,11 +1369,11 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera } if (statement.fields.len >= jsc.JSObject.maxInlineCapacity()) { - cells = try bun.default_allocator.alloc(DataCell, statement.fields.len); + cells = try bun.default_allocator.alloc(DataCell.SQLDataCell, statement.fields.len); free_cells = true; } // make sure all cells are reset if reader short breaks the fields will just be null with is better than undefined behavior - @memset(cells, DataCell{ .tag = .null, .value = .{ .null = 0 } }); + @memset(cells, DataCell.SQLDataCell{ .tag = .null, .value = .{ .null = 0 } }); putter.list = cells; if (request.flags.result_mode == .raw) { @@ -1395,7 +1397,14 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera }; const pending_value = PostgresSQLQuery.js.pendingValueGetCached(thisValue) orelse .zero; pending_value.ensureStillAlive(); - const result = putter.toJS(this.globalObject, pending_value, structure, statement.fields_flags, request.flags.result_mode, cached_structure); + const result = putter.toJS( + this.globalObject, + pending_value, + structure, + statement.fields_flags, + request.flags.result_mode, + cached_structure, + ); if (pending_value == .zero) { PostgresSQLQuery.js.pendingValueSetCached(thisValue, this.globalObject, result); @@ -1814,7 +1823,8 @@ pub const fromJS = js.fromJS; pub const fromJSDirect = js.fromJSDirect; pub const toJS = js.toJS; -const PostgresCachedStructure = @import("./PostgresCachedStructure.zig"); +const DataCell = @import("./DataCell.zig"); +const PostgresCachedStructure = @import("../shared/CachedStructure.zig"); const PostgresRequest = @import("./PostgresRequest.zig"); const PostgresSQLQuery = @import("./PostgresSQLQuery.zig"); const PostgresSQLStatement = @import("./PostgresSQLStatement.zig"); @@ -1822,9 +1832,8 @@ const SocketMonitor = @import("./SocketMonitor.zig"); const protocol = @import("./PostgresProtocol.zig"); const std = @import("std"); const AuthenticationState = @import("./AuthenticationState.zig").AuthenticationState; -const ConnectionFlags = @import("./ConnectionFlags.zig").ConnectionFlags; -const Data = @import("./Data.zig").Data; -const DataCell = @import("./DataCell.zig").DataCell; +const ConnectionFlags = @import("../shared/ConnectionFlags.zig").ConnectionFlags; +const Data = @import("../shared/Data.zig").Data; const SSLMode = @import("./SSLMode.zig").SSLMode; const Status = @import("./Status.zig").Status; const TLSStatus = @import("./TLSStatus.zig").TLSStatus; diff --git a/src/sql/postgres/PostgresSQLQuery.zig b/src/sql/postgres/PostgresSQLQuery.zig index c1b3cedbc0..35b1af4906 100644 --- a/src/sql/postgres/PostgresSQLQuery.zig +++ b/src/sql/postgres/PostgresSQLQuery.zig @@ -186,7 +186,7 @@ pub fn estimatedSize(this: *PostgresSQLQuery) usize { } pub fn call(globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { - const arguments = callframe.arguments_old(6).slice(); + const arguments = callframe.arguments(); var args = jsc.CallFrame.ArgumentsSlice.init(globalThis.bunVM(), arguments); defer args.deinit(); const query = args.nextEat() orelse { @@ -276,8 +276,7 @@ pub fn setMode(this: *PostgresSQLQuery, globalObject: *jsc.JSGlobalObject, callf } pub fn doRun(this: *PostgresSQLQuery, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { - var arguments_ = callframe.arguments_old(2); - const arguments = arguments_.slice(); + var arguments = callframe.arguments(); const connection: *PostgresSQLConnection = arguments[0].as(PostgresSQLConnection) orelse { return globalObject.throw("connection must be a PostgresSQLConnection", .{}); }; @@ -375,11 +374,10 @@ pub fn doRun(this: *PostgresSQLQuery, globalObject: *jsc.JSGlobalObject, callfra switch (stmt.status) { .failed => { this.statement = null; + const error_response = try stmt.error_response.?.toJS(globalObject); stmt.deref(); this.deref(); - // If the statement failed, we need to throw the error - const e = try this.statement.?.error_response.?.toJS(globalObject); - return globalObject.throwValue(e); + return globalObject.throwValue(error_response); }, .prepared => { if (!connection.hasQueryRunning() or connection.canPipeline()) { @@ -524,7 +522,7 @@ const bun = @import("bun"); const protocol = @import("./PostgresProtocol.zig"); const std = @import("std"); const CommandTag = @import("./CommandTag.zig").CommandTag; -const PostgresSQLQueryResultMode = @import("./PostgresSQLQueryResultMode.zig").PostgresSQLQueryResultMode; +const PostgresSQLQueryResultMode = @import("../shared/SQLQueryResultMode.zig").SQLQueryResultMode; const AnyPostgresError = @import("./AnyPostgresError.zig").AnyPostgresError; const postgresErrorToJS = @import("./AnyPostgresError.zig").postgresErrorToJS; diff --git a/src/sql/postgres/PostgresSQLQueryResultMode.zig b/src/sql/postgres/PostgresSQLQueryResultMode.zig deleted file mode 100644 index 2744cb61e2..0000000000 --- a/src/sql/postgres/PostgresSQLQueryResultMode.zig +++ /dev/null @@ -1,5 +0,0 @@ -pub const PostgresSQLQueryResultMode = enum(u2) { - objects = 0, - values = 1, - raw = 2, -}; diff --git a/src/sql/postgres/PostgresSQLStatement.zig b/src/sql/postgres/PostgresSQLStatement.zig index 1026d86b22..5604cf3106 100644 --- a/src/sql/postgres/PostgresSQLStatement.zig +++ b/src/sql/postgres/PostgresSQLStatement.zig @@ -162,11 +162,11 @@ pub fn structure(this: *PostgresSQLStatement, owner: JSValue, globalObject: *jsc const debug = bun.Output.scoped(.Postgres, .visible); -const PostgresCachedStructure = @import("./PostgresCachedStructure.zig"); +const PostgresCachedStructure = @import("../shared/CachedStructure.zig"); const Signature = @import("./Signature.zig"); const protocol = @import("./PostgresProtocol.zig"); const std = @import("std"); -const DataCell = @import("./DataCell.zig").DataCell; +const DataCell = @import("./DataCell.zig").SQLDataCell; const AnyPostgresError = @import("./AnyPostgresError.zig").AnyPostgresError; const postgresErrorToJS = @import("./AnyPostgresError.zig").postgresErrorToJS; diff --git a/src/sql/postgres/Signature.zig b/src/sql/postgres/Signature.zig index 53e74a3677..0918996f7b 100644 --- a/src/sql/postgres/Signature.zig +++ b/src/sql/postgres/Signature.zig @@ -103,7 +103,7 @@ pub fn generate(globalObject: *jsc.JSGlobalObject, query: []const u8, array_valu const bun = @import("bun"); const std = @import("std"); -const QueryBindingIterator = @import("./QueryBindingIterator.zig").QueryBindingIterator; +const QueryBindingIterator = @import("../shared/QueryBindingIterator.zig").QueryBindingIterator; const types = @import("./PostgresTypes.zig"); const int4 = types.int4; diff --git a/src/sql/postgres/SocketMonitor.zig b/src/sql/postgres/SocketMonitor.zig index 988b334fe9..c9db858509 100644 --- a/src/sql/postgres/SocketMonitor.zig +++ b/src/sql/postgres/SocketMonitor.zig @@ -1,4 +1,5 @@ pub fn write(data: []const u8) void { + debug("SocketMonitor: write {s}", .{std.fmt.fmtSliceHexLower(data)}); if (comptime bun.Environment.isDebug) { DebugSocketMonitorWriter.check.call(); if (DebugSocketMonitorWriter.enabled) { @@ -8,6 +9,7 @@ pub fn write(data: []const u8) void { } pub fn read(data: []const u8) void { + debug("SocketMonitor: read {s}", .{std.fmt.fmtSliceHexLower(data)}); if (comptime bun.Environment.isDebug) { DebugSocketMonitorReader.check.call(); if (DebugSocketMonitorReader.enabled) { @@ -16,6 +18,9 @@ pub fn read(data: []const u8) void { } } +const debug = bun.Output.scoped(.SocketMonitor, .visible); + const DebugSocketMonitorReader = @import("./DebugSocketMonitorReader.zig"); const DebugSocketMonitorWriter = @import("./DebugSocketMonitorWriter.zig"); const bun = @import("bun"); +const std = @import("std"); diff --git a/src/sql/postgres/protocol/Authentication.zig b/src/sql/postgres/protocol/Authentication.zig index 306e08b14d..f567a5bbbd 100644 --- a/src/sql/postgres/protocol/Authentication.zig +++ b/src/sql/postgres/protocol/Authentication.zig @@ -175,6 +175,6 @@ const debug = bun.Output.scoped(.Postgres, .hidden); const bun = @import("bun"); const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/CommandComplete.zig b/src/sql/postgres/protocol/CommandComplete.zig index 36ab1b2f81..fa554f7666 100644 --- a/src/sql/postgres/protocol/CommandComplete.zig +++ b/src/sql/postgres/protocol/CommandComplete.zig @@ -19,6 +19,6 @@ pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReade pub const decode = DecoderWrap(CommandComplete, decodeInternal).decode; const bun = @import("bun"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/CopyData.zig b/src/sql/postgres/protocol/CopyData.zig index 938889266b..ca26782a8d 100644 --- a/src/sql/postgres/protocol/CopyData.zig +++ b/src/sql/postgres/protocol/CopyData.zig @@ -30,7 +30,7 @@ pub fn writeInternal( pub const write = WriteWrap(@This(), writeInternal).write; const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; const Int32 = @import("../types/int_types.zig").Int32; const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/CopyFail.zig b/src/sql/postgres/protocol/CopyFail.zig index 1a08cc6340..4904346662 100644 --- a/src/sql/postgres/protocol/CopyFail.zig +++ b/src/sql/postgres/protocol/CopyFail.zig @@ -30,7 +30,7 @@ pub fn writeInternal( pub const write = WriteWrap(@This(), writeInternal).write; const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; const NewReader = @import("./NewReader.zig").NewReader; const NewWriter = @import("./NewWriter.zig").NewWriter; diff --git a/src/sql/postgres/protocol/DataRow.zig b/src/sql/postgres/protocol/DataRow.zig index e1744246d8..bbb71ce5c9 100644 --- a/src/sql/postgres/protocol/DataRow.zig +++ b/src/sql/postgres/protocol/DataRow.zig @@ -24,8 +24,8 @@ pub fn decode(context: anytype, comptime ContextType: type, reader: NewReader(Co pub const null_int4 = 4294967295; +const Data = @import("../../shared/Data.zig").Data; + const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; -const Data = @import("../Data.zig").Data; - const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/FieldDescription.zig b/src/sql/postgres/protocol/FieldDescription.zig index eb159c981c..ccedc65fb4 100644 --- a/src/sql/postgres/protocol/FieldDescription.zig +++ b/src/sql/postgres/protocol/FieldDescription.zig @@ -60,7 +60,7 @@ pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReade pub const decode = DecoderWrap(FieldDescription, decodeInternal).decode; const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; -const ColumnIdentifier = @import("./ColumnIdentifier.zig").ColumnIdentifier; +const ColumnIdentifier = @import("../../shared/ColumnIdentifier.zig").ColumnIdentifier; const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/NewReader.zig b/src/sql/postgres/protocol/NewReader.zig index 5832f65953..8fc1e22c68 100644 --- a/src/sql/postgres/protocol/NewReader.zig +++ b/src/sql/postgres/protocol/NewReader.zig @@ -113,7 +113,7 @@ pub fn NewReader(comptime Context: type) type { const bun = @import("bun"); const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const int_types = @import("../types/int_types.zig"); const PostgresInt32 = int_types.PostgresInt32; diff --git a/src/sql/postgres/protocol/ParameterStatus.zig b/src/sql/postgres/protocol/ParameterStatus.zig index adb4b9d131..a74c0e89f8 100644 --- a/src/sql/postgres/protocol/ParameterStatus.zig +++ b/src/sql/postgres/protocol/ParameterStatus.zig @@ -21,6 +21,6 @@ pub fn decodeInternal(this: *@This(), comptime Container: type, reader: NewReade pub const decode = DecoderWrap(ParameterStatus, decodeInternal).decode; const bun = @import("bun"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const DecoderWrap = @import("./DecoderWrap.zig").DecoderWrap; const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/PasswordMessage.zig b/src/sql/postgres/protocol/PasswordMessage.zig index 1a4c141856..c9c71194c6 100644 --- a/src/sql/postgres/protocol/PasswordMessage.zig +++ b/src/sql/postgres/protocol/PasswordMessage.zig @@ -23,7 +23,7 @@ pub fn writeInternal( pub const write = WriteWrap(@This(), writeInternal).write; const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const Int32 = @import("../types/int_types.zig").Int32; const NewWriter = @import("./NewWriter.zig").NewWriter; const WriteWrap = @import("./WriteWrap.zig").WriteWrap; diff --git a/src/sql/postgres/protocol/SASLInitialResponse.zig b/src/sql/postgres/protocol/SASLInitialResponse.zig index 2558c211f8..ce9ca6fe53 100644 --- a/src/sql/postgres/protocol/SASLInitialResponse.zig +++ b/src/sql/postgres/protocol/SASLInitialResponse.zig @@ -28,7 +28,7 @@ pub fn writeInternal( pub const write = WriteWrap(@This(), writeInternal).write; const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const Int32 = @import("../types/int_types.zig").Int32; const NewWriter = @import("./NewWriter.zig").NewWriter; const WriteWrap = @import("./WriteWrap.zig").WriteWrap; diff --git a/src/sql/postgres/protocol/SASLResponse.zig b/src/sql/postgres/protocol/SASLResponse.zig index 04c2d33afd..3a1b0d88ce 100644 --- a/src/sql/postgres/protocol/SASLResponse.zig +++ b/src/sql/postgres/protocol/SASLResponse.zig @@ -23,7 +23,7 @@ pub fn writeInternal( pub const write = WriteWrap(@This(), writeInternal).write; const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const Int32 = @import("../types/int_types.zig").Int32; const NewWriter = @import("./NewWriter.zig").NewWriter; const WriteWrap = @import("./WriteWrap.zig").WriteWrap; diff --git a/src/sql/postgres/protocol/StackReader.zig b/src/sql/postgres/protocol/StackReader.zig index 06ca3a7cd4..a540c5da4b 100644 --- a/src/sql/postgres/protocol/StackReader.zig +++ b/src/sql/postgres/protocol/StackReader.zig @@ -61,5 +61,5 @@ pub fn readZ(this: StackReader) AnyPostgresError!Data { const bun = @import("bun"); const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const NewReader = @import("./NewReader.zig").NewReader; diff --git a/src/sql/postgres/protocol/StartupMessage.zig b/src/sql/postgres/protocol/StartupMessage.zig index c70f8c5b26..0115e4a3ba 100644 --- a/src/sql/postgres/protocol/StartupMessage.zig +++ b/src/sql/postgres/protocol/StartupMessage.zig @@ -39,7 +39,7 @@ pub fn writeInternal( pub const write = WriteWrap(@This(), writeInternal).write; const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const NewWriter = @import("./NewWriter.zig").NewWriter; const WriteWrap = @import("./WriteWrap.zig").WriteWrap; const zFieldCount = @import("./zHelpers.zig").zFieldCount; diff --git a/src/sql/postgres/types/PostgresString.zig b/src/sql/postgres/types/PostgresString.zig index 4ca1c822ec..8d6caed69d 100644 --- a/src/sql/postgres/types/PostgresString.zig +++ b/src/sql/postgres/types/PostgresString.zig @@ -41,7 +41,7 @@ pub fn toJS( const bun = @import("bun"); const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const int_types = @import("./int_types.zig"); const short = int_types.short; diff --git a/src/sql/postgres/types/bytea.zig b/src/sql/postgres/types/bytea.zig index 42e453a2b2..8366ceacc3 100644 --- a/src/sql/postgres/types/bytea.zig +++ b/src/sql/postgres/types/bytea.zig @@ -14,7 +14,7 @@ pub fn toJS( const bun = @import("bun"); const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const int_types = @import("./int_types.zig"); const short = int_types.short; diff --git a/src/sql/postgres/types/date.zig b/src/sql/postgres/types/date.zig index 95be95e48d..8a5ec36144 100644 --- a/src/sql/postgres/types/date.zig +++ b/src/sql/postgres/types/date.zig @@ -46,7 +46,7 @@ pub fn toJS( const bun = @import("bun"); const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const int_types = @import("./int_types.zig"); const short = int_types.short; diff --git a/src/sql/postgres/types/json.zig b/src/sql/postgres/types/json.zig index 14aad5fbe5..de5cf9be84 100644 --- a/src/sql/postgres/types/json.zig +++ b/src/sql/postgres/types/json.zig @@ -18,7 +18,7 @@ pub fn toJS( const bun = @import("bun"); const AnyPostgresError = @import("../AnyPostgresError.zig").AnyPostgresError; -const Data = @import("../Data.zig").Data; +const Data = @import("../../shared/Data.zig").Data; const int_types = @import("./int_types.zig"); const short = int_types.short; diff --git a/src/sql/postgres/PostgresCachedStructure.zig b/src/sql/shared/CachedStructure.zig similarity index 100% rename from src/sql/postgres/PostgresCachedStructure.zig rename to src/sql/shared/CachedStructure.zig diff --git a/src/sql/postgres/protocol/ColumnIdentifier.zig b/src/sql/shared/ColumnIdentifier.zig similarity index 95% rename from src/sql/postgres/protocol/ColumnIdentifier.zig rename to src/sql/shared/ColumnIdentifier.zig index 53e778b92f..48d5f4c03b 100644 --- a/src/sql/postgres/protocol/ColumnIdentifier.zig +++ b/src/sql/shared/ColumnIdentifier.zig @@ -35,4 +35,4 @@ pub const ColumnIdentifier = union(enum) { }; const std = @import("std"); -const Data = @import("../Data.zig").Data; +const Data = @import("../shared/Data.zig").Data; diff --git a/src/sql/postgres/ConnectionFlags.zig b/src/sql/shared/ConnectionFlags.zig similarity index 100% rename from src/sql/postgres/ConnectionFlags.zig rename to src/sql/shared/ConnectionFlags.zig diff --git a/src/sql/postgres/Data.zig b/src/sql/shared/Data.zig similarity index 52% rename from src/sql/postgres/Data.zig rename to src/sql/shared/Data.zig index ec2f5478a0..f94d5791c3 100644 --- a/src/sql/postgres/Data.zig +++ b/src/sql/shared/Data.zig @@ -1,15 +1,32 @@ +// Represents data that can be either owned or temporary pub const Data = union(enum) { owned: bun.ByteList, temporary: []const u8, + inline_storage: std.BoundedArray(u8, 15), empty: void, pub const Empty: Data = .{ .empty = {} }; + pub fn create(possibly_inline_bytes: []const u8, allocator: std.mem.Allocator) !Data { + if (possibly_inline_bytes.len == 0) { + return .{ .empty = {} }; + } + + if (possibly_inline_bytes.len <= 15) { + var inline_storage = std.BoundedArray(u8, 15){}; + @memcpy(inline_storage.buffer[0..possibly_inline_bytes.len], possibly_inline_bytes); + inline_storage.len = @truncate(possibly_inline_bytes.len); + return .{ .inline_storage = inline_storage }; + } + return .{ .owned = bun.ByteList.init(try allocator.dupe(u8, possibly_inline_bytes)) }; + } + pub fn toOwned(this: @This()) !bun.ByteList { return switch (this) { .owned => this.owned, .temporary => bun.ByteList.init(try bun.default_allocator.dupe(u8, this.temporary)), .empty => bun.ByteList.init(&.{}), + .inline_storage => bun.ByteList.init(try bun.default_allocator.dupe(u8, this.inline_storage.slice())), }; } @@ -18,6 +35,7 @@ pub const Data = union(enum) { .owned => this.owned.deinitWithAllocator(bun.default_allocator), .temporary => {}, .empty => {}, + .inline_storage => {}, } } @@ -34,32 +52,37 @@ pub const Data = union(enum) { }, .temporary => {}, .empty => {}, + .inline_storage => {}, } } - pub fn slice(this: @This()) []const u8 { - return switch (this) { + pub fn slice(this: *const @This()) []const u8 { + return switch (this.*) { .owned => this.owned.slice(), .temporary => this.temporary, .empty => "", + .inline_storage => this.inline_storage.slice(), }; } - pub fn substring(this: @This(), start_index: usize, end_index: usize) Data { - return switch (this) { + pub fn substring(this: *const @This(), start_index: usize, end_index: usize) Data { + return switch (this.*) { .owned => .{ .temporary = this.owned.slice()[start_index..end_index] }, .temporary => .{ .temporary = this.temporary[start_index..end_index] }, .empty => .{ .empty = {} }, + .inline_storage => .{ .temporary = this.inline_storage.slice()[start_index..end_index] }, }; } - pub fn sliceZ(this: @This()) [:0]const u8 { - return switch (this) { + pub fn sliceZ(this: *const @This()) [:0]const u8 { + return switch (this.*) { .owned => this.owned.slice()[0..this.owned.len :0], .temporary => this.temporary[0..this.temporary.len :0], .empty => "", + .inline_storage => this.inline_storage.slice()[0..this.inline_storage.len :0], }; } }; const bun = @import("bun"); +const std = @import("std"); diff --git a/src/sql/postgres/ObjectIterator.zig b/src/sql/shared/ObjectIterator.zig similarity index 100% rename from src/sql/postgres/ObjectIterator.zig rename to src/sql/shared/ObjectIterator.zig diff --git a/src/sql/postgres/QueryBindingIterator.zig b/src/sql/shared/QueryBindingIterator.zig similarity index 100% rename from src/sql/postgres/QueryBindingIterator.zig rename to src/sql/shared/QueryBindingIterator.zig diff --git a/src/sql/shared/SQLDataCell.zig b/src/sql/shared/SQLDataCell.zig new file mode 100644 index 0000000000..1cf73d6edb --- /dev/null +++ b/src/sql/shared/SQLDataCell.zig @@ -0,0 +1,161 @@ +pub const SQLDataCell = extern struct { + tag: Tag, + + value: Value, + free_value: u8 = 0, + isIndexedColumn: u8 = 0, + index: u32 = 0, + + pub const Tag = enum(u8) { + null = 0, + string = 1, + float8 = 2, + int4 = 3, + int8 = 4, + bool = 5, + date = 6, + date_with_time_zone = 7, + bytea = 8, + json = 9, + array = 10, + typed_array = 11, + raw = 12, + uint4 = 13, + uint8 = 14, + }; + + pub const Value = extern union { + null: u8, + string: ?bun.WTF.StringImpl, + float8: f64, + int4: i32, + int8: i64, + bool: u8, + date: f64, + date_with_time_zone: f64, + bytea: [2]usize, + json: ?bun.WTF.StringImpl, + array: Array, + typed_array: TypedArray, + raw: Raw, + uint4: u32, + uint8: u64, + }; + + pub const Array = extern struct { + ptr: ?[*]SQLDataCell = null, + len: u32, + cap: u32, + pub fn slice(this: *Array) []SQLDataCell { + const ptr = this.ptr orelse return &.{}; + return ptr[0..this.len]; + } + + pub fn allocatedSlice(this: *Array) []SQLDataCell { + const ptr = this.ptr orelse return &.{}; + return ptr[0..this.cap]; + } + + pub fn deinit(this: *Array) void { + const allocated = this.allocatedSlice(); + this.ptr = null; + this.len = 0; + this.cap = 0; + bun.default_allocator.free(allocated); + } + }; + pub const Raw = extern struct { + ptr: ?[*]const u8 = null, + len: u64, + }; + pub const TypedArray = extern struct { + head_ptr: ?[*]u8 = null, + ptr: ?[*]u8 = null, + len: u32, + byte_len: u32, + type: JSValue.JSType, + + pub fn slice(this: *TypedArray) []u8 { + const ptr = this.ptr orelse return &.{}; + return ptr[0..this.len]; + } + + pub fn byteSlice(this: *TypedArray) []u8 { + const ptr = this.head_ptr orelse return &.{}; + return ptr[0..this.len]; + } + }; + + pub fn deinit(this: *SQLDataCell) void { + if (this.free_value == 0) return; + + switch (this.tag) { + .string => { + if (this.value.string) |str| { + str.deref(); + } + }, + .json => { + if (this.value.json) |str| { + str.deref(); + } + }, + .bytea => { + if (this.value.bytea[1] == 0) return; + const slice = @as([*]u8, @ptrFromInt(this.value.bytea[0]))[0..this.value.bytea[1]]; + bun.default_allocator.free(slice); + }, + .array => { + for (this.value.array.slice()) |*cell| { + cell.deinit(); + } + this.value.array.deinit(); + }, + .typed_array => { + bun.default_allocator.free(this.value.typed_array.byteSlice()); + }, + + else => {}, + } + } + + pub fn raw(optional_bytes: ?*const Data) SQLDataCell { + if (optional_bytes) |bytes| { + const bytes_slice = bytes.slice(); + return SQLDataCell{ + .tag = .raw, + .value = .{ .raw = .{ .ptr = @ptrCast(bytes_slice.ptr), .len = bytes_slice.len } }, + }; + } + // TODO: check empty and null fields + return SQLDataCell{ + .tag = .null, + .value = .{ .null = 0 }, + }; + } + + pub const Flags = packed struct(u32) { + has_indexed_columns: bool = false, + has_named_columns: bool = false, + has_duplicate_columns: bool = false, + _: u29 = 0, + }; + + pub extern fn JSC__constructObjectFromDataCell( + *jsc.JSGlobalObject, + JSValue, + JSValue, + [*]SQLDataCell, + u32, + SQLDataCell.Flags, + u8, // result_mode + ?[*]jsc.JSObject.ExternColumnIdentifier, // names + u32, // names count + ) JSValue; +}; + +const bun = @import("bun"); +const Data = @import("./Data.zig").Data; + +const jsc = bun.jsc; +const JSValue = jsc.JSValue; diff --git a/src/sql/shared/SQLQueryResultMode.zig b/src/sql/shared/SQLQueryResultMode.zig new file mode 100644 index 0000000000..c584dab46e --- /dev/null +++ b/src/sql/shared/SQLQueryResultMode.zig @@ -0,0 +1,5 @@ +pub const SQLQueryResultMode = enum(u2) { + objects = 0, + values = 1, + raw = 2, +}; diff --git a/test/integration/bun-types/fixture/sql.ts b/test/integration/bun-types/fixture/sql.ts index 9128c3a708..ccac825fd6 100644 --- a/test/integration/bun-types/fixture/sql.ts +++ b/test/integration/bun-types/fixture/sql.ts @@ -271,5 +271,5 @@ expectType>(); // check some types exist expectType>; expectType; -expectType; +expectType; expectType>; diff --git a/test/internal/ban-limits.json b/test/internal/ban-limits.json index 4e4318a1d2..5ba0f7e51b 100644 --- a/test/internal/ban-limits.json +++ b/test/internal/ban-limits.json @@ -3,13 +3,13 @@ " == undefined": 0, "!= alloc.ptr": 0, "!= allocator.ptr": 0, - ".arguments_old(": 279, + ".arguments_old(": 276, ".jsBoolean(false)": 0, ".jsBoolean(true)": 0, ".stdDir()": 41, ".stdFile()": 18, "// autofix": 168, - ": [^=]+= undefined,$": 260, + ": [^=]+= undefined,$": 261, "== alloc.ptr": 0, "== allocator.ptr": 0, "@import(\"bun\").": 0, @@ -21,7 +21,7 @@ "allocator.ptr !=": 1, "allocator.ptr ==": 0, "global.hasException": 28, - "globalObject.hasException": 42, + "globalObject.hasException": 47, "globalThis.hasException": 133, "std.StringArrayHashMap(": 1, "std.StringArrayHashMapUnmanaged(": 12, diff --git a/test/js/sql/sql-mysql.helpers.test.ts b/test/js/sql/sql-mysql.helpers.test.ts new file mode 100644 index 0000000000..73aeccbf45 --- /dev/null +++ b/test/js/sql/sql-mysql.helpers.test.ts @@ -0,0 +1,124 @@ +import { SQL, randomUUIDv7 } from "bun"; +import { expect, test } from "bun:test"; +import { describeWithContainer } from "harness"; + +describeWithContainer( + "mysql", + { + image: "mysql:8", + env: { + MYSQL_ROOT_PASSWORD: "bun", + }, + }, + (port: number) => { + const options = { + url: `mysql://root:bun@localhost:${port}`, + max: 1, + bigint: true, + }; + test("insert helper", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })}`; + const result = await sql`SELECT * FROM ${sql(random_name)}`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("John"); + expect(result[0].age).toBe(30); + }); + test("update helper", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })}`; + await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id = 1`; + const result = await sql`SELECT * FROM ${sql(random_name)}`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("Mary"); + expect(result[0].age).toBe(18); + }); + + test("update helper with IN", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + const users = [ + { id: 1, name: "John", age: 30 }, + { id: 2, name: "Jane", age: 25 }, + ]; + await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`; + + await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id IN ${sql([1, 2])}`; + const result = await sql`SELECT * FROM ${sql(random_name)}`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("Mary"); + expect(result[0].age).toBe(18); + expect(result[1].id).toBe(2); + expect(result[1].name).toBe("Mary"); + expect(result[1].age).toBe(18); + }); + + test("update helper with IN and column name", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + const users = [ + { id: 1, name: "John", age: 30 }, + { id: 2, name: "Jane", age: 25 }, + ]; + await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`; + + await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id IN ${sql(users, "id")}`; + const result = await sql`SELECT * FROM ${sql(random_name)}`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("Mary"); + expect(result[0].age).toBe(18); + expect(result[1].id).toBe(2); + expect(result[1].name).toBe("Mary"); + expect(result[1].age).toBe(18); + }); + + test("update multiple values no helper", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })}`; + await sql`UPDATE ${sql(random_name)} SET ${sql("name")} = ${"Mary"}, ${sql("age")} = ${18} WHERE id = 1`; + const result = await sql`SELECT * FROM ${sql(random_name)} WHERE id = 1`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("Mary"); + expect(result[0].age).toBe(18); + }); + + test("SELECT with IN and NOT IN", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + const users = [ + { id: 1, name: "John", age: 30 }, + { id: 2, name: "Jane", age: 25 }, + ]; + await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`; + + const result = + await sql`SELECT * FROM ${sql(random_name)} WHERE id IN ${sql(users, "id")} and id NOT IN ${sql([3, 4, 5])}`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("John"); + expect(result[0].age).toBe(30); + expect(result[1].id).toBe(2); + expect(result[1].name).toBe("Jane"); + expect(result[1].age).toBe(25); + }); + + test("syntax error", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + const users = [ + { id: 1, name: "John", age: 30 }, + { id: 2, name: "Jane", age: 25 }, + ]; + + expect(() => sql`DELETE FROM ${sql(random_name)} ${sql(users, "id")}`.execute()).toThrow(SyntaxError); + }); + }, +); diff --git a/test/js/sql/sql-mysql.test.ts b/test/js/sql/sql-mysql.test.ts new file mode 100644 index 0000000000..b84f3fc488 --- /dev/null +++ b/test/js/sql/sql-mysql.test.ts @@ -0,0 +1,805 @@ +import { SQL, randomUUIDv7 } from "bun"; +import { describe, expect, mock, test } from "bun:test"; +import { describeWithContainer, tempDirWithFiles } from "harness"; +import net from "net"; +import path from "path"; +const dir = tempDirWithFiles("sql-test", { + "select-param.sql": `select ? as x`, + "select.sql": `select CAST(1 AS SIGNED) as x`, +}); +function rel(filename: string) { + return path.join(dir, filename); +} +describeWithContainer( + "mysql", + { + image: "mysql:8", + env: { + MYSQL_ROOT_PASSWORD: "bun", + }, + }, + (port: number) => { + const options = { + url: `mysql://root:bun@localhost:${port}`, + max: 1, + }; + const sql = new SQL(options); + describe("should work with more than the max inline capacity", () => { + for (let size of [50, 60, 62, 64, 70, 100]) { + for (let duplicated of [true, false]) { + test(`${size} ${duplicated ? "+ duplicated" : "unique"} fields`, async () => { + await using sql = new SQL(options); + const longQuery = `select ${Array.from({ length: size }, (_, i) => { + if (duplicated) { + return i % 2 === 0 ? `${i + 1} as f${i}, ${i} as f${i}` : `${i} as f${i}`; + } + return `${i} as f${i}`; + }).join(",\n")}`; + const result = await sql.unsafe(longQuery); + let value = 0; + for (const column of Object.values(result[0])) { + expect(column?.toString()).toEqual(value.toString()); + value++; + } + }); + } + } + }); + + test("Connection timeout works", async () => { + const onclose = mock(); + const onconnect = mock(); + await using sql = new SQL({ + ...options, + hostname: "example.com", + connection_timeout: 4, + onconnect, + onclose, + max: 1, + }); + let error: any; + try { + await sql`select SLEEP(8)`; + } catch (e) { + error = e; + } + expect(error.code).toBe(`ERR_MYSQL_CONNECTION_TIMEOUT`); + expect(error.message).toContain("Connection timeout after 4s"); + expect(onconnect).not.toHaveBeenCalled(); + expect(onclose).toHaveBeenCalledTimes(1); + }); + + test("Idle timeout works at start", async () => { + const onclose = mock(); + const onconnect = mock(); + await using sql = new SQL({ + ...options, + idle_timeout: 1, + onconnect, + onclose, + }); + let error: any; + try { + await sql`select SLEEP(2)`; + } catch (e) { + error = e; + } + expect(error.code).toBe(`ERR_MYSQL_IDLE_TIMEOUT`); + expect(onconnect).toHaveBeenCalled(); + expect(onclose).toHaveBeenCalledTimes(1); + }); + + test("Idle timeout is reset when a query is run", async () => { + const onClosePromise = Promise.withResolvers(); + const onclose = mock(err => { + onClosePromise.resolve(err); + }); + const onconnect = mock(); + await using sql = new SQL({ + ...options, + idle_timeout: 1, + onconnect, + onclose, + }); + expect(await sql`select 123 as x`).toEqual([{ x: 123 }]); + expect(onconnect).toHaveBeenCalledTimes(1); + expect(onclose).not.toHaveBeenCalled(); + const err = await onClosePromise.promise; + expect(err.code).toBe(`ERR_MYSQL_IDLE_TIMEOUT`); + }); + + test("Max lifetime works", async () => { + const onClosePromise = Promise.withResolvers(); + const onclose = mock(err => { + onClosePromise.resolve(err); + }); + const onconnect = mock(); + const sql = new SQL({ + ...options, + max_lifetime: 1, + onconnect, + onclose, + }); + let error: any; + expect(await sql`select 1 as x`).toEqual([{ x: 1 }]); + expect(onconnect).toHaveBeenCalledTimes(1); + try { + while (true) { + for (let i = 0; i < 100; i++) { + await sql`select SLEEP(1)`; + } + } + } catch (e) { + error = e; + } + + expect(onclose).toHaveBeenCalledTimes(1); + + expect(error.code).toBe(`ERR_MYSQL_LIFETIME_TIMEOUT`); + }); + + // Last one wins. + test("Handles duplicate string column names", async () => { + const result = await sql`select 1 as x, 2 as x, 3 as x`; + expect(result).toEqual([{ x: 3 }]); + }); + + test("should not timeout in long results", async () => { + await using db = new SQL({ ...options, max: 1, idleTimeout: 5 }); + using sql = await db.reserve(); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text)`; + const promises: Promise[] = []; + for (let i = 0; i < 10_000; i++) { + promises.push(sql`INSERT INTO ${sql(random_name)} VALUES (${i}, ${"test" + i})`); + if (i % 50 === 0 && i > 0) { + await Promise.all(promises); + promises.length = 0; + } + } + await Promise.all(promises); + await sql`SELECT * FROM ${sql(random_name)}`; + await sql`SELECT * FROM ${sql(random_name)}`; + await sql`SELECT * FROM ${sql(random_name)}`; + + expect().pass(); + }, 10_000); + + test("Handles numeric column names", async () => { + // deliberately out of order + const result = await sql`select 1 as "1", 2 as "2", 3 as "3", 0 as "0"`; + expect(result).toEqual([{ "1": 1, "2": 2, "3": 3, "0": 0 }]); + + expect(Object.keys(result[0])).toEqual(["0", "1", "2", "3"]); + // Sanity check: ensure iterating through the properties doesn't crash. + Bun.inspect(result); + }); + + // Last one wins. + test("Handles duplicate numeric column names", async () => { + const result = await sql`select 1 as "1", 2 as "1", 3 as "1"`; + expect(result).toEqual([{ "1": 3 }]); + // Sanity check: ensure iterating through the properties doesn't crash. + Bun.inspect(result); + }); + + test("Handles mixed column names", async () => { + const result = await sql`select 1 as "1", 2 as "2", 3 as "3", 4 as x`; + expect(result).toEqual([{ "1": 1, "2": 2, "3": 3, x: 4 }]); + // Sanity check: ensure iterating through the properties doesn't crash. + Bun.inspect(result); + }); + + test("Handles mixed column names with duplicates", async () => { + const result = await sql`select 1 as "1", 2 as "2", 3 as "3", 4 as "1", 1 as x, 2 as x`; + expect(result).toEqual([{ "1": 4, "2": 2, "3": 3, x: 2 }]); + // Sanity check: ensure iterating through the properties doesn't crash. + Bun.inspect(result); + + // Named columns are inserted first, but they appear from JS as last. + expect(Object.keys(result[0])).toEqual(["1", "2", "3", "x"]); + }); + + test("Handles mixed column names with duplicates at the end", async () => { + const result = await sql`select 1 as "1", 2 as "2", 3 as "3", 4 as "1", 1 as x, 2 as x, 3 as x, 4 as "y"`; + expect(result).toEqual([{ "1": 4, "2": 2, "3": 3, x: 3, y: 4 }]); + + // Sanity check: ensure iterating through the properties doesn't crash. + Bun.inspect(result); + }); + + test("Handles mixed column names with duplicates at the start", async () => { + const result = await sql`select 1 as "1", 2 as "1", 3 as "2", 4 as "3", 1 as x, 2 as x, 3 as x`; + expect(result).toEqual([{ "1": 2, "2": 3, "3": 4, x: 3 }]); + // Sanity check: ensure iterating through the properties doesn't crash. + Bun.inspect(result); + }); + + test("Uses default database without slash", async () => { + const sql = new SQL("mysql://localhost"); + expect("mysql").toBe(sql.options.database); + }); + + test("Uses default database with slash", async () => { + const sql = new SQL("mysql://localhost/"); + expect("mysql").toBe(sql.options.database); + }); + + test("Result is array", async () => { + expect(await sql`select 1`).toBeArray(); + }); + + test("Create table", async () => { + await sql`create table test(id int)`; + await sql`drop table test`; + }); + + test("Drop table", async () => { + await sql`create table test(id int)`; + await sql`drop table test`; + // Verify that table is dropped + const result = await sql`select * from information_schema.tables where table_name = 'test'`; + expect(result).toBeArrayOfSize(0); + }); + + test("null", async () => { + expect((await sql`select ${null} as x`)[0].x).toBeNull(); + }); + + test("Unsigned Integer", async () => { + expect((await sql`select ${0x7fffffff + 2} as x`)[0].x).toBe(2147483649); + }); + + test("Signed Integer", async () => { + expect((await sql`select ${-1} as x`)[0].x).toBe(-1); + expect((await sql`select ${1} as x`)[0].x).toBe(1); + }); + + test("Double", async () => { + expect((await sql`select ${1.123456789} as x`)[0].x).toBe(1.123456789); + }); + + test("String", async () => { + expect((await sql`select ${"hello"} as x`)[0].x).toBe("hello"); + }); + + test("Boolean", async () => { + // Protocol will always return 0 or 1 for TRUE and FALSE when not using a table. + expect((await sql`select ${false} as x`)[0].x).toBe(0); + expect((await sql`select ${true} as x`)[0].x).toBe(1); + const random_name = ("t_" + Bun.randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (a bool)`; + const values = [{ a: true }, { a: false }]; + await sql`INSERT INTO ${sql(random_name)} ${sql(values)}`; + const [[a], [b]] = await sql`select * from ${sql(random_name)}`.values(); + expect(a).toBe(true); + expect(b).toBe(false); + }); + + test("Date", async () => { + const now = new Date(); + const then = (await sql`select ${now} as x`)[0].x; + expect(then).toEqual(now); + }); + + test("Timestamp", async () => { + { + const result = (await sql`select DATE_ADD(FROM_UNIXTIME(0), INTERVAL -25 SECOND) as x`)[0].x; + expect(result.getTime()).toBe(-25000); + } + { + const result = (await sql`select DATE_ADD(FROM_UNIXTIME(0), INTERVAL 25 SECOND) as x`)[0].x; + expect(result.getSeconds()).toBe(25); + } + { + const result = (await sql`select DATE_ADD(FROM_UNIXTIME(0), INTERVAL 251000 MICROSECOND) as x`)[0].x; + expect(result.getMilliseconds()).toBe(251); + } + { + const result = (await sql`select DATE_ADD(FROM_UNIXTIME(0), INTERVAL -251000 MICROSECOND) as x`)[0].x; + expect(result.getTime()).toBe(-251); + } + }); + + test("JSON", async () => { + const x = (await sql`select CAST(${{ a: "hello", b: 42 }} AS JSON) as x`)[0].x; + expect(x).toEqual({ a: "hello", b: 42 }); + + const y = (await sql`select CAST('{"key": "value", "number": 123}' AS JSON) as x`)[0].x; + expect(y).toEqual({ key: "value", number: 123 }); + + const random_name = ("t_" + Bun.randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (a json)`; + const values = [{ a: { b: 1 } }, { a: { b: 2 } }]; + await sql`INSERT INTO ${sql(random_name)} ${sql(values)}`; + const [[a], [b]] = await sql`select * from ${sql(random_name)}`.values(); + expect(a).toEqual({ b: 1 }); + expect(b).toEqual({ b: 2 }); + }); + + test("bulk insert nested sql()", async () => { + await sql`create table users (name text, age int)`; + const users = [ + { name: "Alice", age: 25 }, + { name: "Bob", age: 30 }, + ]; + try { + await sql`insert into users ${sql(users)}`; + const result = await sql`select * from users`; + expect(result).toEqual([ + { name: "Alice", age: 25 }, + { name: "Bob", age: 30 }, + ]); + } finally { + await sql`drop table users`; + } + }); + + test("Escapes", async () => { + expect(Object.keys((await sql`select 1 as ${sql('hej"hej')}`)[0])[0]).toBe('hej"hej'); + }); + + test("null for int", async () => { + const result = await sql`create table test (x int)`; + expect(result.count).toBe(0); + try { + await sql`insert into test values(${null})`; + const result2 = await sql`select * from test`; + expect(result2).toEqual([{ x: null }]); + } finally { + await sql`drop table test`; + } + }); + + test("should be able to execute different queries in the same connection #16774", async () => { + const sql = new SQL({ ...options, max: 1 }); + const random_table_name = `test_user_${Math.random().toString(36).substring(2, 15)}`; + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_table_name)} (id int, name text)`; + + const promises: Array> = []; + // POPULATE TABLE + for (let i = 0; i < 1_000; i++) { + promises.push(sql`insert into ${sql(random_table_name)} values (${i}, ${`test${i}`})`.execute()); + } + await Promise.all(promises); + + // QUERY TABLE using execute() to force executing the query immediately + { + for (let i = 0; i < 1_000; i++) { + // mix different parameters + switch (i % 3) { + case 0: + promises.push(sql`select id, name from ${sql(random_table_name)} where id = ${i}`.execute()); + break; + case 1: + promises.push(sql`select id from ${sql(random_table_name)} where id = ${i}`.execute()); + break; + case 2: + promises.push(sql`select 1, id, name from ${sql(random_table_name)} where id = ${i}`.execute()); + break; + } + } + await Promise.all(promises); + } + }); + + test("Prepared transaction", async () => { + await using sql = new SQL(options); + await sql`create table test (a int)`; + + try { + await sql.beginDistributed("tx1", async sql => { + await sql`insert into test values(1)`; + }); + await sql.commitDistributed("tx1"); + expect((await sql`select count(*) from test`).count).toBe(1); + } finally { + await sql`drop table test`; + } + }); + + test("Idle timeout retry works", async () => { + await using sql = new SQL({ ...options, idleTimeout: 1 }); + await sql`select 1`; + await Bun.sleep(1100); // 1.1 seconds so it should retry + await sql`select 1`; + expect().pass(); + }); + + test("Fragments in transactions", async () => { + const sql = new SQL({ ...options, debug: true, idle_timeout: 1, fetch_types: false }); + expect((await sql.begin(sql => sql`select 1 as x where ${sql`1=1`}`))[0].x).toBe(1); + }); + + test("Helpers in Transaction", async () => { + const result = await sql.begin(async sql => await sql`select ${sql.unsafe("1 as x")}`); + expect(result[0].x).toBe(1); + }); + + test("Undefined values throws", async () => { + const result = await sql`select ${undefined} as x`; + expect(result[0].x).toBeNull(); + }); + + test("Null sets to null", async () => expect((await sql`select ${null} as x`)[0].x).toBeNull()); + + // Add code property. + test("Throw syntax error", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const err = await sql`wat 1`.catch(x => x); + expect(err.code).toBe("ERR_MYSQL_SYNTAX_ERROR"); + }); + + test("should work with fragments", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = sql("test_" + randomUUIDv7("hex").replaceAll("-", "")); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${random_name} (id int, hotel_id int, created_at timestamp)`; + await sql`INSERT INTO ${random_name} VALUES (1, 1, '2024-01-01 10:00:00')`; + // single escaped identifier + { + const results = await sql`SELECT * FROM ${random_name}`; + expect(results).toEqual([{ id: 1, hotel_id: 1, created_at: new Date("2024-01-01T10:00:00.000Z") }]); + } + // multiple escaped identifiers + { + const results = await sql`SELECT ${random_name}.* FROM ${random_name}`; + expect(results).toEqual([{ id: 1, hotel_id: 1, created_at: new Date("2024-01-01T10:00:00.000Z") }]); + } + // even more complex fragment + { + const results = + await sql`SELECT ${random_name}.* FROM ${random_name} WHERE ${random_name}.hotel_id = ${1} ORDER BY ${random_name}.created_at DESC`; + expect(results).toEqual([{ id: 1, hotel_id: 1, created_at: new Date("2024-01-01T10:00:00.000Z") }]); + } + }); + test("should handle nested fragments", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const random_name = sql("test_" + randomUUIDv7("hex").replaceAll("-", "")); + + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${random_name} (id int, hotel_id int, created_at timestamp)`; + await sql`INSERT INTO ${random_name} VALUES (1, 1, '2024-01-01 10:00:00')`; + await sql`INSERT INTO ${random_name} VALUES (2, 1, '2024-01-02 10:00:00')`; + await sql`INSERT INTO ${random_name} VALUES (3, 2, '2024-01-03 10:00:00')`; + + // fragment containing another scape fragment for the field name + const orderBy = (field_name: string) => sql`ORDER BY ${sql(field_name)} DESC`; + + // dynamic information + const sortBy = { should_sort: true, field: "created_at" }; + const user = { hotel_id: 1 }; + + // query containing the fragments + const results = await sql` + SELECT ${random_name}.* + FROM ${random_name} + WHERE ${random_name}.hotel_id = ${user.hotel_id} + ${sortBy.should_sort ? orderBy(sortBy.field) : sql``}`; + expect(results).toEqual([ + { id: 2, hotel_id: 1, created_at: new Date("2024-01-02T10:00:00.000Z") }, + { id: 1, hotel_id: 1, created_at: new Date("2024-01-01T10:00:00.000Z") }, + ]); + }); + + test("Support dynamic password function", async () => { + await using sql = new SQL({ ...options, password: () => "bun", max: 1 }); + return expect((await sql`select 1 as x`)[0].x).toBe(1); + }); + + test("Support dynamic async resolved password function", async () => { + await using sql = new SQL({ + ...options, + password: () => Promise.resolve("bun"), + max: 1, + }); + return expect((await sql`select 1 as x`)[0].x).toBe(1); + }); + + test("Support dynamic async password function", async () => { + await using sql = new SQL({ + ...options, + max: 1, + password: async () => { + await Bun.sleep(10); + return "bun"; + }, + }); + return expect((await sql`select 1 as x`)[0].x).toBe(1); + }); + test("Support dynamic async rejected password function", async () => { + await using sql = new SQL({ + ...options, + password: () => Promise.reject(new Error("password error")), + max: 1, + }); + try { + await sql`select true as x`; + expect.unreachable(); + } catch (e: any) { + expect(e.message).toBe("password error"); + } + }); + test("Support dynamic async password function that throws", async () => { + await using sql = new SQL({ + ...options, + max: 1, + password: async () => { + await Bun.sleep(10); + throw new Error("password error"); + }, + }); + try { + await sql`select true as x`; + expect.unreachable(); + } catch (e: any) { + expect(e).toBeInstanceOf(Error); + expect(e.message).toBe("password error"); + } + }); + test("sql file", async () => { + await using sql = new SQL(options); + expect((await sql.file(rel("select.sql")))[0].x).toBe(1); + }); + + test("sql file throws", async () => { + await using sql = new SQL(options); + expect(await sql.file(rel("selectomondo.sql")).catch(x => x.code)).toBe("ENOENT"); + }); + test("Parameters in file", async () => { + await using sql = new SQL(options); + const result = await sql.file(rel("select-param.sql"), ["hello"]); + return expect(result[0].x).toBe("hello"); + }); + + test("Connection ended promise", async () => { + const sql = new SQL(options); + + await sql.end(); + + expect(await sql.end()).toBeUndefined(); + }); + + test("Connection ended timeout", async () => { + const sql = new SQL(options); + + await sql.end({ timeout: 10 }); + + expect(await sql.end()).toBeUndefined(); + }); + + test("Connection ended error", async () => { + const sql = new SQL(options); + await sql.end(); + return expect(await sql``.catch(x => x.code)).toBe("ERR_MYSQL_CONNECTION_CLOSED"); + }); + + test("Connection end does not cancel query", async () => { + const sql = new SQL(options); + + const promise = sql`select SLEEP(1) as x`.execute(); + await sql.end(); + return expect(await promise).toEqual([{ x: 0 }]); + }); + + test("Connection destroyed", async () => { + const sql = new SQL(options); + process.nextTick(() => sql.end({ timeout: 0 })); + expect(await sql``.catch(x => x.code)).toBe("ERR_MYSQL_CONNECTION_CLOSED"); + }); + + test("Connection destroyed with query before", async () => { + const sql = new SQL(options); + const error = sql`select SLEEP(0.2)`.catch(err => err.code); + + sql.end({ timeout: 0 }); + return expect(await error).toBe("ERR_MYSQL_CONNECTION_CLOSED"); + }); + + test("unsafe", async () => { + await sql`create table test (x int)`; + try { + await sql.unsafe("insert into test values (?)", [1]); + const [{ x }] = await sql`select * from test`; + expect(x).toBe(1); + } finally { + await sql`drop table test`; + } + }); + + test("unsafe simple", async () => { + await using sql = new SQL({ ...options, max: 1 }); + expect(await sql.unsafe("select 1 as x")).toEqual([{ x: 1 }]); + }); + + test("simple query with multiple statements", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const result = await sql`select 1 as x;select 2 as x`.simple(); + expect(result).toBeDefined(); + expect(result.length).toEqual(2); + expect(result[0][0].x).toEqual(1); + expect(result[1][0].x).toEqual(2); + }); + + test("simple query using unsafe with multiple statements", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const result = await sql.unsafe("select 1 as x;select 2 as x"); + expect(result).toBeDefined(); + expect(result.length).toEqual(2); + expect(result[0][0].x).toEqual(1); + expect(result[1][0].x).toEqual(2); + }); + + test("only allows one statement", async () => { + expect(await sql`select 1; select 2`.catch(e => e.message)).toBe( + "You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'select 2' at line 1", + ); + }); + + test("await sql() throws not tagged error", async () => { + try { + await sql("select 1"); + expect.unreachable(); + } catch (e: any) { + expect(e.code).toBe("ERR_MYSQL_NOT_TAGGED_CALL"); + } + }); + + test("sql().then throws not tagged error", async () => { + try { + await sql("select 1").then(() => { + /* noop */ + }); + expect.unreachable(); + } catch (e: any) { + expect(e.code).toBe("ERR_MYSQL_NOT_TAGGED_CALL"); + } + }); + + test("sql().catch throws not tagged error", async () => { + try { + sql("select 1").catch(() => { + /* noop */ + }); + expect.unreachable(); + } catch (e: any) { + expect(e.code).toBe("ERR_MYSQL_NOT_TAGGED_CALL"); + } + }); + + test("sql().finally throws not tagged error", async () => { + try { + sql("select 1").finally(() => { + /* noop */ + }); + expect.unreachable(); + } catch (e: any) { + expect(e.code).toBe("ERR_MYSQL_NOT_TAGGED_CALL"); + } + }); + + test("little bobby tables", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const name = "Robert'); DROP TABLE students;--"; + + try { + await sql`create table students (name text, age int)`; + await sql`insert into students (name) values (${name})`; + + expect((await sql`select name from students`)[0].name).toBe(name); + } finally { + await sql`drop table students`; + } + }); + + test("Connection errors are caught using begin()", async () => { + let error; + try { + const sql = new SQL({ host: "localhost", port: 1, adapter: "mysql" }); + + await sql.begin(async sql => { + await sql`insert into test (label, value) values (${1}, ${2})`; + }); + } catch (err) { + error = err; + } + expect(error.code).toBe("ERR_MYSQL_CONNECTION_CLOSED"); + }); + + test("dynamic table name", async () => { + await using sql = new SQL({ ...options, max: 1 }); + await sql`create table test(a int)`; + try { + return expect((await sql`select * from ${sql("test")}`).length).toBe(0); + } finally { + await sql`drop table test`; + } + }); + + test("dynamic column name", async () => { + await using sql = new SQL({ ...options, max: 1 }); + const result = await sql`select 1 as ${sql("!not_valid")}`; + expect(Object.keys(result[0])[0]).toBe("!not_valid"); + }); + + test("dynamic insert", async () => { + await using sql = new SQL({ ...options, max: 1 }); + await sql`create table test (a int, b text)`; + try { + const x = { a: 42, b: "the answer" }; + await sql`insert into test ${sql(x)}`; + const [{ b }] = await sql`select * from test`; + expect(b).toBe("the answer"); + } finally { + await sql`drop table test`; + } + }); + + test("dynamic insert pluck", async () => { + await using sql = new SQL({ ...options, max: 1 }); + try { + await sql`create table test2 (a int, b text)`; + const x = { a: 42, b: "the answer" }; + await sql`insert into test2 ${sql(x, "a")}`; + const [{ b, a }] = await sql`select * from test2`; + expect(b).toBeNull(); + expect(a).toBe(42); + } finally { + await sql`drop table test2`; + } + }); + + test("bigint is returned as String", async () => { + await using sql = new SQL(options); + expect(typeof (await sql`select 9223372036854777 as x`)[0].x).toBe("string"); + }); + + test("bigint is returned as BigInt", async () => { + await using sql = new SQL({ + ...options, + bigint: true, + }); + expect((await sql`select 9223372036854777 as x`)[0].x).toBe(9223372036854777n); + }); + + test("int is returned as Number", async () => { + await using sql = new SQL(options); + expect((await sql`select CAST(123 AS SIGNED) as x`)[0].x).toBe(123); + }); + + test("flush should work", async () => { + await using sql = new SQL(options); + await sql`select 1`; + sql.flush(); + }); + + test.each(["connect_timeout", "connectTimeout", "connectionTimeout", "connection_timeout"] as const)( + "connection timeout key %p throws", + async key => { + const server = net.createServer().listen(); + + const port = (server.address() as import("node:net").AddressInfo).port; + + const sql = new SQL({ adapter: "mysql", port, host: "127.0.0.1", [key]: 0.2 }); + + try { + await sql`select 1`; + throw new Error("should not reach"); + } catch (e) { + expect(e).toBeInstanceOf(Error); + expect(e.code).toBe("ERR_MYSQL_CONNECTION_TIMEOUT"); + expect(e.message).toMatch(/Connection timed out after 200ms/); + } finally { + sql.close(); + server.close(); + } + }, + { + timeout: 1000, + }, + ); + test("Array returns rows as arrays of columns", async () => { + await using sql = new SQL(options); + return [(await sql`select CAST(1 AS SIGNED) as x`.values())[0][0], 1]; + }); + }, +); diff --git a/test/js/sql/sql-mysql.transactions.test.ts b/test/js/sql/sql-mysql.transactions.test.ts new file mode 100644 index 0000000000..e38c57faef --- /dev/null +++ b/test/js/sql/sql-mysql.transactions.test.ts @@ -0,0 +1,183 @@ +import { SQL, randomUUIDv7 } from "bun"; +import { expect, test } from "bun:test"; +import { describeWithContainer } from "harness"; + +describeWithContainer( + "mysql", + { + image: "mysql:8", + env: { + MYSQL_ROOT_PASSWORD: "bun", + }, + }, + (port: number) => { + const options = { + url: `mysql://root:bun@localhost:${port}`, + max: 1, + bigint: true, + }; + + test("Transaction works", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + + await sql.begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql`insert into ${sql(random_name)} values(2)`; + }); + + expect((await sql`select a from ${sql(random_name)}`).count).toBe(2); + await sql.close(); + }); + + test("Throws on illegal transactions", async () => { + await using sql = new SQL({ ...options, max: 2 }); + const error = await sql`BEGIN`.catch(e => e); + return expect(error.code).toBe("ERR_MYSQL_UNSAFE_TRANSACTION"); + }); + + test("Transaction throws", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql`insert into ${sql(random_name)} values('hej')`; + }) + .catch(e => e.message), + ).toBe("Incorrect integer value: 'hej' for column 'a' at row 1"); + }); + + test("Transaction rolls back", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + + await sql + .begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql`insert into ${sql(random_name)} values('hej')`; + }) + .catch(() => { + /* ignore */ + }); + + expect((await sql`select a from ${sql(random_name)}`).count).toBe(0); + }); + + test("Transaction throws on uncaught savepoint", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql.savepoint(async sql => { + await sql`insert into ${sql(random_name)} values(2)`; + throw new Error("fail"); + }); + }) + .catch(err => err.message), + ).toBe("fail"); + }); + + test("Transaction throws on uncaught named savepoint", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TEMPORARY TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + expect( + await sql + .begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql.savepoint("watpoint", async sql => { + await sql`insert into ${sql(random_name)} values(2)`; + throw new Error("fail"); + }); + }) + .catch(() => "fail"), + ).toBe("fail"); + }); + + test("Transaction succeeds on caught savepoint", async () => { + await using sql = new SQL(options); + const random_name = ("t_" + randomUUIDv7("hex").replaceAll("-", "")).toLowerCase(); + await sql`CREATE TABLE IF NOT EXISTS ${sql(random_name)} (a int)`; + try { + await sql.begin(async sql => { + await sql`insert into ${sql(random_name)} values(1)`; + await sql + .savepoint(async sql => { + await sql`insert into ${sql(random_name)} values(2)`; + throw new Error("please rollback"); + }) + .catch(() => { + /* ignore */ + }); + await sql`insert into ${sql(random_name)} values(3)`; + }); + expect((await sql`select count(1) as count from ${sql(random_name)}`)[0].count).toBe(2); + } finally { + await sql`DROP TABLE IF EXISTS ${sql(random_name)}`; + } + }); + + test("Savepoint returns Result", async () => { + let result; + await using sql = new SQL(options); + await sql.begin(async t => { + result = await t.savepoint(s => s`select 1 as x`); + }); + expect(result[0]?.x).toBe(1); + }); + + test("Uncaught transaction request errors bubbles to transaction", async () => { + await using sql = new SQL(options); + expect(await sql.begin(sql => [sql`select wat`, sql`select 1 as x, ${1} as a`]).catch(e => e.message)).toBe( + "Unknown column 'wat' in 'field list'", + ); + }); + + test("Transaction rejects with rethrown error", async () => { + await using sql = new SQL(options); + expect( + await sql + .begin(async sql => { + try { + await sql`select exception`; + } catch (ex) { + throw new Error("WAT"); + } + }) + .catch(e => e.message), + ).toBe("WAT"); + }); + + test("Parallel transactions", async () => { + await using sql = new SQL({ ...options, max: 2 }); + + expect( + (await Promise.all([sql.begin(sql => sql`select 1 as count`), sql.begin(sql => sql`select 1 as count`)])) + .map(x => x[0].count) + .join(""), + ).toBe("11"); + }); + + test("Many transactions at beginning of connection", async () => { + await using sql = new SQL({ ...options, max: 2 }); + const xs = await Promise.all(Array.from({ length: 30 }, () => sql.begin(sql => sql`select 1`))); + return expect(xs.length).toBe(30); + }); + + test("Transactions array", async () => { + await using sql = new SQL(options); + expect( + (await sql.begin(sql => [sql`select 1 as count`, sql`select 1 as count`])).map(x => x[0].count).join(""), + ).toBe("11"); + }); + }, +); diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index 963935f989..16930ff773 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -147,24 +147,24 @@ if (isDockerEnabled()) { // --- Expected pg_hba.conf --- process.env.DATABASE_URL = `postgres://bun_sql_test@localhost:${container.port}/bun_sql_test`; - const login: Bun.SQL.PostgresOptions = { + const login: Bun.SQL.PostgresOrMySQLOptions = { username: "bun_sql_test", port: container.port, }; - const login_md5: Bun.SQL.PostgresOptions = { + const login_md5: Bun.SQL.PostgresOrMySQLOptions = { username: "bun_sql_test_md5", password: "bun_sql_test_md5", port: container.port, }; - const login_scram: Bun.SQL.PostgresOptions = { + const login_scram: Bun.SQL.PostgresOrMySQLOptions = { username: "bun_sql_test_scram", password: "bun_sql_test_scram", port: container.port, }; - const options: Bun.SQL.PostgresOptions = { + const options: Bun.SQL.PostgresOrMySQLOptions = { db: "bun_sql_test", username: login.username, password: login.password, diff --git a/test/js/sql/sqlite-sql.test.ts b/test/js/sql/sqlite-sql.test.ts index a735e0e221..adf3b92b79 100644 --- a/test/js/sql/sqlite-sql.test.ts +++ b/test/js/sql/sqlite-sql.test.ts @@ -17,14 +17,6 @@ describe("Connection & Initialization", () => { expect(myapp.options.adapter).toBe("sqlite"); expect(myapp.options.filename).toBe("myapp.db"); - const myapp2 = new SQL("myapp.db", { adapter: "sqlite" }); - expect(myapp2.options.adapter).toBe("sqlite"); - expect(myapp2.options.filename).toBe("myapp.db"); - - expect(() => new SQL("myapp.db")).toThrowErrorMatchingInlineSnapshot( - `"Invalid URL 'myapp.db' for postgres. Did you mean to specify \`{ adapter: "sqlite" }\`?"`, - ); - const postgres = new SQL("postgres://user1:pass2@localhost:5432/mydb"); expect(postgres.options.adapter).not.toBe("sqlite"); }); @@ -611,18 +603,6 @@ describe("Connection & Initialization", () => { expect(sql.options.filename).toBe(":memory:"); sql.close(); }); - - test("should throw for invalid URL without adapter", () => { - expect(() => new SQL("not-a-url")).toThrowErrorMatchingInlineSnapshot( - `"Invalid URL 'not-a-url' for postgres. Did you mean to specify \`{ adapter: "sqlite" }\`?"`, - ); - }); - - test("should throw for postgres URL when sqlite adapter is expected", () => { - expect(() => new SQL("myapp.db")).toThrowErrorMatchingInlineSnapshot( - `"Invalid URL 'myapp.db' for postgres. Did you mean to specify \`{ adapter: "sqlite" }\`?"`, - ); - }); }); describe("Mixed Configurations", () => { @@ -690,8 +670,8 @@ describe("Connection & Initialization", () => { describe("Error Cases", () => { test("should throw for unsupported adapter", () => { - expect(() => new SQL({ adapter: "mysql" as any })).toThrowErrorMatchingInlineSnapshot( - `"Unsupported adapter: mysql. Supported adapters: "postgres", "sqlite""`, + expect(() => new SQL({ adapter: "mssql" as any })).toThrowErrorMatchingInlineSnapshot( + `"Unsupported adapter: mssql. Supported adapters: "postgres", "sqlite", "mysql""`, ); }); diff --git a/test/js/sql/sqlite-url-parsing.test.ts b/test/js/sql/sqlite-url-parsing.test.ts index 9f808e44d8..006bd73d50 100644 --- a/test/js/sql/sqlite-url-parsing.test.ts +++ b/test/js/sql/sqlite-url-parsing.test.ts @@ -307,6 +307,11 @@ describe("SQLite URL Parsing Matrix", () => { "http://example.com/test.db", "https://example.com/test.db", "ftp://example.com/test.db", + "localhost/test.db", + "localhost:5432/test.db", + "example.com:3306/db", + "example.com/test", + "localhost", "postgres://user:pass@localhost/db", "postgresql://user:pass@localhost/db", ]; @@ -317,12 +322,4 @@ describe("SQLite URL Parsing Matrix", () => { sql.close(); }); }); - - describe("Plain filenames without adapter should throw", () => { - test("plain filename without adapter throws", () => { - expect(() => new SQL("myapp.db")).toThrowErrorMatchingInlineSnapshot( - `"Invalid URL 'myapp.db' for postgres. Did you mean to specify \`{ adapter: "sqlite" }\`?"`, - ); - }); - }); });