fix(sql) fix SQL fragments handling and expand support for helpers (#17691)

This commit is contained in:
Ciro Spaciari
2025-02-25 19:46:53 -08:00
committed by GitHub
parent 1322adbb16
commit caff4e6008
2 changed files with 417 additions and 177 deletions

View File

@@ -129,6 +129,307 @@ function getQueryHandle(query) {
}
return handle;
}
enum SQLCommand {
insert = 0,
update = 1,
updateSet = 2,
where = 3,
whereIn = 4,
none = -1,
}
function commandToString(command: SQLCommand): string {
switch (command) {
case SQLCommand.insert:
return "INSERT";
case SQLCommand.updateSet:
case SQLCommand.update:
return "UPDATE";
case SQLCommand.whereIn:
case SQLCommand.where:
return "WHERE";
default:
return "";
}
}
function detectCommand(query: string): SQLCommand {
const text = query.toLowerCase().trim();
const text_len = text.length;
let token = "";
let command = SQLCommand.none;
let quoted = false;
for (let i = 0; i < text_len; i++) {
const char = text[i];
switch (char) {
case " ": // Space
case "\n": // Line feed
case "\t": // Tab character
case "\r": // Carriage return
case "\f": // Form feed
case "\v": {
switch (token) {
case "insert": {
if (command === SQLCommand.none) {
return SQLCommand.insert;
}
return command;
}
case "update": {
if (command === SQLCommand.none) {
command = SQLCommand.update;
token = "";
continue; // try to find SET
}
return command;
}
case "where": {
command = SQLCommand.where;
token = "";
continue; // try to find IN
}
case "set": {
if (command === SQLCommand.update) {
command = SQLCommand.updateSet;
token = "";
continue; // try to find WHERE
}
return command;
}
case "in": {
if (command === SQLCommand.where) {
return SQLCommand.whereIn;
}
return command;
}
default: {
token = "";
continue;
}
}
}
default: {
// skip quoted commands
if (char === '"') {
quoted = !quoted;
continue;
}
if (!quoted) {
token += char;
}
}
}
}
if (token) {
switch (command) {
case SQLCommand.none: {
switch (token) {
case "insert":
return SQLCommand.insert;
case "update":
return SQLCommand.update;
case "where":
return SQLCommand.where;
default:
return SQLCommand.none;
}
}
case SQLCommand.update: {
if (token === "set") {
return SQLCommand.updateSet;
}
return SQLCommand.update;
}
case SQLCommand.where: {
if (token === "in") {
return SQLCommand.whereIn;
}
return SQLCommand.where;
}
}
}
return command;
}
function normalizeQuery(strings, values, binding_idx = 1) {
if (typeof strings === "string") {
// identifier or unsafe query
return [strings, values || []];
}
if (!$isArray(strings)) {
// we should not hit this path
throw new SyntaxError("Invalid query: SQL Fragment cannot be executed or was misused");
}
const str_len = strings.length;
if (str_len === 0) {
return ["", []];
}
let binding_values: any[] = [];
let query = "";
for (let i = 0; i < str_len; i++) {
const string = strings[i];
if (typeof string === "string") {
query += string;
if (values.length > i) {
const value = values[i];
if (value instanceof Query) {
const [sub_query, sub_values] = normalizeQuery(value[_strings], value[_values], binding_idx);
query += sub_query;
for (let j = 0; j < sub_values.length; j++) {
binding_values.push(sub_values[j]);
}
binding_idx += sub_values.length;
} else if (value instanceof SQLArrayParameter) {
const command = detectCommand(query);
// only selectIn, insert, update, updateSet are allowed
if (command === SQLCommand.none || command === SQLCommand.where) {
throw new SyntaxError("Helper are only allowed for INSERT, UPDATE and WHERE IN commands");
}
const { columns, value: items } = value as SQLArrayParameter;
const columnCount = columns.length;
if (columnCount === 0 && command !== SQLCommand.whereIn) {
throw new SyntaxError(`Cannot ${commandToString(command)} with no columns`);
}
const lastColumnIndex = columns.length - 1;
if (command === SQLCommand.insert) {
//
// insert into users ${sql(users)} or insert into users ${sql(user)}
//
query += "(";
for (let j = 0; j < columnCount; j++) {
query += escapeIdentifier(columns[j]);
if (j < lastColumnIndex) {
query += ", ";
}
}
query += ") VALUES";
if ($isArray(items)) {
const itemsCount = items.length;
const lastItemIndex = itemsCount - 1;
for (let j = 0; j < itemsCount; j++) {
query += "(";
const item = items[j];
for (let k = 0; k < columnCount; k++) {
const column = columns[k];
const columnValue = item[column];
query += `$${binding_idx++}${k < lastColumnIndex ? ", " : ""}`;
if (typeof columnValue === "undefined") {
binding_values.push(null);
} else {
binding_values.push(columnValue);
}
}
if (j < lastItemIndex) {
query += "),";
} else {
query += ") "; // the user can add RETURNING * or RETURNING id
}
}
} else {
query += "(";
const item = items;
for (let j = 0; j < columnCount; j++) {
const column = columns[j];
const columnValue = item[column];
query += `$${binding_idx++}${j < lastColumnIndex ? ", " : ""}`;
if (typeof columnValue === "undefined") {
binding_values.push(null);
} else {
binding_values.push(columnValue);
}
}
query += ") "; // the user can add RETURNING * or RETURNING id
}
} else if (command === SQLCommand.whereIn) {
// SELECT * FROM users WHERE id IN (${sql([1, 2, 3])})
if (!$isArray(items)) {
throw new SyntaxError("An array of values is required for WHERE IN helper");
}
const itemsCount = items.length;
const lastItemIndex = itemsCount - 1;
query += "(";
for (let j = 0; j < itemsCount; j++) {
query += `$${binding_idx++}${j < lastItemIndex ? ", " : ""}`;
if (columnCount > 0) {
// we must use a key from a object
if (columnCount > 1) {
// we should not pass multiple columns here
throw new SyntaxError("Cannot use WHERE IN helper with multiple columns");
}
// SELECT * FROM users WHERE id IN (${sql(users, "id")})
const value = items[j];
if (typeof value === "undefined") {
binding_values.push(null);
} else {
const value_from_key = value[columns[0]];
if (typeof value_from_key === "undefined") {
binding_values.push(null);
} else {
binding_values.push(value_from_key);
}
}
} else {
const value = items[j];
if (typeof value === "undefined") {
binding_values.push(null);
} else {
binding_values.push(value);
}
}
}
query += ") "; // more conditions can be added after this
} else {
// UPDATE users SET ${sql({ name: "John", age: 31 })} WHERE id = 1
let item;
if ($isArray(items)) {
if (items.length > 1) {
throw new SyntaxError("Cannot use array of objects for UPDATE");
}
item = items[0];
} else {
item = items;
}
// no need to include if is updateSet
if (command === SQLCommand.update) {
query += " SET ";
}
for (let i = 0; i < columnCount; i++) {
const column = columns[i];
const columnValue = item[column];
query += `${escapeIdentifier(column)} = $${binding_idx++}${i < lastColumnIndex ? ", " : ""}`;
if (typeof columnValue === "undefined") {
binding_values.push(null);
} else {
binding_values.push(columnValue);
}
}
query += " "; // the user can add where clause after this
}
} else {
//TODO: handle sql.array parameters
query += `$${binding_idx++} `;
if (typeof value === "undefined") {
binding_values.push(null);
} else {
binding_values.push(value);
}
}
}
} else {
throw new SyntaxError("Invalid query: SQL Fragment cannot be executed or was misused");
}
}
return [query, binding_values];
}
class Query extends PublicPromise {
[_resolve];
[_reject];
@@ -199,7 +500,6 @@ class Query extends PublicPromise {
this.reject(err);
}
}
get active() {
return (this[_queryStatus] & QueryStatus.active) != 0;
}
@@ -985,165 +1285,8 @@ async function createConnection(options, onConnected, onClose) {
}
}
var hasSQLArrayParameter = false;
function normalizeStrings(strings, values) {
hasSQLArrayParameter = false;
if ($isArray(strings)) {
const count = strings.length;
if (count === 0) {
return "";
}
var out = strings[0];
// For now, only support insert queries with array parameters
//
// insert into users ${sql(users)}
//
if (values.length > 0 && typeof values[0] === "object" && values[0] && values[0] instanceof SQLArrayParameter) {
if (values.length > 1) {
throw new Error("Cannot mix array parameters with other values");
}
hasSQLArrayParameter = true;
const { columns, value } = values[0];
const groupCount = value.length;
out += `values `;
let columnIndex = 1;
let columnCount = columns.length;
let lastColumnIndex = columnCount - 1;
for (var i = 0; i < groupCount; i++) {
out += i > 0 ? `, (` : `(`;
for (var j = 0; j < lastColumnIndex; j++) {
out += `$${columnIndex++}, `;
}
out += `$${columnIndex++})`;
}
for (var i = 1; i < count; i++) {
out += strings[i];
}
return out;
}
for (var i = 1; i < count; i++) {
// this space in between is important
out += `$${i} ${strings[i]}`;
}
return out;
}
return strings + "";
}
function hasQuery(value: any) {
return value instanceof Query;
}
function handleQueryFragment(strings, values) {
let sqlString;
let final_values: Array<any>;
let final_strings = [];
if ($isArray(strings) && values.some(hasQuery)) {
// we need to handle fragments of queries
final_values = [];
let strings_idx = 0;
for (let i = 0; i < values.length; i++) {
const value = values[i];
if (value instanceof Query) {
let sub_strings = value[_strings];
var is_unsafe = value[_flags] & SQLQueryFlags.unsafe;
if (typeof sub_strings === "string") {
if (final_strings.length === 0) {
// we are the first value
let final_string_value = strings[strings_idx] + sub_strings;
strings_idx++;
if (strings_idx < strings.length) {
final_string_value += strings[strings_idx];
strings_idx++;
}
//@ts-ignore
final_strings.push(final_string_value);
} else {
// merge the strings with current string
const current_idx = final_strings.length - 1;
final_strings[current_idx] = final_strings[current_idx] + sub_strings;
if (strings_idx < strings.length) {
final_strings[current_idx] += strings[strings_idx];
strings_idx++;
}
}
// in this case we dont have values to merge
} else {
// complex fragment, we need to merge values
let sub_values = value[_values];
if (sub_values.some(hasQuery)) {
const { final_strings: sub_final_strings, final_values: sub_final_values } = handleQueryFragment(
sub_strings,
sub_values,
);
sub_strings = sub_final_strings;
sub_values = sub_final_values;
}
if (final_strings.length > 0) {
// complex not the first
const current_idx = final_strings.length - 1;
final_strings[current_idx] = final_strings[current_idx] + sub_strings[0];
if (sub_strings.length > 1) {
final_strings.push(...sub_strings.slice(1));
}
final_values.push(...sub_values);
} else {
// complex the first
final_strings.push(strings[strings_idx] + sub_strings[0]);
strings_idx += 1;
final_values.push(...sub_values);
if (sub_strings.length > 1) {
final_strings.push(...sub_strings.slice(1));
}
}
}
} else {
// for each value we have 2 strings
//@ts-ignore
final_strings.push(strings[strings_idx]);
strings_idx += 1;
if (strings_idx + 1 < strings.length) {
//@ts-ignore
final_strings.push(strings[strings_idx + 1]);
strings_idx += 1;
}
final_values.push(value);
}
}
} else {
final_strings = strings;
final_values = values;
}
return { final_strings, final_values };
}
function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize, bigint, simple) {
let columns;
let { final_strings, final_values } = handleQueryFragment(strings, values);
const sqlString = normalizeStrings(final_strings, final_values);
if (hasSQLArrayParameter) {
hasSQLArrayParameter = false;
const v = final_values[0];
columns = v.columns;
final_values = v.value;
}
const [sqlString, final_values] = normalizeQuery(strings, values);
if (!allowUnsafeTransaction) {
if (poolSize !== 1) {
const upperCaseSqlString = sqlString.toUpperCase().trim();
@@ -1152,7 +1295,8 @@ function doCreateQuery(strings, values, allowUnsafeTransaction, poolSize, bigint
}
}
}
return createQuery(sqlString, final_values, new SQLResultArray(), columns, !!bigint, !!simple);
return createQuery(sqlString, final_values, new SQLResultArray(), undefined, !!bigint, !!simple);
}
class SQLArrayParameter {
@@ -1179,7 +1323,7 @@ class SQLArrayParameter {
}
}
throw new Error(`Invalid key: ${key}`);
throw new Error(`Keys must be strings or numbers: ${key}`);
}
}
@@ -1576,7 +1720,8 @@ function SQL(o, e = {}) {
return Promise.reject(connectionClosedError());
}
if ($isArray(strings)) {
if (strings[0] && typeof strings[0] === "object") {
// detect if is tagged template
if (!$isArray(strings.raw)) {
return new SQLArrayParameter(strings, values);
}
} else if (
@@ -1908,7 +2053,8 @@ function SQL(o, e = {}) {
return Promise.reject(connectionClosedError());
}
if ($isArray(strings)) {
if (strings[0] && typeof strings[0] === "object") {
// detect if is tagged template
if (!$isArray(strings.raw)) {
return new SQLArrayParameter(strings, values);
}
} else if (
@@ -2140,21 +2286,9 @@ function SQL(o, e = {}) {
}
}
function sql(strings, ...values) {
/**
* const users = [
* {
* name: "Alice",
* age: 25,
* },
* {
* name: "Bob",
* age: 30,
* },
* ]
* sql`insert into users ${sql(users)}`
*/
if ($isArray(strings)) {
if (strings[0] && typeof strings[0] === "object") {
// detect if is tagged template
if (!$isArray(strings.raw)) {
return new SQLArrayParameter(strings, values);
}
} else if (typeof strings === "object" && !(strings instanceof Query) && !(strings instanceof SQLArrayParameter)) {

View File

@@ -11014,4 +11014,110 @@ CREATE TABLE ${table_name} (
expect(results[1].price).toBe("0.0123");
});
});
describe("helpers", () => {
test("insert helper", async () => {
await using sql = postgres({ ...options, max: 1 });
const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", "");
await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`;
const result = await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })} RETURNING *`;
expect(result[0].id).toBe(1);
expect(result[0].name).toBe("John");
expect(result[0].age).toBe(30);
});
test("update helper", async () => {
await using sql = postgres({ ...options, max: 1 });
const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", "");
await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`;
await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })}`;
const result =
await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id = 1 RETURNING *`;
expect(result[0].id).toBe(1);
expect(result[0].name).toBe("Mary");
expect(result[0].age).toBe(18);
});
test("update helper with IN", async () => {
await using sql = postgres({ ...options, max: 1 });
const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", "");
await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`;
const users = [
{ id: 1, name: "John", age: 30 },
{ id: 2, name: "Jane", age: 25 },
];
await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`;
const result =
await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id IN ${sql([1, 2])} RETURNING *`;
expect(result[0].id).toBe(1);
expect(result[0].name).toBe("Mary");
expect(result[0].age).toBe(18);
expect(result[1].id).toBe(2);
expect(result[1].name).toBe("Mary");
expect(result[1].age).toBe(18);
});
test("update helper with IN and column name", async () => {
await using sql = postgres({ ...options, max: 1 });
const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", "");
await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`;
const users = [
{ id: 1, name: "John", age: 30 },
{ id: 2, name: "Jane", age: 25 },
];
await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`;
const result =
await sql`UPDATE ${sql(random_name)} SET ${sql({ name: "Mary", age: 18 })} WHERE id IN ${sql(users, "id")} RETURNING *`;
expect(result[0].id).toBe(1);
expect(result[0].name).toBe("Mary");
expect(result[0].age).toBe(18);
expect(result[1].id).toBe(2);
expect(result[1].name).toBe("Mary");
expect(result[1].age).toBe(18);
});
test("update multiple values no helper", async () => {
await using sql = postgres({ ...options, max: 1 });
const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", "");
await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`;
await sql`INSERT INTO ${sql(random_name)} ${sql({ id: 1, name: "John", age: 30 })}`;
await sql`UPDATE ${sql(random_name)} SET ${sql("name")} = ${"Mary"}, ${sql("age")} = ${18} WHERE id = 1`;
const result = await sql`SELECT * FROM ${sql(random_name)} WHERE id = 1`;
expect(result[0].id).toBe(1);
expect(result[0].name).toBe("Mary");
expect(result[0].age).toBe(18);
});
test("SELECT with IN and NOT IN", async () => {
await using sql = postgres({ ...options, max: 1 });
const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", "");
await sql`CREATE TEMPORARY TABLE ${sql(random_name)} (id int, name text, age int)`;
const users = [
{ id: 1, name: "John", age: 30 },
{ id: 2, name: "Jane", age: 25 },
];
await sql`INSERT INTO ${sql(random_name)} ${sql(users)}`;
const result =
await sql`SELECT * FROM ${sql(random_name)} WHERE id IN ${sql(users, "id")} and id NOT IN ${sql([3, 4, 5])}`;
expect(result[0].id).toBe(1);
expect(result[0].name).toBe("John");
expect(result[0].age).toBe(30);
expect(result[1].id).toBe(2);
expect(result[1].name).toBe("Jane");
expect(result[1].age).toBe(25);
});
test("syntax error", async () => {
await using sql = postgres({ ...options, max: 1 });
const random_name = "test_" + randomUUIDv7("hex").replaceAll("-", "");
const users = [
{ id: 1, name: "John", age: 30 },
{ id: 2, name: "Jane", age: 25 },
];
expect(() => sql`DELETE FROM ${sql(random_name)} ${sql(users, "id")}`.execute()).toThrow(SyntaxError);
});
});
}