mirror of
https://github.com/oven-sh/bun
synced 2026-03-12 18:27:35 +01:00
Compare commits
3 Commits
claude/bui
...
claude/sql
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
212d96e165 | ||
|
|
685020fcee | ||
|
|
40072ae8e3 |
134
packages/bun-types/sqlite.d.ts
vendored
134
packages/bun-types/sqlite.d.ts
vendored
@@ -553,6 +553,140 @@ declare module "bun:sqlite" {
|
||||
options?: { readonly?: boolean; strict?: boolean; safeIntegers?: boolean },
|
||||
): Database;
|
||||
|
||||
/**
|
||||
* Register a user-defined scalar function.
|
||||
*
|
||||
* Matches the `better-sqlite3` API for compatibility.
|
||||
*
|
||||
* @param name The name of the function to register in SQL
|
||||
* @param callback The JavaScript function to call
|
||||
* @returns The database instance (for chaining)
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const db = new Database(":memory:");
|
||||
* db.function("add2", (a, b) => a + b);
|
||||
* db.prepare("SELECT add2(1, 2)").pluck().get(); // => 3
|
||||
* ```
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* // With options
|
||||
* db.function("upper", { deterministic: true }, (s) => s.toUpperCase());
|
||||
* ```
|
||||
*/
|
||||
function(name: string, callback: (...args: any[]) => any): this;
|
||||
function(
|
||||
name: string,
|
||||
options: {
|
||||
/**
|
||||
* If `true`, the function accepts any number of arguments.
|
||||
* Otherwise, the number of arguments is determined by `callback.length`.
|
||||
*/
|
||||
varargs?: boolean;
|
||||
/**
|
||||
* If `true`, SQLite may cache results for the same inputs (optimization).
|
||||
*/
|
||||
deterministic?: boolean;
|
||||
/**
|
||||
* If `true`, the function can only be called from top-level SQL,
|
||||
* not from VIEWs, TRIGGERs, CHECK constraints, or DEFAULT clauses.
|
||||
*/
|
||||
directOnly?: boolean;
|
||||
/**
|
||||
* If `true`, integer arguments are passed as `BigInt` instead of `number`.
|
||||
*/
|
||||
safeIntegers?: boolean;
|
||||
},
|
||||
callback: (...args: any[]) => any,
|
||||
): this;
|
||||
|
||||
/**
|
||||
* Register a user-defined aggregate function.
|
||||
*
|
||||
* Matches the `better-sqlite3` API for compatibility.
|
||||
*
|
||||
* @param name The name of the aggregate function to register in SQL
|
||||
* @param options Configuration for the aggregate function
|
||||
* @returns The database instance (for chaining)
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const db = new Database(":memory:");
|
||||
*
|
||||
* db.aggregate("addAll", {
|
||||
* start: 0,
|
||||
* step: (total, nextValue) => total + nextValue,
|
||||
* });
|
||||
*
|
||||
* db.exec("CREATE TABLE expenses (dollars REAL)");
|
||||
* db.exec("INSERT INTO expenses VALUES (10), (20), (30)");
|
||||
* db.prepare("SELECT addAll(dollars) FROM expenses").pluck().get(); // => 60
|
||||
* ```
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* // With result transformation
|
||||
* db.aggregate("getAverage", {
|
||||
* start: () => [],
|
||||
* step: (array, nextValue) => { array.push(nextValue); },
|
||||
* result: (array) => array.reduce((a, b) => a + b, 0) / array.length,
|
||||
* });
|
||||
* ```
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* // As a window function (with inverse)
|
||||
* db.aggregate("addAll", {
|
||||
* start: 0,
|
||||
* step: (total, nextValue) => total + nextValue,
|
||||
* inverse: (total, droppedValue) => total - droppedValue,
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
aggregate<T = any>(
|
||||
name: string,
|
||||
options: {
|
||||
/**
|
||||
* The initial accumulator value. If a function, it's called to
|
||||
* produce a fresh initial value for each aggregate invocation.
|
||||
*/
|
||||
start?: T | (() => T);
|
||||
/**
|
||||
* Called once per row. Return the new accumulator value.
|
||||
* If `undefined` is returned, the accumulator is not replaced
|
||||
* (useful for in-place mutations like pushing to an array).
|
||||
*/
|
||||
step: (total: T, ...values: any[]) => T | void;
|
||||
/**
|
||||
* Optional. Transforms the final accumulator value before
|
||||
* returning it as the SQL result.
|
||||
*/
|
||||
result?: (total: T) => any;
|
||||
/**
|
||||
* Optional. If provided, the aggregate can be used as a window function.
|
||||
* Called when a row is removed from the window.
|
||||
*/
|
||||
inverse?: (total: T, ...dropped: any[]) => T | void;
|
||||
/**
|
||||
* If `true`, the function accepts any number of arguments per row.
|
||||
*/
|
||||
varargs?: boolean;
|
||||
/**
|
||||
* If `true`, SQLite may cache results for the same inputs.
|
||||
*/
|
||||
deterministic?: boolean;
|
||||
/**
|
||||
* If `true`, the function can only be called from top-level SQL.
|
||||
*/
|
||||
directOnly?: boolean;
|
||||
/**
|
||||
* If `true`, integer arguments are passed as `BigInt` instead of `number`.
|
||||
*/
|
||||
safeIntegers?: boolean;
|
||||
},
|
||||
): this;
|
||||
|
||||
/**
|
||||
* See `sqlite3_file_control` for more information.
|
||||
* @link https://www.sqlite.org/c3ref/file_control.html
|
||||
|
||||
@@ -266,10 +266,473 @@ extern "C" void Bun__closeAllSQLiteDatabasesForTermination()
|
||||
namespace WebCore {
|
||||
using namespace JSC;
|
||||
|
||||
/* ******************************************************************************** */
|
||||
// User-Defined Functions support
|
||||
/* ******************************************************************************** */
|
||||
|
||||
// GC-managed object that holds all JS references for a user-defined function.
|
||||
// WriteBarrier<> fields are traced by visitChildren, so no Strong<> handles
|
||||
// are needed for individual callbacks or accumulators.
|
||||
class JSUserDefinedFunction final : public JSDestructibleObject {
|
||||
public:
|
||||
using Base = JSDestructibleObject;
|
||||
static constexpr DestructionMode needsDestruction = NeedsDestruction;
|
||||
|
||||
WriteBarrier<JSObject> m_scalarFn;
|
||||
WriteBarrier<JSObject> m_stepFn;
|
||||
WriteBarrier<JSObject> m_resultFn;
|
||||
WriteBarrier<JSObject> m_inverseFn;
|
||||
WriteBarrier<Unknown> m_startValue;
|
||||
bool m_safeIntegers = false;
|
||||
bool m_startIsFunction = false;
|
||||
|
||||
// Per-group accumulator management for aggregates.
|
||||
// Each slot holds one group's current accumulator value.
|
||||
Vector<WriteBarrier<Unknown>> m_accumulators;
|
||||
Vector<size_t> m_freeSlots;
|
||||
|
||||
size_t allocAccumulator(VM& vm, JSValue initial)
|
||||
{
|
||||
size_t idx;
|
||||
if (!m_freeSlots.isEmpty()) {
|
||||
idx = m_freeSlots.takeLast();
|
||||
m_accumulators[idx].set(vm, this, initial);
|
||||
} else {
|
||||
idx = m_accumulators.size();
|
||||
m_accumulators.append(WriteBarrier<Unknown>());
|
||||
m_accumulators.last().set(vm, this, initial);
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
void freeAccumulator(size_t idx)
|
||||
{
|
||||
m_accumulators[idx].clear();
|
||||
m_freeSlots.append(idx);
|
||||
}
|
||||
|
||||
static JSUserDefinedFunction* create(VM& vm, JSGlobalObject* globalObject, Structure* structure)
|
||||
{
|
||||
auto* obj = new (NotNull, allocateCell<JSUserDefinedFunction>(vm)) JSUserDefinedFunction(vm, structure);
|
||||
obj->finishCreation(vm);
|
||||
return obj;
|
||||
}
|
||||
|
||||
static Structure* createStructure(VM& vm, JSGlobalObject* globalObject, JSValue prototype)
|
||||
{
|
||||
return Structure::create(vm, globalObject, prototype, TypeInfo(ObjectType, StructureFlags), info());
|
||||
}
|
||||
|
||||
DECLARE_INFO;
|
||||
DECLARE_VISIT_CHILDREN;
|
||||
|
||||
template<typename, SubspaceAccess mode>
|
||||
static GCClient::IsoSubspace* subspaceFor(VM& vm)
|
||||
{
|
||||
return subspaceForImpl<JSUserDefinedFunction, UseCustomHeapCellType::No>(
|
||||
vm,
|
||||
[](auto& spaces) { return spaces.m_clientSubspaceForJSUserDefinedFunction.get(); },
|
||||
[](auto& spaces, auto&& space) { spaces.m_clientSubspaceForJSUserDefinedFunction = std::forward<decltype(space)>(space); },
|
||||
[](auto& spaces) { return spaces.m_subspaceForJSUserDefinedFunction.get(); },
|
||||
[](auto& spaces, auto&& space) { spaces.m_subspaceForJSUserDefinedFunction = std::forward<decltype(space)>(space); });
|
||||
}
|
||||
|
||||
static void destroy(JSCell* cell)
|
||||
{
|
||||
static_cast<JSUserDefinedFunction*>(cell)->~JSUserDefinedFunction();
|
||||
}
|
||||
|
||||
private:
|
||||
JSUserDefinedFunction(VM& vm, Structure* structure)
|
||||
: Base(vm, structure)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
const ClassInfo JSUserDefinedFunction::s_info = { "UserDefinedFunction"_s, &Base::s_info, nullptr, nullptr, CREATE_METHOD_TABLE(JSUserDefinedFunction) };
|
||||
|
||||
template<typename Visitor>
|
||||
void JSUserDefinedFunction::visitChildrenImpl(JSCell* cell, Visitor& visitor)
|
||||
{
|
||||
auto* thisObject = jsCast<JSUserDefinedFunction*>(cell);
|
||||
ASSERT_GC_OBJECT_INHERITS(thisObject, info());
|
||||
Base::visitChildren(thisObject, visitor);
|
||||
|
||||
visitor.append(thisObject->m_scalarFn);
|
||||
visitor.append(thisObject->m_stepFn);
|
||||
visitor.append(thisObject->m_resultFn);
|
||||
visitor.append(thisObject->m_inverseFn);
|
||||
visitor.append(thisObject->m_startValue);
|
||||
|
||||
for (auto& acc : thisObject->m_accumulators)
|
||||
visitor.append(acc);
|
||||
}
|
||||
|
||||
DEFINE_VISIT_CHILDREN(JSUserDefinedFunction);
|
||||
|
||||
// Thin pointer wrapper that SQLite owns via sqlite3_user_data / xDestroy.
|
||||
// Holds a single Strong<> to prevent the GC object from being collected.
|
||||
struct UDFPointer {
|
||||
Strong<JSUserDefinedFunction> prevent_gc;
|
||||
|
||||
JSUserDefinedFunction* get() const { return prevent_gc.get(); }
|
||||
};
|
||||
|
||||
static void destroyUDFPointer(void* ptr)
|
||||
{
|
||||
delete static_cast<UDFPointer*>(ptr);
|
||||
}
|
||||
|
||||
// Per-group context allocated by sqlite3_aggregate_context()
|
||||
struct AggregateContext {
|
||||
size_t accumulatorIndex;
|
||||
bool initialized;
|
||||
};
|
||||
|
||||
// Convert a sqlite3_value to a JSValue
|
||||
static JSValue sqliteValueToJS(JSGlobalObject* globalObject, sqlite3_value* value, bool safeIntegers)
|
||||
{
|
||||
auto& vm = getVM(globalObject);
|
||||
switch (sqlite3_value_type(value)) {
|
||||
case SQLITE_INTEGER: {
|
||||
int64_t intVal = sqlite3_value_int64(value);
|
||||
if (safeIntegers) {
|
||||
return JSBigInt::createFrom(globalObject, intVal);
|
||||
}
|
||||
return jsNumber(intVal);
|
||||
}
|
||||
case SQLITE_FLOAT:
|
||||
return jsNumber(sqlite3_value_double(value));
|
||||
case SQLITE3_TEXT: {
|
||||
int len = sqlite3_value_bytes(value);
|
||||
const unsigned char* text = sqlite3_value_text(value);
|
||||
if (!text || len == 0)
|
||||
return jsEmptyString(vm);
|
||||
if (len < 64)
|
||||
return jsString(vm, WTF::String::fromUTF8({ text, static_cast<size_t>(len) }));
|
||||
auto encoded = Bun__encoding__toStringUTF8(text, len, globalObject);
|
||||
return JSValue::decode(encoded);
|
||||
}
|
||||
case SQLITE_BLOB: {
|
||||
int len = sqlite3_value_bytes(value);
|
||||
const void* blob = sqlite3_value_blob(value);
|
||||
if (len > 0 && blob) {
|
||||
auto* array = JSUint8Array::createUninitialized(globalObject, globalObject->m_typedArrayUint8.get(globalObject), len);
|
||||
if (array)
|
||||
memcpy(array->vector(), blob, len);
|
||||
return array ? array : jsNull();
|
||||
}
|
||||
auto array = JSUint8Array::create(globalObject, globalObject->m_typedArrayUint8.get(globalObject), 0);
|
||||
return array ? array : jsNull();
|
||||
}
|
||||
case SQLITE_NULL:
|
||||
default:
|
||||
return jsNull();
|
||||
}
|
||||
}
|
||||
|
||||
// Convert a JSValue to a sqlite3_result
|
||||
static void jsValueToSQLiteResult(JSGlobalObject* globalObject, sqlite3_context* ctx, JSValue value)
|
||||
{
|
||||
if (value.isUndefinedOrNull()) {
|
||||
sqlite3_result_null(ctx);
|
||||
} else if (value.isBoolean()) {
|
||||
sqlite3_result_int(ctx, value.asBoolean() ? 1 : 0);
|
||||
} else if (value.isAnyInt()) {
|
||||
int64_t val = value.asAnyInt();
|
||||
if (val >= INT_MIN && val <= INT_MAX) {
|
||||
sqlite3_result_int(ctx, static_cast<int>(val));
|
||||
} else {
|
||||
sqlite3_result_int64(ctx, val);
|
||||
}
|
||||
} else if (value.isNumber()) {
|
||||
sqlite3_result_double(ctx, value.asDouble());
|
||||
} else if (value.isString()) {
|
||||
auto* str = value.toStringOrNull(globalObject);
|
||||
if (!str) {
|
||||
sqlite3_result_null(ctx);
|
||||
return;
|
||||
}
|
||||
auto view = str->view(globalObject);
|
||||
auto utf8 = view->utf8();
|
||||
sqlite3_result_text(ctx, utf8.data(), utf8.length(), SQLITE_TRANSIENT);
|
||||
} else if (value.isHeapBigInt()) {
|
||||
sqlite3_result_int64(ctx, JSBigInt::toBigInt64(value));
|
||||
} else if (auto* buffer = jsDynamicCast<JSArrayBufferView*>(value)) {
|
||||
sqlite3_result_blob(ctx, buffer->vector(), buffer->byteLength(), SQLITE_TRANSIENT);
|
||||
} else {
|
||||
sqlite3_result_error(ctx, "User-defined function returned an unsupported type", -1);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to extract an error message from a pending exception, clear the
|
||||
// exception, and forward it to SQLite via sqlite3_result_error.
|
||||
// Must be called while the exception is still pending (before clearing).
|
||||
static void propagateExceptionToSQLite(JSGlobalObject* globalObject, VM& vm, ThrowScope& scope, sqlite3_context* ctx, const char* fallbackMessage)
|
||||
{
|
||||
JSC::Exception* exception = scope.exception();
|
||||
// Grab the value while the exception is still rooted by the VM.
|
||||
JSValue errorValue = exception->value();
|
||||
// Now clear so that we can safely re-enter JS for toString conversion.
|
||||
if (!scope.tryClearException()) {
|
||||
// Termination exception — can't clear, just set a generic error.
|
||||
sqlite3_result_error(ctx, "Terminated", -1);
|
||||
return;
|
||||
}
|
||||
|
||||
auto* errorString = errorValue.toStringOrNull(globalObject);
|
||||
if (errorString && !scope.exception()) {
|
||||
auto utf8 = errorString->view(globalObject)->utf8();
|
||||
if (!scope.exception()) {
|
||||
sqlite3_result_error(ctx, utf8.data(), utf8.length());
|
||||
} else {
|
||||
(void)scope.tryClearException();
|
||||
sqlite3_result_error(ctx, fallbackMessage, -1);
|
||||
}
|
||||
} else {
|
||||
if (scope.exception())
|
||||
(void)scope.tryClearException();
|
||||
sqlite3_result_error(ctx, fallbackMessage, -1);
|
||||
}
|
||||
}
|
||||
|
||||
static JSUserDefinedFunction* udfFromCtx(sqlite3_context* ctx)
|
||||
{
|
||||
return static_cast<UDFPointer*>(sqlite3_user_data(ctx))->get();
|
||||
}
|
||||
|
||||
static void scalarFunctionCallback(sqlite3_context* ctx, int argc, sqlite3_value** argv)
|
||||
{
|
||||
auto* udf = udfFromCtx(ctx);
|
||||
auto* globalObject = udf->globalObject();
|
||||
auto& vm = getVM(globalObject);
|
||||
auto scope = DECLARE_THROW_SCOPE(vm);
|
||||
|
||||
MarkedArgumentBuffer args;
|
||||
for (int i = 0; i < argc; i++) {
|
||||
args.append(sqliteValueToJS(globalObject, argv[i], udf->m_safeIntegers));
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Failed to convert argument");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto callData = getCallData(udf->m_scalarFn.get());
|
||||
JSValue result = call(globalObject, udf->m_scalarFn.get(), callData, jsUndefined(), args);
|
||||
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "User-defined function threw an error");
|
||||
return;
|
||||
}
|
||||
|
||||
jsValueToSQLiteResult(globalObject, ctx, result);
|
||||
if (scope.exception())
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Failed to convert return value");
|
||||
}
|
||||
|
||||
// Returns the start value for an aggregate. Caller must check for exceptions
|
||||
// on the ThrowScope after calling this — if the start function throws, the
|
||||
// exception will be pending on the VM.
|
||||
static JSValue getAggregateStartValue(JSUserDefinedFunction* udf)
|
||||
{
|
||||
if (udf->m_startIsFunction) {
|
||||
auto* globalObject = udf->globalObject();
|
||||
auto callData = getCallData(asObject(udf->m_startValue.get()));
|
||||
return call(globalObject, asObject(udf->m_startValue.get()), callData, jsUndefined(), MarkedArgumentBuffer());
|
||||
}
|
||||
return udf->m_startValue.get();
|
||||
}
|
||||
|
||||
static AggregateContext* getOrInitAggregateContext(sqlite3_context* ctx, JSUserDefinedFunction* udf, ThrowScope& scope)
|
||||
{
|
||||
auto* aggCtx = static_cast<AggregateContext*>(sqlite3_aggregate_context(ctx, sizeof(AggregateContext)));
|
||||
if (!aggCtx)
|
||||
return nullptr;
|
||||
if (!aggCtx->initialized) {
|
||||
auto* globalObject = udf->globalObject();
|
||||
auto& vm = getVM(globalObject);
|
||||
JSValue start = getAggregateStartValue(udf);
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Aggregate start function threw an error");
|
||||
return nullptr;
|
||||
}
|
||||
aggCtx->accumulatorIndex = udf->allocAccumulator(vm, start);
|
||||
aggCtx->initialized = true;
|
||||
}
|
||||
return aggCtx;
|
||||
}
|
||||
|
||||
static void aggregateStepCallback(sqlite3_context* ctx, int argc, sqlite3_value** argv)
|
||||
{
|
||||
auto* udf = udfFromCtx(ctx);
|
||||
auto* globalObject = udf->globalObject();
|
||||
auto& vm = getVM(globalObject);
|
||||
auto scope = DECLARE_THROW_SCOPE(vm);
|
||||
|
||||
auto* aggCtx = getOrInitAggregateContext(ctx, udf, scope);
|
||||
if (!aggCtx) {
|
||||
if (!scope.exception())
|
||||
sqlite3_result_error_nomem(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
JSValue accumulator = udf->m_accumulators[aggCtx->accumulatorIndex].get();
|
||||
|
||||
MarkedArgumentBuffer args;
|
||||
args.append(accumulator);
|
||||
for (int i = 0; i < argc; i++) {
|
||||
args.append(sqliteValueToJS(globalObject, argv[i], udf->m_safeIntegers));
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Failed to convert argument");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto callData = getCallData(udf->m_stepFn.get());
|
||||
JSValue result = call(globalObject, udf->m_stepFn.get(), callData, jsUndefined(), args);
|
||||
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Aggregate step function threw an error");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!result.isUndefined()) {
|
||||
udf->m_accumulators[aggCtx->accumulatorIndex].set(vm, udf, result);
|
||||
}
|
||||
}
|
||||
|
||||
static void aggregateFinalCallback(sqlite3_context* ctx)
|
||||
{
|
||||
auto* udf = udfFromCtx(ctx);
|
||||
auto* globalObject = udf->globalObject();
|
||||
auto& vm = getVM(globalObject);
|
||||
auto scope = DECLARE_THROW_SCOPE(vm);
|
||||
|
||||
auto* aggCtx = static_cast<AggregateContext*>(sqlite3_aggregate_context(ctx, 0));
|
||||
JSValue accumulator;
|
||||
|
||||
if (aggCtx && aggCtx->initialized) {
|
||||
accumulator = udf->m_accumulators[aggCtx->accumulatorIndex].get();
|
||||
} else {
|
||||
accumulator = getAggregateStartValue(udf);
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Aggregate start function threw an error");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
JSValue result;
|
||||
if (udf->m_resultFn.get()) {
|
||||
MarkedArgumentBuffer args;
|
||||
args.append(accumulator);
|
||||
auto callData = getCallData(udf->m_resultFn.get());
|
||||
result = call(globalObject, udf->m_resultFn.get(), callData, jsUndefined(), args);
|
||||
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Aggregate result function threw an error");
|
||||
if (aggCtx && aggCtx->initialized)
|
||||
udf->freeAccumulator(aggCtx->accumulatorIndex);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
result = accumulator;
|
||||
}
|
||||
|
||||
jsValueToSQLiteResult(globalObject, ctx, result);
|
||||
if (scope.exception())
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Failed to convert return value");
|
||||
|
||||
if (aggCtx && aggCtx->initialized)
|
||||
udf->freeAccumulator(aggCtx->accumulatorIndex);
|
||||
}
|
||||
|
||||
static void windowValueCallback(sqlite3_context* ctx)
|
||||
{
|
||||
auto* udf = udfFromCtx(ctx);
|
||||
auto* globalObject = udf->globalObject();
|
||||
auto& vm = getVM(globalObject);
|
||||
auto scope = DECLARE_THROW_SCOPE(vm);
|
||||
|
||||
auto* aggCtx = static_cast<AggregateContext*>(sqlite3_aggregate_context(ctx, 0));
|
||||
JSValue accumulator;
|
||||
|
||||
if (aggCtx && aggCtx->initialized) {
|
||||
accumulator = udf->m_accumulators[aggCtx->accumulatorIndex].get();
|
||||
} else {
|
||||
accumulator = getAggregateStartValue(udf);
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Aggregate start function threw an error");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
JSValue result;
|
||||
if (udf->m_resultFn.get()) {
|
||||
MarkedArgumentBuffer args;
|
||||
args.append(accumulator);
|
||||
auto callData = getCallData(udf->m_resultFn.get());
|
||||
result = call(globalObject, udf->m_resultFn.get(), callData, jsUndefined(), args);
|
||||
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Window value function threw an error");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
result = accumulator;
|
||||
}
|
||||
|
||||
jsValueToSQLiteResult(globalObject, ctx, result);
|
||||
if (scope.exception())
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Failed to convert return value");
|
||||
}
|
||||
|
||||
static void windowInverseCallback(sqlite3_context* ctx, int argc, sqlite3_value** argv)
|
||||
{
|
||||
auto* udf = udfFromCtx(ctx);
|
||||
auto* globalObject = udf->globalObject();
|
||||
auto& vm = getVM(globalObject);
|
||||
auto scope = DECLARE_THROW_SCOPE(vm);
|
||||
|
||||
auto* aggCtx = getOrInitAggregateContext(ctx, udf, scope);
|
||||
if (!aggCtx) {
|
||||
if (!scope.exception())
|
||||
sqlite3_result_error_nomem(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
JSValue accumulator = udf->m_accumulators[aggCtx->accumulatorIndex].get();
|
||||
|
||||
MarkedArgumentBuffer args;
|
||||
args.append(accumulator);
|
||||
for (int i = 0; i < argc; i++) {
|
||||
args.append(sqliteValueToJS(globalObject, argv[i], udf->m_safeIntegers));
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Failed to convert argument");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto callData = getCallData(udf->m_inverseFn.get());
|
||||
JSValue result = call(globalObject, udf->m_inverseFn.get(), callData, jsUndefined(), args);
|
||||
|
||||
if (scope.exception()) {
|
||||
propagateExceptionToSQLite(globalObject, vm, scope, ctx, "Window inverse function threw an error");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!result.isUndefined()) {
|
||||
udf->m_accumulators[aggCtx->accumulatorIndex].set(vm, udf, result);
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
|
||||
JSC_DECLARE_HOST_FUNCTION(jsSQLStatementPrepareStatementFunction);
|
||||
JSC_DECLARE_HOST_FUNCTION(jsSQLStatementExecuteFunction);
|
||||
JSC_DECLARE_HOST_FUNCTION(jsSQLStatementOpenStatementFunction);
|
||||
JSC_DECLARE_HOST_FUNCTION(jsSQLStatementIsInTransactionFunction);
|
||||
JSC_DECLARE_HOST_FUNCTION(jsSQLStatementCreateFunction);
|
||||
JSC_DECLARE_HOST_FUNCTION(jsSQLStatementCreateAggregateFunction);
|
||||
|
||||
JSC_DECLARE_HOST_FUNCTION(jsSQLStatementLoadExtensionFunction);
|
||||
|
||||
@@ -1841,6 +2304,216 @@ JSC_DEFINE_HOST_FUNCTION(jsSQLStatementFcntlFunction, (JSC::JSGlobalObject * lex
|
||||
return JSValue::encode(jsNumber(statusCode));
|
||||
}
|
||||
|
||||
// createFunction(dbIndex, name, nArgs, flags, callback, safeIntegers)
|
||||
JSC_DEFINE_HOST_FUNCTION(jsSQLStatementCreateFunction, (JSC::JSGlobalObject * lexicalGlobalObject, JSC::CallFrame* callFrame))
|
||||
{
|
||||
auto& vm = JSC::getVM(lexicalGlobalObject);
|
||||
auto scope = DECLARE_THROW_SCOPE(vm);
|
||||
|
||||
JSValue thisValue = callFrame->thisValue();
|
||||
JSSQLStatementConstructor* thisObject = jsDynamicCast<JSSQLStatementConstructor*>(thisValue.getObject());
|
||||
if (!thisObject) [[unlikely]] {
|
||||
throwException(lexicalGlobalObject, scope, createError(lexicalGlobalObject, "Expected SQL"_s));
|
||||
return {};
|
||||
}
|
||||
|
||||
if (callFrame->argumentCount() < 5) {
|
||||
throwException(lexicalGlobalObject, scope, createError(lexicalGlobalObject, "Expected at least 5 arguments"_s));
|
||||
return {};
|
||||
}
|
||||
|
||||
int32_t dbIndex = callFrame->argument(0).toInt32(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
if (dbIndex < 0 || dbIndex >= databases().size()) [[unlikely]] {
|
||||
throwException(lexicalGlobalObject, scope, createError(lexicalGlobalObject, "Invalid database handle"_s));
|
||||
return {};
|
||||
}
|
||||
sqlite3* db = databases()[dbIndex]->db;
|
||||
if (!db) [[unlikely]] {
|
||||
throwException(lexicalGlobalObject, scope, createError(lexicalGlobalObject, "Can't do this on a closed database"_s));
|
||||
return {};
|
||||
}
|
||||
|
||||
JSC::JSValue nameValue = callFrame->argument(1);
|
||||
if (!nameValue.isString()) {
|
||||
throwException(lexicalGlobalObject, scope, createTypeError(lexicalGlobalObject, "Expected function name to be a string"_s));
|
||||
return {};
|
||||
}
|
||||
auto nameStr = nameValue.toWTFString(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
auto nameUtf8 = nameStr.utf8();
|
||||
|
||||
int nArgs = callFrame->argument(2).toInt32(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
int flags = callFrame->argument(3).toInt32(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
|
||||
JSC::JSValue callbackValue = callFrame->argument(4);
|
||||
JSC::JSObject* callback = jsDynamicCast<JSC::JSObject*>(callbackValue);
|
||||
if (!callback || !callback->isCallable()) {
|
||||
throwException(lexicalGlobalObject, scope, createTypeError(lexicalGlobalObject, "Expected callback to be a function"_s));
|
||||
return {};
|
||||
}
|
||||
|
||||
bool safeIntegers = callFrame->argumentCount() > 5 && callFrame->argument(5).toBoolean(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
|
||||
Structure* udfStructure = JSUserDefinedFunction::createStructure(vm, lexicalGlobalObject, jsNull());
|
||||
auto* udf = JSUserDefinedFunction::create(vm, lexicalGlobalObject, udfStructure);
|
||||
udf->m_scalarFn.set(vm, udf, callback);
|
||||
udf->m_safeIntegers = safeIntegers;
|
||||
|
||||
auto* ptr = new UDFPointer { Strong<JSUserDefinedFunction>(vm, udf) };
|
||||
|
||||
int rc = sqlite3_create_function_v2(
|
||||
db,
|
||||
nameUtf8.data(),
|
||||
nArgs,
|
||||
SQLITE_UTF8 | flags,
|
||||
ptr,
|
||||
scalarFunctionCallback,
|
||||
nullptr,
|
||||
nullptr,
|
||||
destroyUDFPointer);
|
||||
|
||||
if (rc != SQLITE_OK) {
|
||||
delete ptr;
|
||||
throwException(lexicalGlobalObject, scope, createSQLiteError(lexicalGlobalObject, db));
|
||||
return {};
|
||||
}
|
||||
|
||||
RELEASE_AND_RETURN(scope, JSValue::encode(jsUndefined()));
|
||||
}
|
||||
|
||||
// createAggregate(dbIndex, name, nArgs, flags, stepFn, resultFn, startValue, startIsFunction, inverseFn, safeIntegers)
|
||||
JSC_DEFINE_HOST_FUNCTION(jsSQLStatementCreateAggregateFunction, (JSC::JSGlobalObject * lexicalGlobalObject, JSC::CallFrame* callFrame))
|
||||
{
|
||||
auto& vm = JSC::getVM(lexicalGlobalObject);
|
||||
auto scope = DECLARE_THROW_SCOPE(vm);
|
||||
|
||||
JSValue thisValue = callFrame->thisValue();
|
||||
JSSQLStatementConstructor* thisObject = jsDynamicCast<JSSQLStatementConstructor*>(thisValue.getObject());
|
||||
if (!thisObject) [[unlikely]] {
|
||||
throwException(lexicalGlobalObject, scope, createError(lexicalGlobalObject, "Expected SQL"_s));
|
||||
return {};
|
||||
}
|
||||
|
||||
if (callFrame->argumentCount() < 8) {
|
||||
throwException(lexicalGlobalObject, scope, createError(lexicalGlobalObject, "Expected at least 8 arguments"_s));
|
||||
return {};
|
||||
}
|
||||
|
||||
int32_t dbIndex = callFrame->argument(0).toInt32(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
if (dbIndex < 0 || dbIndex >= databases().size()) [[unlikely]] {
|
||||
throwException(lexicalGlobalObject, scope, createError(lexicalGlobalObject, "Invalid database handle"_s));
|
||||
return {};
|
||||
}
|
||||
sqlite3* db = databases()[dbIndex]->db;
|
||||
if (!db) [[unlikely]] {
|
||||
throwException(lexicalGlobalObject, scope, createError(lexicalGlobalObject, "Can't do this on a closed database"_s));
|
||||
return {};
|
||||
}
|
||||
|
||||
JSC::JSValue nameValue = callFrame->argument(1);
|
||||
if (!nameValue.isString()) {
|
||||
throwException(lexicalGlobalObject, scope, createTypeError(lexicalGlobalObject, "Expected function name to be a string"_s));
|
||||
return {};
|
||||
}
|
||||
auto nameStr = nameValue.toWTFString(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
auto nameUtf8 = nameStr.utf8();
|
||||
|
||||
int nArgs = callFrame->argument(2).toInt32(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
int flags = callFrame->argument(3).toInt32(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
|
||||
// step function (required)
|
||||
JSC::JSObject* stepFn = jsDynamicCast<JSC::JSObject*>(callFrame->argument(4));
|
||||
if (!stepFn || !stepFn->isCallable()) {
|
||||
throwException(lexicalGlobalObject, scope, createTypeError(lexicalGlobalObject, "Expected step to be a function"_s));
|
||||
return {};
|
||||
}
|
||||
|
||||
// result function (optional, null if not provided)
|
||||
JSC::JSObject* resultFn = nullptr;
|
||||
JSC::JSValue resultFnValue = callFrame->argument(5);
|
||||
if (!resultFnValue.isNull() && !resultFnValue.isUndefined()) {
|
||||
resultFn = jsDynamicCast<JSC::JSObject*>(resultFnValue);
|
||||
if (!resultFn || !resultFn->isCallable()) {
|
||||
throwException(lexicalGlobalObject, scope, createTypeError(lexicalGlobalObject, "Expected result to be a function or null"_s));
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// start value
|
||||
JSC::JSValue startValueJS = callFrame->argument(6);
|
||||
bool startIsFunction = callFrame->argument(7).toBoolean(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
|
||||
// inverse function (optional)
|
||||
JSC::JSObject* inverseFn = nullptr;
|
||||
JSC::JSValue inverseFnValue = callFrame->argumentCount() > 8 ? callFrame->argument(8) : JSC::jsUndefined();
|
||||
if (!inverseFnValue.isNull() && !inverseFnValue.isUndefined()) {
|
||||
inverseFn = jsDynamicCast<JSC::JSObject*>(inverseFnValue);
|
||||
if (!inverseFn || !inverseFn->isCallable()) {
|
||||
throwException(lexicalGlobalObject, scope, createTypeError(lexicalGlobalObject, "Expected inverse to be a function or null"_s));
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
bool safeIntegers = callFrame->argumentCount() > 9 && callFrame->argument(9).toBoolean(lexicalGlobalObject);
|
||||
RETURN_IF_EXCEPTION(scope, {});
|
||||
|
||||
Structure* udfStructure = JSUserDefinedFunction::createStructure(vm, lexicalGlobalObject, jsNull());
|
||||
auto* udf = JSUserDefinedFunction::create(vm, lexicalGlobalObject, udfStructure);
|
||||
udf->m_stepFn.set(vm, udf, stepFn);
|
||||
if (resultFn)
|
||||
udf->m_resultFn.set(vm, udf, resultFn);
|
||||
udf->m_startValue.set(vm, udf, startValueJS);
|
||||
udf->m_startIsFunction = startIsFunction;
|
||||
if (inverseFn)
|
||||
udf->m_inverseFn.set(vm, udf, inverseFn);
|
||||
udf->m_safeIntegers = safeIntegers;
|
||||
|
||||
auto* ptr = new UDFPointer { Strong<JSUserDefinedFunction>(vm, udf) };
|
||||
|
||||
int rc;
|
||||
if (inverseFn) {
|
||||
rc = sqlite3_create_window_function(
|
||||
db,
|
||||
nameUtf8.data(),
|
||||
nArgs,
|
||||
SQLITE_UTF8 | flags,
|
||||
ptr,
|
||||
aggregateStepCallback,
|
||||
aggregateFinalCallback,
|
||||
windowValueCallback,
|
||||
windowInverseCallback,
|
||||
destroyUDFPointer);
|
||||
} else {
|
||||
rc = sqlite3_create_function_v2(
|
||||
db,
|
||||
nameUtf8.data(),
|
||||
nArgs,
|
||||
SQLITE_UTF8 | flags,
|
||||
ptr,
|
||||
nullptr,
|
||||
aggregateStepCallback,
|
||||
aggregateFinalCallback,
|
||||
destroyUDFPointer);
|
||||
}
|
||||
|
||||
if (rc != SQLITE_OK) {
|
||||
delete ptr;
|
||||
throwException(lexicalGlobalObject, scope, createSQLiteError(lexicalGlobalObject, db));
|
||||
return {};
|
||||
}
|
||||
|
||||
RELEASE_AND_RETURN(scope, JSValue::encode(jsUndefined()));
|
||||
}
|
||||
|
||||
/* Hash table for constructor */
|
||||
static const HashTableValue JSSQLStatementConstructorTableValues[] = {
|
||||
{ "open"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsSQLStatementOpenStatementFunction, 2 } },
|
||||
@@ -1853,6 +2526,8 @@ static const HashTableValue JSSQLStatementConstructorTableValues[] = {
|
||||
{ "serialize"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsSQLStatementSerialize, 1 } },
|
||||
{ "deserialize"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsSQLStatementDeserialize, 2 } },
|
||||
{ "fcntl"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsSQLStatementFcntlFunction, 2 } },
|
||||
{ "createFunction"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsSQLStatementCreateFunction, 5 } },
|
||||
{ "createAggregate"_s, static_cast<unsigned>(JSC::PropertyAttribute::Function), NoIntrinsic, { HashTableValue::NativeFunctionType, jsSQLStatementCreateAggregateFunction, 8 } },
|
||||
};
|
||||
|
||||
const ClassInfo JSSQLStatementConstructor::s_info = { "SQLStatement"_s, &Base::s_info, nullptr, nullptr, CREATE_METHOD_TABLE(JSSQLStatementConstructor) };
|
||||
|
||||
@@ -95,6 +95,26 @@ typedef int (*lazy_sqlite3_stmt_busy_type)(sqlite3_stmt* pStmt);
|
||||
typedef int (*lazy_sqlite3_compileoption_used_type)(const char* zOptName);
|
||||
typedef int64_t (*lazy_sqlite3_last_insert_rowid_type)(sqlite3* db);
|
||||
|
||||
// User-defined functions
|
||||
typedef int (*lazy_sqlite3_create_function_v2_type)(sqlite3* db, const char* zFunctionName, int nArg, int eTextRep, void* pApp, void (*xFunc)(sqlite3_context*, int, sqlite3_value**), void (*xStep)(sqlite3_context*, int, sqlite3_value**), void (*xFinal)(sqlite3_context*), void (*xDestroy)(void*));
|
||||
typedef int (*lazy_sqlite3_create_window_function_type)(sqlite3* db, const char* zFunctionName, int nArg, int eTextRep, void* pApp, void (*xStep)(sqlite3_context*, int, sqlite3_value**), void (*xFinal)(sqlite3_context*), void (*xValue)(sqlite3_context*), void (*xInverse)(sqlite3_context*, int, sqlite3_value**), void (*xDestroy)(void*));
|
||||
typedef int (*lazy_sqlite3_value_type_type)(sqlite3_value*);
|
||||
typedef sqlite3_int64 (*lazy_sqlite3_value_int64_type)(sqlite3_value*);
|
||||
typedef double (*lazy_sqlite3_value_double_type)(sqlite3_value*);
|
||||
typedef const unsigned char* (*lazy_sqlite3_value_text_type)(sqlite3_value*);
|
||||
typedef const void* (*lazy_sqlite3_value_blob_type)(sqlite3_value*);
|
||||
typedef int (*lazy_sqlite3_value_bytes_type)(sqlite3_value*);
|
||||
typedef void* (*lazy_sqlite3_user_data_type)(sqlite3_context*);
|
||||
typedef void* (*lazy_sqlite3_aggregate_context_type)(sqlite3_context*, int nBytes);
|
||||
typedef void (*lazy_sqlite3_result_null_type)(sqlite3_context*);
|
||||
typedef void (*lazy_sqlite3_result_int_type)(sqlite3_context*, int);
|
||||
typedef void (*lazy_sqlite3_result_int64_type)(sqlite3_context*, sqlite3_int64);
|
||||
typedef void (*lazy_sqlite3_result_double_type)(sqlite3_context*, double);
|
||||
typedef void (*lazy_sqlite3_result_text_type)(sqlite3_context*, const char*, int, void (*)(void*));
|
||||
typedef void (*lazy_sqlite3_result_blob_type)(sqlite3_context*, const void*, int, void (*)(void*));
|
||||
typedef void (*lazy_sqlite3_result_error_type)(sqlite3_context*, const char*, int);
|
||||
typedef void (*lazy_sqlite3_result_error_nomem_type)(sqlite3_context*);
|
||||
|
||||
static lazy_sqlite3_bind_blob_type lazy_sqlite3_bind_blob;
|
||||
static lazy_sqlite3_bind_double_type lazy_sqlite3_bind_double;
|
||||
static lazy_sqlite3_bind_int_type lazy_sqlite3_bind_int;
|
||||
@@ -147,6 +167,24 @@ static lazy_sqlite3_memory_used_type lazy_sqlite3_memory_used;
|
||||
static lazy_sqlite3_bind_parameter_name_type lazy_sqlite3_bind_parameter_name;
|
||||
static lazy_sqlite3_total_changes_type lazy_sqlite3_total_changes;
|
||||
static lazy_sqlite3_last_insert_rowid_type lazy_sqlite3_last_insert_rowid;
|
||||
static lazy_sqlite3_create_function_v2_type lazy_sqlite3_create_function_v2;
|
||||
static lazy_sqlite3_create_window_function_type lazy_sqlite3_create_window_function;
|
||||
static lazy_sqlite3_value_type_type lazy_sqlite3_value_type;
|
||||
static lazy_sqlite3_value_int64_type lazy_sqlite3_value_int64;
|
||||
static lazy_sqlite3_value_double_type lazy_sqlite3_value_double;
|
||||
static lazy_sqlite3_value_text_type lazy_sqlite3_value_text;
|
||||
static lazy_sqlite3_value_blob_type lazy_sqlite3_value_blob;
|
||||
static lazy_sqlite3_value_bytes_type lazy_sqlite3_value_bytes;
|
||||
static lazy_sqlite3_user_data_type lazy_sqlite3_user_data;
|
||||
static lazy_sqlite3_aggregate_context_type lazy_sqlite3_aggregate_context;
|
||||
static lazy_sqlite3_result_null_type lazy_sqlite3_result_null;
|
||||
static lazy_sqlite3_result_int_type lazy_sqlite3_result_int;
|
||||
static lazy_sqlite3_result_int64_type lazy_sqlite3_result_int64;
|
||||
static lazy_sqlite3_result_double_type lazy_sqlite3_result_double;
|
||||
static lazy_sqlite3_result_text_type lazy_sqlite3_result_text;
|
||||
static lazy_sqlite3_result_blob_type lazy_sqlite3_result_blob;
|
||||
static lazy_sqlite3_result_error_type lazy_sqlite3_result_error;
|
||||
static lazy_sqlite3_result_error_nomem_type lazy_sqlite3_result_error_nomem;
|
||||
|
||||
#define sqlite3_bind_blob lazy_sqlite3_bind_blob
|
||||
#define sqlite3_bind_double lazy_sqlite3_bind_double
|
||||
@@ -199,6 +237,24 @@ static lazy_sqlite3_last_insert_rowid_type lazy_sqlite3_last_insert_rowid;
|
||||
#define sqlite3_bind_parameter_name lazy_sqlite3_bind_parameter_name
|
||||
#define sqlite3_total_changes lazy_sqlite3_total_changes
|
||||
#define sqlite3_last_insert_rowid lazy_sqlite3_last_insert_rowid
|
||||
#define sqlite3_create_function_v2 lazy_sqlite3_create_function_v2
|
||||
#define sqlite3_create_window_function lazy_sqlite3_create_window_function
|
||||
#define sqlite3_value_type lazy_sqlite3_value_type
|
||||
#define sqlite3_value_int64 lazy_sqlite3_value_int64
|
||||
#define sqlite3_value_double lazy_sqlite3_value_double
|
||||
#define sqlite3_value_text lazy_sqlite3_value_text
|
||||
#define sqlite3_value_blob lazy_sqlite3_value_blob
|
||||
#define sqlite3_value_bytes lazy_sqlite3_value_bytes
|
||||
#define sqlite3_user_data lazy_sqlite3_user_data
|
||||
#define sqlite3_aggregate_context lazy_sqlite3_aggregate_context
|
||||
#define sqlite3_result_null lazy_sqlite3_result_null
|
||||
#define sqlite3_result_int lazy_sqlite3_result_int
|
||||
#define sqlite3_result_int64 lazy_sqlite3_result_int64
|
||||
#define sqlite3_result_double lazy_sqlite3_result_double
|
||||
#define sqlite3_result_text lazy_sqlite3_result_text
|
||||
#define sqlite3_result_blob lazy_sqlite3_result_blob
|
||||
#define sqlite3_result_error lazy_sqlite3_result_error
|
||||
#define sqlite3_result_error_nomem lazy_sqlite3_result_error_nomem
|
||||
|
||||
#if !OS(WINDOWS)
|
||||
#define HMODULE void*
|
||||
@@ -285,6 +341,24 @@ static int lazyLoadSQLite()
|
||||
lazy_sqlite3_bind_parameter_name = (lazy_sqlite3_bind_parameter_name_type)dlsym(sqlite3_handle, "sqlite3_bind_parameter_name");
|
||||
lazy_sqlite3_total_changes = (lazy_sqlite3_total_changes_type)dlsym(sqlite3_handle, "sqlite3_total_changes");
|
||||
lazy_sqlite3_last_insert_rowid = (lazy_sqlite3_last_insert_rowid_type)dlsym(sqlite3_handle, "sqlite3_last_insert_rowid");
|
||||
lazy_sqlite3_create_function_v2 = (lazy_sqlite3_create_function_v2_type)dlsym(sqlite3_handle, "sqlite3_create_function_v2");
|
||||
lazy_sqlite3_create_window_function = (lazy_sqlite3_create_window_function_type)dlsym(sqlite3_handle, "sqlite3_create_window_function");
|
||||
lazy_sqlite3_value_type = (lazy_sqlite3_value_type_type)dlsym(sqlite3_handle, "sqlite3_value_type");
|
||||
lazy_sqlite3_value_int64 = (lazy_sqlite3_value_int64_type)dlsym(sqlite3_handle, "sqlite3_value_int64");
|
||||
lazy_sqlite3_value_double = (lazy_sqlite3_value_double_type)dlsym(sqlite3_handle, "sqlite3_value_double");
|
||||
lazy_sqlite3_value_text = (lazy_sqlite3_value_text_type)dlsym(sqlite3_handle, "sqlite3_value_text");
|
||||
lazy_sqlite3_value_blob = (lazy_sqlite3_value_blob_type)dlsym(sqlite3_handle, "sqlite3_value_blob");
|
||||
lazy_sqlite3_value_bytes = (lazy_sqlite3_value_bytes_type)dlsym(sqlite3_handle, "sqlite3_value_bytes");
|
||||
lazy_sqlite3_user_data = (lazy_sqlite3_user_data_type)dlsym(sqlite3_handle, "sqlite3_user_data");
|
||||
lazy_sqlite3_aggregate_context = (lazy_sqlite3_aggregate_context_type)dlsym(sqlite3_handle, "sqlite3_aggregate_context");
|
||||
lazy_sqlite3_result_null = (lazy_sqlite3_result_null_type)dlsym(sqlite3_handle, "sqlite3_result_null");
|
||||
lazy_sqlite3_result_int = (lazy_sqlite3_result_int_type)dlsym(sqlite3_handle, "sqlite3_result_int");
|
||||
lazy_sqlite3_result_int64 = (lazy_sqlite3_result_int64_type)dlsym(sqlite3_handle, "sqlite3_result_int64");
|
||||
lazy_sqlite3_result_double = (lazy_sqlite3_result_double_type)dlsym(sqlite3_handle, "sqlite3_result_double");
|
||||
lazy_sqlite3_result_text = (lazy_sqlite3_result_text_type)dlsym(sqlite3_handle, "sqlite3_result_text");
|
||||
lazy_sqlite3_result_blob = (lazy_sqlite3_result_blob_type)dlsym(sqlite3_handle, "sqlite3_result_blob");
|
||||
lazy_sqlite3_result_error = (lazy_sqlite3_result_error_type)dlsym(sqlite3_handle, "sqlite3_result_error");
|
||||
lazy_sqlite3_result_error_nomem = (lazy_sqlite3_result_error_nomem_type)dlsym(sqlite3_handle, "sqlite3_result_error_nomem");
|
||||
|
||||
if (!lazy_sqlite3_extended_result_codes) {
|
||||
lazy_sqlite3_extended_result_codes = [](sqlite3*, int) -> int {
|
||||
|
||||
@@ -24,6 +24,7 @@ public:
|
||||
std::unique_ptr<GCClient::IsoSubspace> m_clientSubspaceForNapiPrototype;
|
||||
std::unique_ptr<GCClient::IsoSubspace> m_clientSubspaceForJSSQLStatement;
|
||||
std::unique_ptr<GCClient::IsoSubspace> m_clientSubspaceForJSSQLStatementConstructor;
|
||||
std::unique_ptr<GCClient::IsoSubspace> m_clientSubspaceForJSUserDefinedFunction;
|
||||
std::unique_ptr<GCClient::IsoSubspace> m_clientSubspaceForJSSinkConstructor;
|
||||
std::unique_ptr<GCClient::IsoSubspace> m_clientSubspaceForJSSinkController;
|
||||
std::unique_ptr<GCClient::IsoSubspace> m_clientSubspaceForJSSink;
|
||||
|
||||
@@ -24,6 +24,7 @@ public:
|
||||
std::unique_ptr<IsoSubspace> m_subspaceForNapiPrototype;
|
||||
std::unique_ptr<IsoSubspace> m_subspaceForJSSQLStatement;
|
||||
std::unique_ptr<IsoSubspace> m_subspaceForJSSQLStatementConstructor;
|
||||
std::unique_ptr<IsoSubspace> m_subspaceForJSUserDefinedFunction;
|
||||
std::unique_ptr<IsoSubspace> m_subspaceForJSSinkConstructor;
|
||||
std::unique_ptr<IsoSubspace> m_subspaceForJSSinkController;
|
||||
std::unique_ptr<IsoSubspace> m_subspaceForJSSink;
|
||||
|
||||
@@ -122,6 +122,26 @@ interface CppSQL {
|
||||
fcntl(handle: TODO, ...args: TODO[]): TODO;
|
||||
close(handle: TODO, throwOnError: boolean): void;
|
||||
setCustomSQLite(path: string): void;
|
||||
createFunction(
|
||||
handle: TODO,
|
||||
name: string,
|
||||
nArgs: number,
|
||||
flags: number,
|
||||
callback: Function,
|
||||
safeIntegers: boolean,
|
||||
): void;
|
||||
createAggregate(
|
||||
handle: TODO,
|
||||
name: string,
|
||||
nArgs: number,
|
||||
flags: number,
|
||||
stepFn: Function,
|
||||
resultFn: Function | null,
|
||||
startValue: any,
|
||||
startIsFunction: boolean,
|
||||
inverseFn: Function | null,
|
||||
safeIntegers: boolean,
|
||||
): void;
|
||||
}
|
||||
|
||||
let SQL: CppSQL;
|
||||
@@ -487,6 +507,89 @@ class Database implements SqliteTypes.Database {
|
||||
return SQL.setCustomSQLite(path);
|
||||
}
|
||||
|
||||
function(name, optionsOrCallback?, callback?) {
|
||||
if (typeof name !== "string") {
|
||||
throw new TypeError("Expected function name to be a string");
|
||||
}
|
||||
|
||||
let fn;
|
||||
let options;
|
||||
|
||||
if (typeof optionsOrCallback === "function") {
|
||||
fn = optionsOrCallback;
|
||||
options = {};
|
||||
} else if (typeof optionsOrCallback === "object" && optionsOrCallback !== null) {
|
||||
options = optionsOrCallback;
|
||||
fn = callback;
|
||||
} else if (optionsOrCallback === undefined || optionsOrCallback === null) {
|
||||
fn = callback;
|
||||
options = {};
|
||||
} else {
|
||||
throw new TypeError("Expected second argument to be a function or options object");
|
||||
}
|
||||
|
||||
if (typeof fn !== "function") {
|
||||
throw new TypeError("Expected callback to be a function");
|
||||
}
|
||||
|
||||
if (!SQL) {
|
||||
initializeSQL();
|
||||
}
|
||||
|
||||
const nArgs = options.varargs ? -1 : fn.length;
|
||||
let flags = 0;
|
||||
if (options.deterministic) flags |= 0x000000800; // SQLITE_DETERMINISTIC
|
||||
if (options.directOnly) flags |= 0x000080000; // SQLITE_DIRECTONLY
|
||||
|
||||
SQL.createFunction(this.#handle, name, nArgs, flags, fn, !!options.safeIntegers);
|
||||
return this;
|
||||
}
|
||||
|
||||
aggregate(name, options?) {
|
||||
if (typeof name !== "string") {
|
||||
throw new TypeError("Expected aggregate name to be a string");
|
||||
}
|
||||
|
||||
if (typeof options !== "object" || options === null) {
|
||||
throw new TypeError("Expected options to be an object");
|
||||
}
|
||||
|
||||
const step = options.step;
|
||||
if (typeof step !== "function") {
|
||||
throw new TypeError("Expected step to be a function");
|
||||
}
|
||||
|
||||
if (!SQL) {
|
||||
initializeSQL();
|
||||
}
|
||||
|
||||
// Determine nArgs from step function (minus 1 for the accumulator parameter)
|
||||
const nArgs = options.varargs ? -1 : step.length > 0 ? step.length - 1 : 0;
|
||||
let flags = 0;
|
||||
if (options.deterministic) flags |= 0x000000800; // SQLITE_DETERMINISTIC
|
||||
if (options.directOnly) flags |= 0x000080000; // SQLITE_DIRECTONLY
|
||||
|
||||
const resultFn = typeof options.result === "function" ? options.result : null;
|
||||
const inverseFn = typeof options.inverse === "function" ? options.inverse : null;
|
||||
|
||||
let startValue = options.start !== undefined ? options.start : null;
|
||||
const startIsFunction = typeof startValue === "function";
|
||||
|
||||
SQL.createAggregate(
|
||||
this.#handle,
|
||||
name,
|
||||
nArgs,
|
||||
flags,
|
||||
step,
|
||||
resultFn,
|
||||
startValue,
|
||||
startIsFunction,
|
||||
inverseFn,
|
||||
!!options.safeIntegers,
|
||||
);
|
||||
return this;
|
||||
}
|
||||
|
||||
fileControl(_cmd, _arg) {
|
||||
const handle = this.#handle;
|
||||
|
||||
|
||||
411
test/js/bun/sqlite/sqlite-user-defined-functions.test.ts
Normal file
411
test/js/bun/sqlite/sqlite-user-defined-functions.test.ts
Normal file
@@ -0,0 +1,411 @@
|
||||
import { Database } from "bun:sqlite";
|
||||
import { describe, expect, test } from "bun:test";
|
||||
|
||||
describe("Database.prototype.function()", () => {
|
||||
test("basic scalar function", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("add2", (a: number, b: number) => a + b);
|
||||
|
||||
const result = db.prepare("SELECT add2(12, 4) as val").get() as any;
|
||||
expect(result.val).toBe(16);
|
||||
});
|
||||
|
||||
test("string concatenation", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("concat", (a: string, b: string) => a + b);
|
||||
|
||||
const result = db.prepare("SELECT concat('foo', 'bar') as val").get() as any;
|
||||
expect(result.val).toBe("foobar");
|
||||
});
|
||||
|
||||
test("returns null for undefined", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("void_fn", { varargs: true }, () => {});
|
||||
|
||||
const result = db.prepare("SELECT void_fn() as val").get() as any;
|
||||
expect(result.val).toBeNull();
|
||||
});
|
||||
|
||||
test("varargs function", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("sum_all", { varargs: true }, (...args: number[]) => {
|
||||
return args.reduce((a, b) => a + b, 0);
|
||||
});
|
||||
|
||||
expect((db.prepare("SELECT sum_all(1, 2, 3) as val").get() as any).val).toBe(6);
|
||||
expect((db.prepare("SELECT sum_all(10) as val").get() as any).val).toBe(10);
|
||||
expect((db.prepare("SELECT sum_all(1, 2, 3, 4, 5) as val").get() as any).val).toBe(15);
|
||||
});
|
||||
|
||||
test("deterministic option", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("double", { deterministic: true }, (x: number) => x * 2);
|
||||
|
||||
const result = db.prepare("SELECT double(21) as val").get() as any;
|
||||
expect(result.val).toBe(42);
|
||||
});
|
||||
|
||||
test("handles null arguments", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("is_null", (x: any) => (x === null ? 1 : 0));
|
||||
|
||||
expect((db.prepare("SELECT is_null(NULL) as val").get() as any).val).toBe(1);
|
||||
expect((db.prepare("SELECT is_null(42) as val").get() as any).val).toBe(0);
|
||||
});
|
||||
|
||||
test("handles blob arguments and return", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("reverse_blob", (x: Uint8Array) => {
|
||||
return new Uint8Array(x).reverse();
|
||||
});
|
||||
|
||||
db.exec("CREATE TABLE blobs (data BLOB)");
|
||||
db.exec("INSERT INTO blobs VALUES (X'010203')");
|
||||
|
||||
const result = db.prepare("SELECT reverse_blob(data) as val FROM blobs").get() as any;
|
||||
expect(result.val).toEqual(new Uint8Array([3, 2, 1]));
|
||||
});
|
||||
|
||||
test("handles float return", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("half", (x: number) => x / 2);
|
||||
|
||||
const result = db.prepare("SELECT half(7) as val").get() as any;
|
||||
expect(result.val).toBe(3.5);
|
||||
});
|
||||
|
||||
test("function error propagation", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("throw_err", () => {
|
||||
throw new Error("custom error");
|
||||
});
|
||||
|
||||
expect(() => db.prepare("SELECT throw_err() as val").get()).toThrow("custom error");
|
||||
});
|
||||
|
||||
test("returns this for chaining", () => {
|
||||
const db = new Database(":memory:");
|
||||
const result = db.function("noop", () => null);
|
||||
expect(result).toBe(db);
|
||||
});
|
||||
|
||||
test("overrides function with same name and arity", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("myfn", (x: number) => x * 2);
|
||||
expect((db.prepare("SELECT myfn(5) as val").get() as any).val).toBe(10);
|
||||
|
||||
db.function("myfn", (x: number) => x * 3);
|
||||
expect((db.prepare("SELECT myfn(5) as val").get() as any).val).toBe(15);
|
||||
});
|
||||
|
||||
test("multiple functions with different arities", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("myfn", (x: number) => x * 10);
|
||||
db.function("myfn", (x: number, y: number) => x + y);
|
||||
|
||||
expect((db.prepare("SELECT myfn(5) as val").get() as any).val).toBe(50);
|
||||
expect((db.prepare("SELECT myfn(5, 3) as val").get() as any).val).toBe(8);
|
||||
});
|
||||
|
||||
test("used in WHERE clause", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("is_even", (x: number) => (x % 2 === 0 ? 1 : 0));
|
||||
|
||||
db.exec("CREATE TABLE nums (val INTEGER)");
|
||||
db.exec("INSERT INTO nums VALUES (1), (2), (3), (4), (5), (6)");
|
||||
|
||||
const results = db.prepare("SELECT val FROM nums WHERE is_even(val)").all() as any[];
|
||||
expect(results.map(r => r.val)).toEqual([2, 4, 6]);
|
||||
});
|
||||
|
||||
test("validation errors", () => {
|
||||
const db = new Database(":memory:");
|
||||
|
||||
expect(() => (db as any).function(123, () => {})).toThrow();
|
||||
expect(() => (db as any).function("test", "not a function")).toThrow();
|
||||
expect(() => (db as any).function("test", {}, "not a function")).toThrow();
|
||||
});
|
||||
|
||||
test("safeIntegers option", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("bigint_check", { safeIntegers: true }, (x: any) => {
|
||||
return typeof x === "bigint" ? 1 : 0;
|
||||
});
|
||||
|
||||
db.exec("CREATE TABLE big (val INTEGER)");
|
||||
db.exec("INSERT INTO big VALUES (42)");
|
||||
|
||||
const result = db.prepare("SELECT bigint_check(val) as val FROM big").get() as any;
|
||||
expect(result.val).toBe(1);
|
||||
});
|
||||
|
||||
test("boolean return value", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("is_positive", (x: number) => x > 0);
|
||||
|
||||
// Booleans become 0/1 in SQLite
|
||||
expect((db.prepare("SELECT is_positive(5) as val").get() as any).val).toBe(1);
|
||||
expect((db.prepare("SELECT is_positive(-1) as val").get() as any).val).toBe(0);
|
||||
});
|
||||
|
||||
test("bigint return value", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.function("big_number", () => 9007199254740993n);
|
||||
|
||||
const db2 = new Database(":memory:");
|
||||
db2.function("big_number", { safeIntegers: true }, () => 9007199254740993n);
|
||||
|
||||
const result = db.prepare("SELECT big_number() as val").get() as any;
|
||||
// The bigint gets stored as int64, then retrieved as a number (may lose precision)
|
||||
expect(typeof result.val).toBe("number");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Database.prototype.aggregate()", () => {
|
||||
function makeDb() {
|
||||
const db = new Database(":memory:");
|
||||
db.exec("CREATE TABLE expenses (category TEXT, dollars REAL)");
|
||||
db.exec("INSERT INTO expenses VALUES ('food', 10), ('food', 20), ('rent', 50), ('food', 12)");
|
||||
return db;
|
||||
}
|
||||
|
||||
test("basic sum aggregate", () => {
|
||||
const db = makeDb();
|
||||
db.aggregate("addAll", {
|
||||
start: 0,
|
||||
step: (total: number, nextValue: number) => total + nextValue,
|
||||
});
|
||||
|
||||
const result = db.prepare("SELECT addAll(dollars) as val FROM expenses").get() as any;
|
||||
expect(result.val).toBe(92);
|
||||
});
|
||||
|
||||
test("aggregate with result transformation", () => {
|
||||
const db = makeDb();
|
||||
db.aggregate("getAverage", {
|
||||
start: () => [] as number[],
|
||||
step: (array: number[], nextValue: number) => {
|
||||
array.push(nextValue);
|
||||
},
|
||||
result: (array: number[]) => array.reduce((a, b) => a + b, 0) / array.length,
|
||||
});
|
||||
|
||||
const result = db.prepare("SELECT getAverage(dollars) as val FROM expenses").get() as any;
|
||||
expect(result.val).toBe(23);
|
||||
});
|
||||
|
||||
test("aggregate with GROUP BY", () => {
|
||||
const db = makeDb();
|
||||
db.aggregate("addAll", {
|
||||
start: 0,
|
||||
step: (total: number, nextValue: number) => total + nextValue,
|
||||
});
|
||||
|
||||
const results = db
|
||||
.prepare("SELECT category, addAll(dollars) as total FROM expenses GROUP BY category ORDER BY category")
|
||||
.all() as any[];
|
||||
expect(results).toEqual([
|
||||
{ category: "food", total: 42 },
|
||||
{ category: "rent", total: 50 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("aggregate with no rows", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.exec("CREATE TABLE empty (val REAL)");
|
||||
db.aggregate("addAll", {
|
||||
start: 0,
|
||||
step: (total: number, nextValue: number) => total + nextValue,
|
||||
});
|
||||
|
||||
const result = db.prepare("SELECT addAll(val) as val FROM empty").get() as any;
|
||||
expect(result.val).toBe(0);
|
||||
});
|
||||
|
||||
test("aggregate with no rows and result function", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.exec("CREATE TABLE empty (val REAL)");
|
||||
db.aggregate("custom_agg", {
|
||||
start: 0,
|
||||
step: (total: number, nextValue: number) => total + nextValue,
|
||||
result: (total: number) => total * 100,
|
||||
});
|
||||
|
||||
const result = db.prepare("SELECT custom_agg(val) as val FROM empty").get() as any;
|
||||
expect(result.val).toBe(0);
|
||||
});
|
||||
|
||||
test("aggregate with start as function", () => {
|
||||
const db = makeDb();
|
||||
let startCallCount = 0;
|
||||
db.aggregate("collect", {
|
||||
start: () => {
|
||||
startCallCount++;
|
||||
return [] as number[];
|
||||
},
|
||||
step: (arr: number[], val: number) => {
|
||||
arr.push(val);
|
||||
},
|
||||
result: (arr: number[]) => arr.join(","),
|
||||
});
|
||||
|
||||
db.prepare("SELECT category, collect(dollars) as vals FROM expenses GROUP BY category ORDER BY category").all();
|
||||
// start should be called once per group
|
||||
expect(startCallCount).toBe(2);
|
||||
});
|
||||
|
||||
test("step returning undefined doesn't replace accumulator", () => {
|
||||
const db = makeDb();
|
||||
db.aggregate("collect", {
|
||||
start: () => [] as number[],
|
||||
step: (arr: number[], val: number) => {
|
||||
arr.push(val);
|
||||
// returning undefined means "don't replace" - array was mutated in place
|
||||
},
|
||||
result: (arr: number[]) => arr.length,
|
||||
});
|
||||
|
||||
const result = db.prepare("SELECT collect(dollars) as val FROM expenses").get() as any;
|
||||
expect(result.val).toBe(4);
|
||||
});
|
||||
|
||||
test("aggregate with deterministic flag", () => {
|
||||
const db = makeDb();
|
||||
db.aggregate("det_sum", {
|
||||
start: 0,
|
||||
step: (total: number, next: number) => total + next,
|
||||
deterministic: true,
|
||||
});
|
||||
|
||||
const result = db.prepare("SELECT det_sum(dollars) as val FROM expenses").get() as any;
|
||||
expect(result.val).toBe(92);
|
||||
});
|
||||
|
||||
test("aggregate error in step propagation", () => {
|
||||
const db = makeDb();
|
||||
db.aggregate("bad_agg", {
|
||||
start: 0,
|
||||
step: (_total: number, _next: number) => {
|
||||
throw new Error("step error");
|
||||
},
|
||||
});
|
||||
|
||||
expect(() => db.prepare("SELECT bad_agg(dollars) as val FROM expenses").get()).toThrow("step error");
|
||||
});
|
||||
|
||||
test("aggregate error in result propagation", () => {
|
||||
const db = makeDb();
|
||||
db.aggregate("bad_result", {
|
||||
start: 0,
|
||||
step: (total: number, next: number) => total + next,
|
||||
result: () => {
|
||||
throw new Error("result error");
|
||||
},
|
||||
});
|
||||
|
||||
expect(() => db.prepare("SELECT bad_result(dollars) as val FROM expenses").get()).toThrow("result error");
|
||||
});
|
||||
|
||||
test("aggregate returns this for chaining", () => {
|
||||
const db = new Database(":memory:");
|
||||
const result = db.aggregate("noop", {
|
||||
start: 0,
|
||||
step: (total: number) => total,
|
||||
});
|
||||
expect(result).toBe(db);
|
||||
});
|
||||
|
||||
test("aggregate validation errors", () => {
|
||||
const db = new Database(":memory:");
|
||||
|
||||
expect(() => (db as any).aggregate(123, {})).toThrow();
|
||||
expect(() => db.aggregate("test", {} as any)).toThrow();
|
||||
expect(() => db.aggregate("test", { step: "not a function" } as any)).toThrow();
|
||||
expect(() => (db as any).aggregate("test", null)).toThrow();
|
||||
});
|
||||
|
||||
test("aggregate with safeIntegers", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.exec("CREATE TABLE nums (val INTEGER)");
|
||||
db.exec("INSERT INTO nums VALUES (1), (2), (3)");
|
||||
|
||||
db.aggregate("bigint_sum", {
|
||||
start: 0n,
|
||||
step: (total: bigint, next: bigint) => total + next,
|
||||
safeIntegers: true,
|
||||
});
|
||||
|
||||
const result = db.prepare("SELECT bigint_sum(val) as val FROM nums").get() as any;
|
||||
expect(result.val).toBe(6);
|
||||
});
|
||||
|
||||
test("window function with inverse", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.exec("CREATE TABLE t (x INTEGER)");
|
||||
db.exec("INSERT INTO t VALUES (1), (2), (3), (4), (5)");
|
||||
|
||||
db.aggregate("win_sum", {
|
||||
start: 0,
|
||||
step: (total: number, next: number) => total + next,
|
||||
inverse: (total: number, dropped: number) => total - dropped,
|
||||
});
|
||||
|
||||
const results = db
|
||||
.prepare("SELECT x, win_sum(x) OVER (ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as val FROM t")
|
||||
.all() as any[];
|
||||
|
||||
expect(results).toEqual([
|
||||
{ x: 1, val: 3 }, // 1+2
|
||||
{ x: 2, val: 6 }, // 1+2+3
|
||||
{ x: 3, val: 9 }, // 2+3+4
|
||||
{ x: 4, val: 12 }, // 3+4+5
|
||||
{ x: 5, val: 9 }, // 4+5
|
||||
]);
|
||||
});
|
||||
|
||||
test("window function with inverse and result", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.exec("CREATE TABLE t (x INTEGER)");
|
||||
db.exec("INSERT INTO t VALUES (1), (2), (3), (4), (5)");
|
||||
|
||||
db.aggregate("win_avg", {
|
||||
start: () => ({ sum: 0, count: 0 }),
|
||||
step: (acc: { sum: number; count: number }, next: number) => ({
|
||||
sum: acc.sum + next,
|
||||
count: acc.count + 1,
|
||||
}),
|
||||
inverse: (acc: { sum: number; count: number }, dropped: number) => ({
|
||||
sum: acc.sum - dropped,
|
||||
count: acc.count - 1,
|
||||
}),
|
||||
result: (acc: { sum: number; count: number }) => (acc.count > 0 ? acc.sum / acc.count : 0),
|
||||
});
|
||||
|
||||
const results = db
|
||||
.prepare("SELECT x, win_avg(x) OVER (ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as val FROM t")
|
||||
.all() as any[];
|
||||
|
||||
expect(results).toEqual([
|
||||
{ x: 1, val: 1.5 }, // avg(1,2)
|
||||
{ x: 2, val: 2 }, // avg(1,2,3)
|
||||
{ x: 3, val: 3 }, // avg(2,3,4)
|
||||
{ x: 4, val: 4 }, // avg(3,4,5)
|
||||
{ x: 5, val: 4.5 }, // avg(4,5)
|
||||
]);
|
||||
});
|
||||
|
||||
test("varargs aggregate", () => {
|
||||
const db = new Database(":memory:");
|
||||
db.exec("CREATE TABLE t (a INTEGER, b INTEGER)");
|
||||
db.exec("INSERT INTO t VALUES (1, 10), (2, 20), (3, 30)");
|
||||
|
||||
db.aggregate("sum_product", {
|
||||
start: 0,
|
||||
step: (total: number, a: number, b: number) => total + a * b,
|
||||
varargs: true,
|
||||
});
|
||||
|
||||
const result = db.prepare("SELECT sum_product(a, b) as val FROM t").get() as any;
|
||||
expect(result.val).toBe(140); // 1*10 + 2*20 + 3*30
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user