diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index 2f98380180..75b937c2eb 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -1252,6 +1252,7 @@ async function createConnection(options, onConnected, onClose) { maxLifetime = 0, prepare = true, } = options; + let password = options.password; try { if (typeof password === "function") { @@ -1390,6 +1391,8 @@ function loadOptions(o) { url = new URL(o); } o ||= {}; + query = ""; + if (url) { ({ hostname, port, username, password, adapter } = o); // object overrides url @@ -1404,12 +1407,15 @@ function loadOptions(o) { } const queryObject = url.searchParams.toJSON(); - query = ""; for (const key in queryObject) { if (key.toLowerCase() === "sslmode") { sslMode = normalizeSSLMode(queryObject[key]); } else { - query += `${encodeURIComponent(key)}=${encodeURIComponent(queryObject[key])} `; + // this is valid for postgres for other databases it might not be valid + // check adapter then implement for other databases + // encode string with \0 as finalizer + // must be key\0value\0 + query += `${key}\0${queryObject[key]}\0`; } } query = query.trim(); @@ -1419,7 +1425,14 @@ function loadOptions(o) { username ||= o.username || o.user || env.PGUSERNAME || env.PGUSER || env.USER || env.USERNAME || "postgres"; database ||= o.database || o.db || decodeIfValid((url?.pathname ?? "").slice(1)) || env.PGDATABASE || username; password ||= o.password || o.pass || env.PGPASSWORD || ""; - + const connection = o.connection; + if (connection && $isObject(connection)) { + for (const key in connection) { + if (connection[key] !== undefined) { + query += `${key}\0${connection[key]}\0`; + } + } + } tls ||= o.tls || o.ssl; adapter ||= o.adapter || "postgres"; max = o.max; diff --git a/src/sql/postgres/postgres_protocol.zig b/src/sql/postgres/postgres_protocol.zig index f20cd461dc..c4347c9015 100644 --- a/src/sql/postgres/postgres_protocol.zig +++ b/src/sql/postgres/postgres_protocol.zig @@ -1330,8 +1330,7 @@ pub const StartupMessage = struct { const user = this.user.slice(); const database = this.database.slice(); const options = this.options.slice(); - - const count: usize = @sizeOf((int4)) + @sizeOf((int4)) + zFieldCount("user", user) + zFieldCount("database", database) + zFieldCount("client_encoding", "UTF8") + zFieldCount("", options) + 1; + const count: usize = @sizeOf((int4)) + @sizeOf((int4)) + zFieldCount("user", user) + zFieldCount("database", database) + zFieldCount("client_encoding", "UTF8") + options.len + 1; const header = toBytes(Int32(@as(u32, @truncate(count)))); try writer.write(&header); @@ -1349,13 +1348,11 @@ pub const StartupMessage = struct { } else { try writer.string(database); } - try writer.string("client_encoding"); try writer.string("UTF8"); - - if (options.len > 0) - try writer.string(options); - + if (options.len > 0) { + try writer.write(options); + } try writer.write(&[_]u8{0}); } diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index 8d741642fb..c022e19246 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -11120,4 +11120,19 @@ CREATE TABLE ${table_name} ( expect(() => sql`DELETE FROM ${sql(random_name)} ${sql(users, "id")}`.execute()).toThrow(SyntaxError); }); }); + + describe("connection options", () => { + test("connection", async () => { + await using sql = postgres({ ...options, max: 1, connection: { search_path: "information_schema" } }); + const [item] = await sql`SELECT COUNT(*)::INT FROM columns LIMIT 1`.values(); + expect(item[0]).toBeGreaterThan(0); + }); + test("query string", async () => { + await using sql = postgres(process.env.DATABASE_URL + "?search_path=information_schema", { + max: 1, + }); + const [item] = await sql`SELECT COUNT(*)::INT FROM columns LIMIT 1`.values(); + expect(item[0]).toBeGreaterThan(0); + }); + }); }