From caff4e6008360713eaa62f5b083dfe82509bbb3e Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Tue, 25 Feb 2025 19:46:53 -0800 Subject: [PATCH] fix(sql) fix SQL fragments handling and expand support for helpers (#17691) --- src/js/bun/sql.ts | 488 +++++++++++++++++++++++++--------------- test/js/sql/sql.test.ts | 106 +++++++++ 2 files changed, 417 insertions(+), 177 deletions(-) diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index 7f889e1229..e16faf010d 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -129,6 +129,307 @@ function getQueryHandle(query) { } return handle; } + +enum SQLCommand { + insert = 0, + update = 1, + updateSet = 2, + where = 3, + whereIn = 4, + none = -1, +} + +function commandToString(command: SQLCommand): string { + switch (command) { + case SQLCommand.insert: + return "INSERT"; + case SQLCommand.updateSet: + case SQLCommand.update: + return "UPDATE"; + case SQLCommand.whereIn: + case SQLCommand.where: + return "WHERE"; + default: + return ""; + } +} + +function detectCommand(query: string): SQLCommand { + const text = query.toLowerCase().trim(); + const text_len = text.length; + + let token = ""; + let command = SQLCommand.none; + let quoted = false; + for (let i = 0; i < text_len; i++) { + const char = text[i]; + switch (char) { + case " ": // Space + case "\n": // Line feed + case "\t": // Tab character + case "\r": // Carriage return + case "\f": // Form feed + case "\v": { + switch (token) { + case "insert": { + if (command === SQLCommand.none) { + return SQLCommand.insert; + } + return command; + } + case "update": { + if (command === SQLCommand.none) { + command = SQLCommand.update; + token = ""; + continue; // try to find SET + } + return command; + } + case "where": { + command = SQLCommand.where; + token = ""; + continue; // try to find IN + } + case "set": { + if (command === SQLCommand.update) { + command = SQLCommand.updateSet; + token = ""; + continue; // try to find WHERE + } + return command; + } + case "in": { + if (command === SQLCommand.where) { + return SQLCommand.whereIn; + } + return command; + } + default: { + token = ""; + continue; + } + } + } + default: { + // skip quoted commands + if (char === '"') { + quoted = !quoted; + continue; + } + if (!quoted) { + token += char; + } + } + } + } + if (token) { + switch (command) { + case SQLCommand.none: { + switch (token) { + case "insert": + return SQLCommand.insert; + case "update": + return SQLCommand.update; + case "where": + return SQLCommand.where; + default: + return SQLCommand.none; + } + } + case SQLCommand.update: { + if (token === "set") { + return SQLCommand.updateSet; + } + return SQLCommand.update; + } + case SQLCommand.where: { + if (token === "in") { + return SQLCommand.whereIn; + } + return SQLCommand.where; + } + } + } + + return command; +} + +function normalizeQuery(strings, values, binding_idx = 1) { + if (typeof strings === "string") { + // identifier or unsafe query + return [strings, values || []]; + } + if (!$isArray(strings)) { + // we should not hit this path + throw new SyntaxError("Invalid query: SQL Fragment cannot be executed or was misused"); + } + const str_len = strings.length; + if (str_len === 0) { + return ["", []]; + } + let binding_values: any[] = []; + let query = ""; + for (let i = 0; i < str_len; i++) { + const string = strings[i]; + + if (typeof string === "string") { + query += string; + if (values.length > i) { + const value = values[i]; + if (value instanceof Query) { + const [sub_query, sub_values] = normalizeQuery(value[_strings], value[_values], binding_idx); + query += sub_query; + for (let j = 0; j < sub_values.length; j++) { + binding_values.push(sub_values[j]); + } + binding_idx += sub_values.length; + } else if (value instanceof SQLArrayParameter) { + const command = detectCommand(query); + // only selectIn, insert, update, updateSet are allowed + if (command === SQLCommand.none || command === SQLCommand.where) { + throw new SyntaxError("Helper are only allowed for INSERT, UPDATE and WHERE IN commands"); + } + const { columns, value: items } = value as SQLArrayParameter; + const columnCount = columns.length; + if (columnCount === 0 && command !== SQLCommand.whereIn) { + throw new SyntaxError(`Cannot ${commandToString(command)} with no columns`); + } + const lastColumnIndex = columns.length - 1; + + if (command === SQLCommand.insert) { + // + // insert into users ${sql(users)} or insert into users ${sql(user)} + // + + query += "("; + for (let j = 0; j < columnCount; j++) { + query += escapeIdentifier(columns[j]); + if (j < lastColumnIndex) { + query += ", "; + } + } + query += ") VALUES"; + if ($isArray(items)) { + const itemsCount = items.length; + const lastItemIndex = itemsCount - 1; + for (let j = 0; j < itemsCount; j++) { + query += "("; + const item = items[j]; + for (let k = 0; k < columnCount; k++) { + const column = columns[k]; + const columnValue = item[column]; + query += `$${binding_idx++}${k < lastColumnIndex ? ", " : ""}`; + if (typeof columnValue === "undefined") { + binding_values.push(null); + } else { + binding_values.push(columnValue); + } + } + if (j < lastItemIndex) { + query += "),"; + } else { + query += ") "; // the user can add RETURNING * or RETURNING id + } + } + } else { + query += "("; + const item = items; + for (let j = 0; j < columnCount; j++) { + const column = columns[j]; + const columnValue = item[column]; + query += `$${binding_idx++}${j < lastColumnIndex ? ", " : ""}`; + if (typeof columnValue === "undefined") { + binding_values.push(null); + } else { + binding_values.push(columnValue); + } + } + query += ") "; // the user can add RETURNING * or RETURNING id + } + } else if (command === SQLCommand.whereIn) { + // SELECT * FROM users WHERE id IN (${sql([1, 2, 3])}) + if (!$isArray(items)) { + throw new SyntaxError("An array of values is required for WHERE IN helper"); + } + const itemsCount = items.length; + const lastItemIndex = itemsCount - 1; + query += "("; + for (let j = 0; j < itemsCount; j++) { + query += `$${binding_idx++}${j < lastItemIndex ? ", " : ""}`; + if (columnCount > 0) { + // we must use a key from a object + if (columnCount > 1) { + // we should not pass multiple columns here + throw new SyntaxError("Cannot use WHERE IN helper with multiple columns"); + } + // SELECT * FROM users WHERE id IN (${sql(users, "id")}) + const value = items[j]; + if (typeof value === "undefined") { + binding_values.push(null); + } else { + const value_from_key = value[columns[0]]; + + if (typeof value_from_key === "undefined") { + binding_values.push(null); + } else { + binding_values.push(value_from_key); + } + } + } else { + const value = items[j]; + if (typeof value === "undefined") { + binding_values.push(null); + } else { + binding_values.push(value); + } + } + } + query += ") "; // more conditions can be added after this + } else { + // UPDATE users SET ${sql({ name: "John", age: 31 })} WHERE id = 1 + let item; + if ($isArray(items)) { + if (items.length > 1) { + throw new SyntaxError("Cannot use array of objects for UPDATE"); + } + item = items[0]; + } else { + item = items; + } + // no need to include if is updateSet + if (command === SQLCommand.update) { + query += " SET "; + } + for (let i = 0; i < columnCount; i++) { + const column = columns[i]; + const columnValue = item[column]; + query += `${escapeIdentifier(column)} = $${binding_idx++}${i < lastColumnIndex ? ", " : ""}`; + if (typeof columnValue === "undefined") { + binding_values.push(null); + } else { + binding_values.push(columnValue); + } + } + query += " "; // the user can add where clause after this + } + } else { + //TODO: handle sql.array parameters + query += `$${binding_idx++} `; + if (typeof value === "undefined") { + binding_values.push(null); + } else { + binding_values.push(value); + } + } + } + } else { + throw new SyntaxError("Invalid query: SQL Fragment cannot be executed or was misused"); + } + } + + return [query, binding_values]; +} + class Query extends PublicPromise { [_resolve]; [_reject]; @@ -199,7 +500,6 @@ class Query extends PublicPromise { this.reject(err); } } - get active() { return (this[_queryStatus] & QueryStatus.active) != 0; } @@ -985,165 +1285,8 @@ async function createConnection(options, onConnected, onClose) { } } -var hasSQLArrayParameter = false; -function normalizeStrings(strings, values) { - hasSQLArrayParameter = false; - - if ($isArray(strings)) { - const count = strings.length; - - if (count === 0) { - return ""; - } - - var out = strings[0]; - - // For now, only support insert queries with array parameters - // - // insert into users ${sql(users)} - // - if (values.length > 0 && typeof values[0] === "object" && values[0] && values[0] instanceof SQLArrayParameter) { - if (values.length > 1) { - throw new Error("Cannot mix array parameters with other values"); - } - hasSQLArrayParameter = true; - const { columns, value } = values[0]; - const groupCount = value.length; - out += `values `; - - let columnIndex = 1; - let columnCount = columns.length; - let lastColumnIndex = columnCount - 1; - - for (var i = 0; i < groupCount; i++) { - out += i > 0 ? `, (` : `(`; - - for (var j = 0; j < lastColumnIndex; j++) { - out += `$${columnIndex++}, `; - } - - out += `$${columnIndex++})`; - } - - for (var i = 1; i < count; i++) { - out += strings[i]; - } - - return out; - } - - for (var i = 1; i < count; i++) { - // this space in between is important - out += `$${i} ${strings[i]}`; - } - return out; - } - return strings + ""; -} -function hasQuery(value: any) { - return value instanceof Query; -} -function handleQueryFragment(strings, values) { - let sqlString; - let final_values: Array; - let final_strings = []; - - if ($isArray(strings) && values.some(hasQuery)) { - // we need to handle fragments of queries - final_values = []; - let strings_idx = 0; - - for (let i = 0; i < values.length; i++) { - const value = values[i]; - if (value instanceof Query) { - let sub_strings = value[_strings]; - var is_unsafe = value[_flags] & SQLQueryFlags.unsafe; - if (typeof sub_strings === "string") { - if (final_strings.length === 0) { - // we are the first value - let final_string_value = strings[strings_idx] + sub_strings; - strings_idx++; - if (strings_idx < strings.length) { - final_string_value += strings[strings_idx]; - strings_idx++; - } - //@ts-ignore - final_strings.push(final_string_value); - } else { - // merge the strings with current string - const current_idx = final_strings.length - 1; - final_strings[current_idx] = final_strings[current_idx] + sub_strings; - if (strings_idx < strings.length) { - final_strings[current_idx] += strings[strings_idx]; - strings_idx++; - } - } - // in this case we dont have values to merge - } else { - // complex fragment, we need to merge values - let sub_values = value[_values]; - - if (sub_values.some(hasQuery)) { - const { final_strings: sub_final_strings, final_values: sub_final_values } = handleQueryFragment( - sub_strings, - sub_values, - ); - sub_strings = sub_final_strings; - sub_values = sub_final_values; - } - - if (final_strings.length > 0) { - // complex not the first - const current_idx = final_strings.length - 1; - final_strings[current_idx] = final_strings[current_idx] + sub_strings[0]; - - if (sub_strings.length > 1) { - final_strings.push(...sub_strings.slice(1)); - } - final_values.push(...sub_values); - } else { - // complex the first - final_strings.push(strings[strings_idx] + sub_strings[0]); - strings_idx += 1; - final_values.push(...sub_values); - if (sub_strings.length > 1) { - final_strings.push(...sub_strings.slice(1)); - } - } - } - } else { - // for each value we have 2 strings - //@ts-ignore - final_strings.push(strings[strings_idx]); - strings_idx += 1; - if (strings_idx + 1 < strings.length) { - //@ts-ignore - final_strings.push(strings[strings_idx + 1]); - strings_idx += 1; - } - - final_values.push(value); - } - } - } else { - final_strings = strings; - final_values = values; - } - - return { final_strings, final_values }; -} function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize, bigint, simple) { - let columns; - - let { final_strings, final_values } = handleQueryFragment(strings, values); - - const sqlString = normalizeStrings(final_strings, final_values); - if (hasSQLArrayParameter) { - hasSQLArrayParameter = false; - const v = final_values[0]; - columns = v.columns; - final_values = v.value; - } + const [sqlString, final_values] = normalizeQuery(strings, values); if (!allowUnsafeTransaction) { if (poolSize !== 1) { const upperCaseSqlString = sqlString.toUpperCase().trim(); @@ -1152,7 +1295,8 @@ function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize, bigint } } } - return createQuery(sqlString, final_values, new SQLResultArray(), columns, !!bigint, !!simple); + + return createQuery(sqlString, final_values, new SQLResultArray(), undefined, !!bigint, !!simple); } class SQLArrayParameter { @@ -1179,7 +1323,7 @@ class SQLArrayParameter { } } - throw new Error(`Invalid key: ${key}`); + throw new Error(`Keys must be strings or numbers: ${key}`); } } @@ -1576,7 +1720,8 @@ function SQL(o, e = {}) { return Promise.reject(connectionClosedError()); } if ($isArray(strings)) { - if (strings[0] && typeof strings[0] === "object") { + // detect if is tagged template + if (!$isArray(strings.raw)) { return new SQLArrayParameter(strings, values); } } else if ( @@ -1908,7 +2053,8 @@ function SQL(o, e = {}) { return Promise.reject(connectionClosedError()); } if ($isArray(strings)) { - if (strings[0] && typeof strings[0] === "object") { + // detect if is tagged template + if (!$isArray(strings.raw)) { return new SQLArrayParameter(strings, values); } } else if ( @@ -2140,21 +2286,9 @@ function SQL(o, e = {}) { } } function sql(strings, ...values) { - /** - * const users = [ - * { - * name: "Alice", - * age: 25, - * }, - * { - * name: "Bob", - * age: 30, - * }, - * ] - * sql`insert into users ${sql(users)}` - */ if ($isArray(strings)) { - if (strings[0] && typeof strings[0] === "object") { + // detect if is tagged template + if (!$isArray(strings.raw)) { return new SQLArrayParameter(strings, values); } } else if (typeof strings === "object" && !(strings instanceof Query) && !(strings instanceof SQLArrayParameter)) { diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index b0c48a8191..8d741642fb 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -11014,4 +11014,110 @@ CREATE TABLE ${table_name} ( expect(results[1].price).toBe("0.0123"); }); }); + + describe("helpers", () => { + test("insert helper", async () => { + await using sql = postgres({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + const result = await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })} RETURNING *`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("John"); + expect(result[0].age).toBe(30); + }); + test("update helper", async () => { + await using sql = postgres({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })}`; + const result = + await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id = 1 RETURNING *`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("Mary"); + expect(result[0].age).toBe(18); + }); + + test("update helper with IN", async () => { + await using sql = postgres({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + const users = [ + { id: 1, name: "John", age: 30 }, + { id: 2, name: "Jane", age: 25 }, + ]; + await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`; + + const result = + await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id IN ${sql([1, 2])} RETURNING *`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("Mary"); + expect(result[0].age).toBe(18); + expect(result[1].id).toBe(2); + expect(result[1].name).toBe("Mary"); + expect(result[1].age).toBe(18); + }); + + test("update helper with IN and column name", async () => { + await using sql = postgres({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + const users = [ + { id: 1, name: "John", age: 30 }, + { id: 2, name: "Jane", age: 25 }, + ]; + await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`; + + const result = + await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id IN ${sql(users, "id")} RETURNING *`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("Mary"); + expect(result[0].age).toBe(18); + expect(result[1].id).toBe(2); + expect(result[1].name).toBe("Mary"); + expect(result[1].age).toBe(18); + }); + + test("update multiple values no helper", async () => { + await using sql = postgres({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })}`; + await sql`UPDATE ${sql(random_name)} SET ${sql("name")} = ${"Mary"}, ${sql("age")} = ${18} WHERE id = 1`; + const result = await sql`SELECT * FROM ${sql(random_name)} WHERE id = 1`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("Mary"); + expect(result[0].age).toBe(18); + }); + + test("SELECT with IN and NOT IN", async () => { + await using sql = postgres({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`; + const users = [ + { id: 1, name: "John", age: 30 }, + { id: 2, name: "Jane", age: 25 }, + ]; + await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`; + + const result = + await sql`SELECT * FROM ${sql(random_name)} WHERE id IN ${sql(users, "id")} and id NOT IN ${sql([3, 4, 5])}`; + expect(result[0].id).toBe(1); + expect(result[0].name).toBe("John"); + expect(result[0].age).toBe(30); + expect(result[1].id).toBe(2); + expect(result[1].name).toBe("Jane"); + expect(result[1].age).toBe(25); + }); + + test("syntax error", async () => { + await using sql = postgres({ ...options, max: 1 }); + const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", ""); + const users = [ + { id: 1, name: "John", age: 30 }, + { id: 2, name: "Jane", age: 25 }, + ]; + + expect(() => sql`DELETE FROM ${sql(random_name)} ${sql(users, "id")}`.execute()).toThrow(SyntaxError); + }); + }); }