From 17b503b389d13ce28355f852b45ae638ce90a166 Mon Sep 17 00:00:00 2001 From: Marko Vejnovic Date: Fri, 26 Sep 2025 03:06:18 -0700 Subject: [PATCH] Redis PUB/SUB 2.0 (#22568) ### What does this PR do? **This PR is created because [the previous PR I opened](https://github.com/oven-sh/bun/pull/21728) had some concerning issues.** Thanks @Jarred-Sumner for the help. The goal of this PR is to introduce PUB/SUB functionality to the built-in Redis client. Based on the fact that the current Redis API does not appear to have compatibility with `io-redis` or `redis-node`, I've decided to do away with existing APIs and API compatibility with these existing libraries. I have decided to base my implementation on the [`redis-node` pub/sub API](https://github.com/redis/node-redis/blob/master/docs/pub-sub.md). #### Random Things That Happened - [x] Refactored the build scripts so that `valgrind` can be disabled. - [x] Added a `numeric` namespace in `harness.ts` with useful mathematical libraries. - [x] Added a mechanism in `cppbind.ts` to disable static assertions (specifically to allow `check_slow` even when returning a `JSValue`). Implemented via `// NOLINT[NEXTLINE]?\(.*\)` macros. - [x] Fixed inconsistencies in error handling of `JSMap`. ### How did you verify your code works? I've written a set of unit tests to hopefully catch the major use-cases of this feature. They all appear to pass. #### Future Improvements I would have a lot more confidence in our Redis implementation if we tested it with a test suite running over a network which emulates a high network failure rate. There are large amounts of edge cases that are worthwhile to grab, but I think we can roll that out in a future PR. ### Future Tasks - [ ] Tests over flaky network - [ ] Use the custom private members over `_`. --------- Co-authored-by: Jarred Sumner --- build.zig | 12 +- cmake/Options.cmake | 6 +- cmake/targets/BuildBun.cmake | 31 +- cmake/tools/SetupZig.cmake | 1 + docs/api/redis.md | 99 +++- docs/nav.ts | 2 +- packages/bun-types/redis.d.ts | 262 +++++++-- src/bun.js/api/BunObject.zig | 8 +- src/bun.js/api/valkey.classes.ts | 4 +- src/bun.js/bindings/ErrorCode.ts | 31 +- src/bun.js/bindings/JSGlobalObject.zig | 7 + src/bun.js/bindings/JSMap.zig | 43 +- src/bun.js/bindings/JSPromise.zig | 11 + src/bun.js/bindings/JSRef.zig | 9 + src/bun.js/bindings/bindings.cpp | 53 +- src/bun.js/bindings/headers.h | 7 +- src/bun.js/event_loop.zig | 13 + src/deps/uws/us_socket_t.zig | 18 +- src/memory.zig | 30 + src/string/StringBuilder.zig | 9 + src/valkey/ValkeyCommand.zig | 3 +- src/valkey/js_valkey.zig | 611 ++++++++++++++++++-- src/valkey/js_valkey_functions.zig | 474 +++++++++++++-- src/valkey/valkey.zig | 184 +++++- src/valkey/valkey_protocol.zig | 12 + test/_util/numeric.ts | 192 ++++++ test/harness.ts | 13 +- test/integration/bun-types/fixture/redis.ts | 31 + test/js/valkey/test-utils.ts | 108 +++- test/js/valkey/valkey.failing-subscriber.ts | 48 ++ test/js/valkey/valkey.test.ts | 577 +++++++++++++++++- 31 files changed, 2643 insertions(+), 266 deletions(-) create mode 100644 test/_util/numeric.ts create mode 100644 test/integration/bun-types/fixture/redis.ts create mode 100644 test/js/valkey/valkey.failing-subscriber.ts diff --git a/build.zig b/build.zig index e368cac8d9..239df85c53 100644 --- a/build.zig +++ b/build.zig @@ -48,6 +48,7 @@ const BunBuildOptions = struct { /// enable debug logs in release builds enable_logs: bool = false, enable_asan: bool, + enable_valgrind: bool, tracy_callstack_depth: u16, reported_nodejs_version: Version, /// To make iterating on some '@embedFile's faster, we load them at runtime @@ -94,6 +95,7 @@ const BunBuildOptions = struct { opts.addOption(bool, "baseline", this.isBaseline()); opts.addOption(bool, "enable_logs", this.enable_logs); opts.addOption(bool, "enable_asan", this.enable_asan); + opts.addOption(bool, "enable_valgrind", this.enable_valgrind); opts.addOption([]const u8, "reported_nodejs_version", b.fmt("{}", .{this.reported_nodejs_version})); opts.addOption(bool, "zig_self_hosted_backend", this.no_llvm); opts.addOption(bool, "override_no_export_cpp_apis", this.override_no_export_cpp_apis); @@ -213,26 +215,21 @@ pub fn build(b: *Build) !void { var build_options = BunBuildOptions{ .target = target, .optimize = optimize, - .os = os, .arch = arch, - .codegen_path = codegen_path, .codegen_embed = codegen_embed, .no_llvm = no_llvm, .override_no_export_cpp_apis = override_no_export_cpp_apis, - .version = try Version.parse(bun_version), .canary_revision = canary: { const rev = b.option(u32, "canary", "Treat this as a canary build") orelse 0; break :canary if (rev == 0) null else rev; }, - .reported_nodejs_version = try Version.parse( b.option([]const u8, "reported_nodejs_version", "Reported Node.js version") orelse "0.0.0-unset", ), - .sha = sha: { const sha_buildoption = b.option([]const u8, "sha", "Force the git sha"); const sha_github = b.graph.env_map.get("GITHUB_SHA"); @@ -268,10 +265,10 @@ pub fn build(b: *Build) !void { break :sha sha; }, - .tracy_callstack_depth = b.option(u16, "tracy_callstack_depth", "") orelse 10, .enable_logs = b.option(bool, "enable_logs", "Enable logs in release") orelse false, .enable_asan = b.option(bool, "enable_asan", "Enable asan") orelse false, + .enable_valgrind = b.option(bool, "enable_valgrind", "Enable valgrind") orelse false, }; // zig build obj @@ -500,6 +497,7 @@ fn addMultiCheck( .codegen_path = root_build_options.codegen_path, .no_llvm = root_build_options.no_llvm, .enable_asan = root_build_options.enable_asan, + .enable_valgrind = root_build_options.enable_valgrind, .override_no_export_cpp_apis = root_build_options.override_no_export_cpp_apis, }; @@ -636,7 +634,7 @@ fn configureObj(b: *Build, opts: *BunBuildOptions, obj: *Compile) void { obj.link_function_sections = true; obj.link_data_sections = true; - if (opts.optimize == .Debug) { + if (opts.optimize == .Debug and opts.enable_valgrind) { obj.root_module.valgrind = true; } } diff --git a/cmake/Options.cmake b/cmake/Options.cmake index 3dd5220cc5..1e9b664321 100644 --- a/cmake/Options.cmake +++ b/cmake/Options.cmake @@ -60,10 +60,10 @@ endif() # Windows Code Signing Option if(WIN32) optionx(ENABLE_WINDOWS_CODESIGNING BOOL "Enable Windows code signing with DigiCert KeyLocker" DEFAULT OFF) - + if(ENABLE_WINDOWS_CODESIGNING) message(STATUS "Windows code signing: ENABLED") - + # Check for required environment variables if(NOT DEFINED ENV{SM_API_KEY}) message(WARNING "SM_API_KEY not set - code signing may fail") @@ -114,8 +114,10 @@ endif() if(DEBUG AND ((APPLE AND ARCH STREQUAL "aarch64") OR LINUX)) set(DEFAULT_ASAN ON) + set(DEFAULT_VALGRIND OFF) else() set(DEFAULT_ASAN OFF) + set(DEFAULT_VALGRIND OFF) endif() optionx(ENABLE_ASAN BOOL "If ASAN support should be enabled" DEFAULT ${DEFAULT_ASAN}) diff --git a/cmake/targets/BuildBun.cmake b/cmake/targets/BuildBun.cmake index 5e8795c1bd..888c360739 100644 --- a/cmake/targets/BuildBun.cmake +++ b/cmake/targets/BuildBun.cmake @@ -2,6 +2,8 @@ include(PathUtils) if(DEBUG) set(bun bun-debug) +elseif(ENABLE_ASAN AND ENABLE_VALGRIND) + set(bun bun-asan-valgrind) elseif(ENABLE_ASAN) set(bun bun-asan) elseif(ENABLE_VALGRIND) @@ -619,6 +621,7 @@ register_command( -Dcpu=${ZIG_CPU} -Denable_logs=$,true,false> -Denable_asan=$,true,false> + -Denable_valgrind=$,true,false> -Dversion=${VERSION} -Dreported_nodejs_version=${NODEJS_VERSION} -Dcanary=${CANARY_REVISION} @@ -886,12 +889,8 @@ if(NOT WIN32) endif() if(ENABLE_ASAN) - target_compile_options(${bun} PUBLIC - -fsanitize=address - ) - target_link_libraries(${bun} PUBLIC - -fsanitize=address - ) + target_compile_options(${bun} PUBLIC -fsanitize=address) + target_link_libraries(${bun} PUBLIC -fsanitize=address) endif() target_compile_options(${bun} PUBLIC @@ -930,12 +929,8 @@ if(NOT WIN32) ) if(ENABLE_ASAN) - target_compile_options(${bun} PUBLIC - -fsanitize=address - ) - target_link_libraries(${bun} PUBLIC - -fsanitize=address - ) + target_compile_options(${bun} PUBLIC -fsanitize=address) + target_link_libraries(${bun} PUBLIC -fsanitize=address) endif() endif() else() @@ -1063,7 +1058,7 @@ if(LINUX) ) endif() - if (NOT DEBUG AND NOT ENABLE_ASAN) + if (NOT DEBUG AND NOT ENABLE_ASAN AND NOT ENABLE_VALGRIND) target_link_options(${bun} PUBLIC -Wl,-icf=safe ) @@ -1366,12 +1361,20 @@ if(NOT BUN_CPP_ONLY) if(ENABLE_BASELINE) set(bunTriplet ${bunTriplet}-baseline) endif() - if(ENABLE_ASAN) + + if (ENABLE_ASAN AND ENABLE_VALGRIND) + set(bunTriplet ${bunTriplet}-asan-valgrind) + set(bunPath ${bunTriplet}) + elseif (ENABLE_VALGRIND) + set(bunTriplet ${bunTriplet}-valgrind) + set(bunPath ${bunTriplet}) + elseif(ENABLE_ASAN) set(bunTriplet ${bunTriplet}-asan) set(bunPath ${bunTriplet}) else() string(REPLACE bun ${bunTriplet} bunPath ${bun}) endif() + set(bunFiles ${bunExe} features.json) if(WIN32) list(APPEND bunFiles ${bun}.pdb) diff --git a/cmake/tools/SetupZig.cmake b/cmake/tools/SetupZig.cmake index cbcb50a867..5143353ca0 100644 --- a/cmake/tools/SetupZig.cmake +++ b/cmake/tools/SetupZig.cmake @@ -90,6 +90,7 @@ register_command( -DZIG_PATH=${ZIG_PATH} -DZIG_COMMIT=${ZIG_COMMIT} -DENABLE_ASAN=${ENABLE_ASAN} + -DENABLE_VALGRIND=${ENABLE_VALGRIND} -DZIG_COMPILER_SAFE=${ZIG_COMPILER_SAFE} -P ${CWD}/cmake/scripts/DownloadZig.cmake SOURCES diff --git a/docs/api/redis.md b/docs/api/redis.md index 929babb318..3c663cb0a0 100644 --- a/docs/api/redis.md +++ b/docs/api/redis.md @@ -161,6 +161,102 @@ const randomTag = await redis.srandmember("tags"); const poppedTag = await redis.spop("tags"); ``` +## Pub/Sub + +Bun provides native bindings for the [Redis +Pub/Sub](https://redis.io/docs/latest/develop/pubsub/) protocol. **New in Bun +1.2.23** + +{% callout %} +**🚧** — The Redis Pub/Sub feature is experimental. Although we expect it to be +stable, we're currently actively looking for feedback and areas for improvement. +{% /callout %} + +### Basic Usage + +To get started publishing messages, you can set up a publisher in +`publisher.ts`: + +```typescript#publisher.ts +import { RedisClient } from "bun"; + +const writer = new RedisClient("redis://localhost:6739"); +await writer.connect(); + +writer.publish("general", "Hello everyone!"); + +writer.close(); +``` + +In another file, create the subscriber in `subscriber.ts`: + +```typescript#subscriber.ts +import { RedisClient } from "bun"; + +const listener = new RedisClient("redis://localhost:6739"); +await listener.connect(); + +await listener.subscribe("general", (message, channel) => { + console.log(`Received: ${message}`); +}); +``` + +In one shell, run your subscriber: + +```bash +bun run subscriber.ts +``` + +and, in another, run your publisher: + +```bash +bun run publisher.ts +``` + +{% callout %} +**Note:** The subscription mode takes over the `RedisClient` connection. A +client with subscriptions can only call `RedisClient.prototype.subscribe()`. In +other words, applications which need to message Redis need a separate +connection, acquirable through `.duplicate()`: + +```typescript +import { RedisClient } from "bun"; + +const redis = new RedisClient("redis://localhost:6379"); +await redis.connect(); +const subscriber = await redis.duplicate(); + +await subscriber.subscribe("foo", () => {}); +await redis.set("bar", "baz"); +``` + +{% /callout %} + +### Publishing + +Publishing messages is done through the `publish()` method: + +```typescript +await client.publish(channelName, message); +``` + +### Subscriptions + +The Bun `RedisClient` allows you to subscribe to channels through the +`.subscribe()` method: + +```typescript +await client.subscribe(channel, (message, channel) => {}); +``` + +You can unsubscribe through the `.unsubscribe()` method: + +```typescript +await client.unsubscribe(); // Unsubscribe from all channels. +await client.unsubscribe(channel); // Unsubscribe a particular channel. +await client.unsubscribe(channel, listener); // Unsubscribe a particular listener. +``` + ## Advanced Usage ### Command Execution and Pipelining @@ -482,9 +578,10 @@ When connecting to Redis servers using older versions that don't support RESP3, Current limitations of the Redis client we are planning to address in future versions: -- [ ] No dedicated API for pub/sub functionality (though you can use the raw command API) - [ ] Transactions (MULTI/EXEC) must be done through raw commands for now - [ ] Streams are supported but without dedicated methods +- [ ] Pub/Sub does not currently support binary data, nor pattern-based + subscriptions. Unsupported features: diff --git a/docs/nav.ts b/docs/nav.ts index 0e7c91bed0..3d776961a0 100644 --- a/docs/nav.ts +++ b/docs/nav.ts @@ -359,7 +359,7 @@ export default { page("api/file-io", "File I/O", { description: `Read and write files fast with Bun's heavily optimized file system API.`, }), // "`Bun.write`"), - page("api/redis", "Redis client", { + page("api/redis", "Redis Client", { description: `Bun provides a fast, native Redis client with automatic command pipelining for better performance.`, }), page("api/import-meta", "import.meta", { diff --git a/packages/bun-types/redis.d.ts b/packages/bun-types/redis.d.ts index 9cf228c92a..414b01288b 100644 --- a/packages/bun-types/redis.d.ts +++ b/packages/bun-types/redis.d.ts @@ -52,21 +52,25 @@ declare module "bun" { export namespace RedisClient { type KeyLike = string | ArrayBufferView | Blob; + type StringPubSubListener = (message: string, channel: string) => void; + + // Buffer subscriptions are not yet implemented + // type BufferPubSubListener = (message: Uint8Array, channel: string) => void; } export class RedisClient { /** * Creates a new Redis client - * @param url URL to connect to, defaults to process.env.VALKEY_URL, process.env.REDIS_URL, or "valkey://localhost:6379" + * + * @param url URL to connect to, defaults to `process.env.VALKEY_URL`, + * `process.env.REDIS_URL`, or `"valkey://localhost:6379"` * @param options Additional options * * @example * ```ts - * const valkey = new RedisClient(); - * - * await valkey.set("hello", "world"); - * - * console.log(await valkey.get("hello")); + * const redis = new RedisClient(); + * await redis.set("hello", "world"); + * console.log(await redis.get("hello")); * ``` */ constructor(url?: string, options?: RedisOptions); @@ -88,12 +92,14 @@ declare module "bun" { /** * Callback fired when the client disconnects from the Redis server + * * @param error The error that caused the disconnection */ onclose: ((this: RedisClient, error: Error) => void) | null; /** * Connect to the Redis server + * * @returns A promise that resolves when connected */ connect(): Promise; @@ -152,10 +158,12 @@ declare module "bun" { set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, px: "PX", milliseconds: number): Promise<"OK">; /** - * Set key to hold the string value with expiration at a specific Unix timestamp + * Set key to hold the string value with expiration at a specific Unix + * timestamp * @param key The key to set * @param value The value to set - * @param exat Set the specified Unix time at which the key will expire, in seconds + * @param exat Set the specified Unix time at which the key will expire, in + * seconds * @returns Promise that resolves with "OK" on success */ set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, exat: "EXAT", timestampSeconds: number): Promise<"OK">; @@ -179,7 +187,8 @@ declare module "bun" { * @param key The key to set * @param value The value to set * @param nx Only set the key if it does not already exist - * @returns Promise that resolves with "OK" on success, or null if the key already exists + * @returns Promise that resolves with "OK" on success, or null if the key + * already exists */ set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, nx: "NX"): Promise<"OK" | null>; @@ -188,7 +197,8 @@ declare module "bun" { * @param key The key to set * @param value The value to set * @param xx Only set the key if it already exists - * @returns Promise that resolves with "OK" on success, or null if the key does not exist + * @returns Promise that resolves with "OK" on success, or null if the key + * does not exist */ set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, xx: "XX"): Promise<"OK" | null>; @@ -196,8 +206,10 @@ declare module "bun" { * Set key to hold the string value and return the old value * @param key The key to set * @param value The value to set - * @param get Return the old string stored at key, or null if key did not exist - * @returns Promise that resolves with the old value, or null if key did not exist + * @param get Return the old string stored at key, or null if key did not + * exist + * @returns Promise that resolves with the old value, or null if key did not + * exist */ set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, get: "GET"): Promise; @@ -243,7 +255,8 @@ declare module "bun" { /** * Determine if a key exists * @param key The key to check - * @returns Promise that resolves with true if the key exists, false otherwise + * @returns Promise that resolves with true if the key exists, false + * otherwise */ exists(key: RedisClient.KeyLike): Promise; @@ -258,7 +271,8 @@ declare module "bun" { /** * Get the time to live for a key in seconds * @param key The key to get the TTL for - * @returns Promise that resolves with the TTL, -1 if no expiry, or -2 if key doesn't exist + * @returns Promise that resolves with the TTL, -1 if no expiry, or -2 if + * key doesn't exist */ ttl(key: RedisClient.KeyLike): Promise; @@ -290,7 +304,8 @@ declare module "bun" { * Check if a value is a member of a set * @param key The set key * @param member The member to check - * @returns Promise that resolves with true if the member exists, false otherwise + * @returns Promise that resolves with true if the member exists, false + * otherwise */ sismember(key: RedisClient.KeyLike, member: string): Promise; @@ -298,7 +313,8 @@ declare module "bun" { * Add a member to a set * @param key The set key * @param member The member to add - * @returns Promise that resolves with 1 if the member was added, 0 if it already existed + * @returns Promise that resolves with 1 if the member was added, 0 if it + * already existed */ sadd(key: RedisClient.KeyLike, member: string): Promise; @@ -306,7 +322,8 @@ declare module "bun" { * Remove a member from a set * @param key The set key * @param member The member to remove - * @returns Promise that resolves with 1 if the member was removed, 0 if it didn't exist + * @returns Promise that resolves with 1 if the member was removed, 0 if it + * didn't exist */ srem(key: RedisClient.KeyLike, member: string): Promise; @@ -320,14 +337,16 @@ declare module "bun" { /** * Get a random member from a set * @param key The set key - * @returns Promise that resolves with a random member, or null if the set is empty + * @returns Promise that resolves with a random member, or null if the set + * is empty */ srandmember(key: RedisClient.KeyLike): Promise; /** * Remove and return a random member from a set * @param key The set key - * @returns Promise that resolves with the removed member, or null if the set is empty + * @returns Promise that resolves with the removed member, or null if the + * set is empty */ spop(key: RedisClient.KeyLike): Promise; @@ -394,28 +413,32 @@ declare module "bun" { /** * Remove and get the first element in a list * @param key The list key - * @returns Promise that resolves with the first element, or null if the list is empty + * @returns Promise that resolves with the first element, or null if the + * list is empty */ lpop(key: RedisClient.KeyLike): Promise; /** * Remove the expiration from a key * @param key The key to persist - * @returns Promise that resolves with 1 if the timeout was removed, 0 if the key doesn't exist or has no timeout + * @returns Promise that resolves with 1 if the timeout was removed, 0 if + * the key doesn't exist or has no timeout */ persist(key: RedisClient.KeyLike): Promise; /** * Get the expiration time of a key as a UNIX timestamp in milliseconds * @param key The key to check - * @returns Promise that resolves with the timestamp, or -1 if the key has no expiration, or -2 if the key doesn't exist + * @returns Promise that resolves with the timestamp, or -1 if the key has + * no expiration, or -2 if the key doesn't exist */ pexpiretime(key: RedisClient.KeyLike): Promise; /** * Get the time to live for a key in milliseconds * @param key The key to check - * @returns Promise that resolves with the TTL in milliseconds, or -1 if the key has no expiration, or -2 if the key doesn't exist + * @returns Promise that resolves with the TTL in milliseconds, or -1 if the + * key has no expiration, or -2 if the key doesn't exist */ pttl(key: RedisClient.KeyLike): Promise; @@ -429,42 +452,48 @@ declare module "bun" { /** * Get the number of members in a set * @param key The set key - * @returns Promise that resolves with the cardinality (number of elements) of the set + * @returns Promise that resolves with the cardinality (number of elements) + * of the set */ scard(key: RedisClient.KeyLike): Promise; /** * Get the length of the value stored in a key * @param key The key to check - * @returns Promise that resolves with the length of the string value, or 0 if the key doesn't exist + * @returns Promise that resolves with the length of the string value, or 0 + * if the key doesn't exist */ strlen(key: RedisClient.KeyLike): Promise; /** * Get the number of members in a sorted set * @param key The sorted set key - * @returns Promise that resolves with the cardinality (number of elements) of the sorted set + * @returns Promise that resolves with the cardinality (number of elements) + * of the sorted set */ zcard(key: RedisClient.KeyLike): Promise; /** * Remove and return members with the highest scores in a sorted set * @param key The sorted set key - * @returns Promise that resolves with the removed member and its score, or null if the set is empty + * @returns Promise that resolves with the removed member and its score, or + * null if the set is empty */ zpopmax(key: RedisClient.KeyLike): Promise; /** * Remove and return members with the lowest scores in a sorted set * @param key The sorted set key - * @returns Promise that resolves with the removed member and its score, or null if the set is empty + * @returns Promise that resolves with the removed member and its score, or + * null if the set is empty */ zpopmin(key: RedisClient.KeyLike): Promise; /** * Get one or multiple random members from a sorted set * @param key The sorted set key - * @returns Promise that resolves with a random member, or null if the set is empty + * @returns Promise that resolves with a random member, or null if the set + * is empty */ zrandmember(key: RedisClient.KeyLike): Promise; @@ -472,7 +501,8 @@ declare module "bun" { * Append a value to a key * @param key The key to append to * @param value The value to append - * @returns Promise that resolves with the length of the string after the append operation + * @returns Promise that resolves with the length of the string after the + * append operation */ append(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -480,7 +510,8 @@ declare module "bun" { * Set the value of a key and return its old value * @param key The key to set * @param value The value to set - * @returns Promise that resolves with the old value, or null if the key didn't exist + * @returns Promise that resolves with the old value, or null if the key + * didn't exist */ getset(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -488,7 +519,8 @@ declare module "bun" { * Prepend one or multiple values to a list * @param key The list key * @param value The value to prepend - * @returns Promise that resolves with the length of the list after the push operation + * @returns Promise that resolves with the length of the list after the push + * operation */ lpush(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -496,7 +528,8 @@ declare module "bun" { * Prepend a value to a list, only if the list exists * @param key The list key * @param value The value to prepend - * @returns Promise that resolves with the length of the list after the push operation, or 0 if the list doesn't exist + * @returns Promise that resolves with the length of the list after the push + * operation, or 0 if the list doesn't exist */ lpushx(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -504,7 +537,8 @@ declare module "bun" { * Add one or more members to a HyperLogLog * @param key The HyperLogLog key * @param element The element to add - * @returns Promise that resolves with 1 if the HyperLogLog was altered, 0 otherwise + * @returns Promise that resolves with 1 if the HyperLogLog was altered, 0 + * otherwise */ pfadd(key: RedisClient.KeyLike, element: string): Promise; @@ -512,7 +546,8 @@ declare module "bun" { * Append one or multiple values to a list * @param key The list key * @param value The value to append - * @returns Promise that resolves with the length of the list after the push operation + * @returns Promise that resolves with the length of the list after the push + * operation */ rpush(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -520,7 +555,8 @@ declare module "bun" { * Append a value to a list, only if the list exists * @param key The list key * @param value The value to append - * @returns Promise that resolves with the length of the list after the push operation, or 0 if the list doesn't exist + * @returns Promise that resolves with the length of the list after the push + * operation, or 0 if the list doesn't exist */ rpushx(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -528,7 +564,8 @@ declare module "bun" { * Set the value of a key, only if the key does not exist * @param key The key to set * @param value The value to set - * @returns Promise that resolves with 1 if the key was set, 0 if the key was not set + * @returns Promise that resolves with 1 if the key was set, 0 if the key + * was not set */ setnx(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -536,14 +573,16 @@ declare module "bun" { * Get the score associated with the given member in a sorted set * @param key The sorted set key * @param member The member to get the score for - * @returns Promise that resolves with the score of the member as a string, or null if the member or key doesn't exist + * @returns Promise that resolves with the score of the member as a string, + * or null if the member or key doesn't exist */ zscore(key: RedisClient.KeyLike, member: string): Promise; /** * Get the values of all specified keys * @param keys The keys to get - * @returns Promise that resolves with an array of values, with null for keys that don't exist + * @returns Promise that resolves with an array of values, with null for + * keys that don't exist */ mget(...keys: RedisClient.KeyLike[]): Promise<(string | null)[]>; @@ -557,37 +596,46 @@ declare module "bun" { /** * Return a serialized version of the value stored at the specified key * @param key The key to dump - * @returns Promise that resolves with the serialized value, or null if the key doesn't exist + * @returns Promise that resolves with the serialized value, or null if the + * key doesn't exist */ dump(key: RedisClient.KeyLike): Promise; /** * Get the expiration time of a key as a UNIX timestamp in seconds + * * @param key The key to check - * @returns Promise that resolves with the timestamp, or -1 if the key has no expiration, or -2 if the key doesn't exist + * @returns Promise that resolves with the timestamp, or -1 if the key has + * no expiration, or -2 if the key doesn't exist */ expiretime(key: RedisClient.KeyLike): Promise; /** * Get the value of a key and delete the key + * * @param key The key to get and delete - * @returns Promise that resolves with the value of the key, or null if the key doesn't exist + * @returns Promise that resolves with the value of the key, or null if the + * key doesn't exist */ getdel(key: RedisClient.KeyLike): Promise; /** * Get the value of a key and optionally set its expiration + * * @param key The key to get - * @returns Promise that resolves with the value of the key, or null if the key doesn't exist + * @returns Promise that resolves with the value of the key, or null if the + * key doesn't exist */ getex(key: RedisClient.KeyLike): Promise; /** * Get the value of a key and set its expiration in seconds + * * @param key The key to get * @param ex Set the specified expire time, in seconds * @param seconds The number of seconds until expiration - * @returns Promise that resolves with the value of the key, or null if the key doesn't exist + * @returns Promise that resolves with the value of the key, or null if the + * key doesn't exist */ getex(key: RedisClient.KeyLike, ex: "EX", seconds: number): Promise; @@ -602,6 +650,7 @@ declare module "bun" { /** * Get the value of a key and set its expiration at a specific Unix timestamp in seconds + * * @param key The key to get * @param exat Set the specified Unix time at which the key will expire, in seconds * @param timestampSeconds The Unix timestamp in seconds @@ -611,6 +660,7 @@ declare module "bun" { /** * Get the value of a key and set its expiration at a specific Unix timestamp in milliseconds + * * @param key The key to get * @param pxat Set the specified Unix time at which the key will expire, in milliseconds * @param timestampMilliseconds The Unix timestamp in milliseconds @@ -620,6 +670,7 @@ declare module "bun" { /** * Get the value of a key and remove its expiration + * * @param key The key to get * @param persist Remove the expiration from the key * @returns Promise that resolves with the value of the key, or null if the key doesn't exist @@ -634,10 +685,133 @@ declare module "bun" { /** * Ping the server with a message + * * @param message The message to send to the server * @returns Promise that resolves with the message if the server is reachable, or throws an error if the server is not reachable */ ping(message: RedisClient.KeyLike): Promise; + + /** + * Publish a message to a Redis channel. + * + * @param channel The channel to publish to. + * @param message The message to publish. + * + * @returns The number of clients that received the message. Note that in a + * cluster this returns the total number of clients in the same node. + */ + publish(channel: string, message: string): Promise; + + /** + * Subscribe to a Redis channel. + * + * Subscribing disables automatic pipelining, so all commands will be + * received immediately. + * + * Subscribing moves the channel to a dedicated subscription state which + * prevents most other commands from being executed until unsubscribed. Only + * {@link ping `.ping()`}, {@link subscribe `.subscribe()`}, and + * {@link unsubscribe `.unsubscribe()`} are legal to invoke in a subscribed + * upon channel. + * + * @param channel The channel to subscribe to. + * @param listener The listener to call when a message is received on the + * channel. The listener will receive the message as the first argument and + * the channel as the second argument. + * + * @example + * ```ts + * await client.subscribe("my-channel", (message, channel) => { + * console.log(`Received message on ${channel}: ${message}`); + * }); + * ``` + */ + subscribe(channel: string, listener: RedisClient.StringPubSubListener): Promise; + + /** + * Subscribe to multiple Redis channels. + * + * Subscribing disables automatic pipelining, so all commands will be + * received immediately. + * + * Subscribing moves the channels to a dedicated subscription state in which + * only a limited set of commands can be executed. + * + * @param channels An array of channels to subscribe to. + * @param listener The listener to call when a message is received on any of + * the subscribed channels. The listener will receive the message as the + * first argument and the channel as the second argument. + */ + subscribe(channels: string[], listener: RedisClient.StringPubSubListener): Promise; + + /** + * Unsubscribe from a singular Redis channel. + * + * @param channel The channel to unsubscribe from. + * + * If there are no more channels subscribed to, the client automatically + * re-enables pipelining if it was previously enabled. + * + * Unsubscribing moves the channel back to a normal state out of the + * subscription state if all channels have been unsubscribed from. For + * further details on the subscription state, see + * {@link subscribe `.subscribe()`}. + */ + unsubscribe(channel: string): Promise; + + /** + * Remove a listener from a given Redis channel. + * + * If there are no more channels subscribed to, the client automatically + * re-enables pipelining if it was previously enabled. + * + * Unsubscribing moves the channel back to a normal state out of the + * subscription state if all channels have been unsubscribed from. For + * further details on the subscription state, see + * {@link subscribe `.subscribe()`}. + * + * @param channel The channel to unsubscribe from. + * @param listener The listener to remove. This is tested against + * referential equality so you must pass the exact same listener instance as + * when subscribing. + */ + unsubscribe(channel: string, listener: RedisClient.StringPubSubListener): Promise; + + /** + * Unsubscribe from all registered Redis channels. + * + * The client will automatically re-enable pipelining if it was previously + * enabled. + * + * Unsubscribing moves the channel back to a normal state out of the + * subscription state if all channels have been unsubscribed from. For + * further details on the subscription state, see + * {@link subscribe `.subscribe()`}. + */ + unsubscribe(): Promise; + + /** + * Unsubscribe from multiple Redis channels. + * + * @param channels An array of channels to unsubscribe from. + * + * If there are no more channels subscribed to, the client automatically + * re-enables pipelining if it was previously enabled. + * + * Unsubscribing moves the channel back to a normal state out of the + * subscription state if all channels have been unsubscribed from. For + * further details on the subscription state, see + * {@link subscribe `.subscribe()`}. + */ + unsubscribe(channels: string[]): Promise; + + /** + * @brief Create a new RedisClient instance with the same configuration as + * the current instance. + * + * This will open up a new connection to the Redis server. + */ + duplicate(): Promise; } /** diff --git a/src/bun.js/api/BunObject.zig b/src/bun.js/api/BunObject.zig index 8cd0dfe71e..a5af653677 100644 --- a/src/bun.js/api/BunObject.zig +++ b/src/bun.js/api/BunObject.zig @@ -1294,7 +1294,7 @@ pub fn setTLSDefaultCiphers(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject, c } pub fn getValkeyDefaultClient(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) jsc.JSValue { - const valkey = jsc.API.Valkey.create(globalThis, &.{.js_undefined}) catch |err| { + var valkey = jsc.API.Valkey.createNoJs(globalThis, &.{.js_undefined}) catch |err| { if (err != error.JSError) { _ = globalThis.throwError(err, "Failed to create Redis client") catch {}; return .zero; @@ -1302,7 +1302,11 @@ pub fn getValkeyDefaultClient(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) return .zero; }; - return valkey.toJS(globalThis); + const as_js = valkey.toJS(globalThis); + + valkey.this_value = jsc.JSRef.initWeak(as_js); + + return as_js; } pub fn getValkeyClientConstructor(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) jsc.JSValue { diff --git a/src/bun.js/api/valkey.classes.ts b/src/bun.js/api/valkey.classes.ts index abcfeaa395..3828db06d1 100644 --- a/src/bun.js/api/valkey.classes.ts +++ b/src/bun.js/api/valkey.classes.ts @@ -4,6 +4,7 @@ export default [ define({ name: "RedisClient", construct: true, + constructNeedsThis: true, call: false, finalize: true, configurable: false, @@ -226,11 +227,12 @@ export default [ zrank: { fn: "zrank" }, zrevrank: { fn: "zrevrank" }, subscribe: { fn: "subscribe" }, + duplicate: { fn: "duplicate" }, psubscribe: { fn: "psubscribe" }, unsubscribe: { fn: "unsubscribe" }, punsubscribe: { fn: "punsubscribe" }, pubsub: { fn: "pubsub" }, }, - values: ["onconnect", "onclose", "connectionPromise", "hello"], + values: ["onconnect", "onclose", "connectionPromise", "hello", "subscriptionCallbackMap"], }), ]; diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index 37a9ce660b..ecd0bb332c 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -276,24 +276,25 @@ const errors: ErrorCodeMapping = [ ["ERR_OSSL_EVP_INVALID_DIGEST", Error], ["ERR_KEY_GENERATION_JOB_FAILED", Error], ["ERR_MISSING_OPTION", TypeError], - ["ERR_REDIS_CONNECTION_CLOSED", Error, "RedisError"], - ["ERR_REDIS_INVALID_RESPONSE", Error, "RedisError"], - ["ERR_REDIS_INVALID_BULK_STRING", Error, "RedisError"], - ["ERR_REDIS_INVALID_ARRAY", Error, "RedisError"], - ["ERR_REDIS_INVALID_INTEGER", Error, "RedisError"], - ["ERR_REDIS_INVALID_SIMPLE_STRING", Error, "RedisError"], - ["ERR_REDIS_INVALID_ERROR_STRING", Error, "RedisError"], - ["ERR_REDIS_TLS_NOT_AVAILABLE", Error, "RedisError"], - ["ERR_REDIS_TLS_UPGRADE_FAILED", Error, "RedisError"], ["ERR_REDIS_AUTHENTICATION_FAILED", Error, "RedisError"], - ["ERR_REDIS_INVALID_PASSWORD", Error, "RedisError"], - ["ERR_REDIS_INVALID_USERNAME", Error, "RedisError"], - ["ERR_REDIS_INVALID_DATABASE", Error, "RedisError"], - ["ERR_REDIS_INVALID_COMMAND", Error, "RedisError"], - ["ERR_REDIS_INVALID_ARGUMENT", Error, "RedisError"], - ["ERR_REDIS_INVALID_RESPONSE_TYPE", Error, "RedisError"], + ["ERR_REDIS_CONNECTION_CLOSED", Error, "RedisError"], ["ERR_REDIS_CONNECTION_TIMEOUT", Error, "RedisError"], ["ERR_REDIS_IDLE_TIMEOUT", Error, "RedisError"], + ["ERR_REDIS_INVALID_ARGUMENT", Error, "RedisError"], + ["ERR_REDIS_INVALID_ARRAY", Error, "RedisError"], + ["ERR_REDIS_INVALID_BULK_STRING", Error, "RedisError"], + ["ERR_REDIS_INVALID_COMMAND", Error, "RedisError"], + ["ERR_REDIS_INVALID_DATABASE", Error, "RedisError"], + ["ERR_REDIS_INVALID_ERROR_STRING", Error, "RedisError"], + ["ERR_REDIS_INVALID_INTEGER", Error, "RedisError"], + ["ERR_REDIS_INVALID_PASSWORD", Error, "RedisError"], + ["ERR_REDIS_INVALID_RESPONSE", Error, "RedisError"], + ["ERR_REDIS_INVALID_RESPONSE_TYPE", Error, "RedisError"], + ["ERR_REDIS_INVALID_SIMPLE_STRING", Error, "RedisError"], + ["ERR_REDIS_INVALID_STATE", Error, "RedisError"], + ["ERR_REDIS_INVALID_USERNAME", Error, "RedisError"], + ["ERR_REDIS_TLS_NOT_AVAILABLE", Error, "RedisError"], + ["ERR_REDIS_TLS_UPGRADE_FAILED", Error, "RedisError"], ["HPE_UNEXPECTED_CONTENT_LENGTH", Error], ["HPE_INVALID_TRANSFER_ENCODING", Error], ["HPE_INVALID_EOF_STATE", Error], diff --git a/src/bun.js/bindings/JSGlobalObject.zig b/src/bun.js/bindings/JSGlobalObject.zig index 5642205a9d..101ccc8781 100644 --- a/src/bun.js/bindings/JSGlobalObject.zig +++ b/src/bun.js/bindings/JSGlobalObject.zig @@ -367,6 +367,10 @@ pub const JSGlobalObject = opaque { return this.throwValue(err); } + /// Throw an Error from a formatted string. + /// + /// Note: If you are throwing an error within somewhere in the Bun API, + /// chances are you should be using `.ERR(...).throw()` instead. pub fn throw(this: *JSGlobalObject, comptime fmt: [:0]const u8, args: anytype) JSError { const instance = this.createErrorInstance(fmt, args); bun.assert(instance != .zero); @@ -789,6 +793,9 @@ pub const JSGlobalObject = opaque { return .{ .globalObject = this }; } + /// Throw an error from within the Bun runtime. + /// + /// The set of errors accepted by `ERR()` is defined in `ErrorCode.ts`. pub fn ERR(global: *JSGlobalObject, comptime code: jsc.Error, comptime fmt: [:0]const u8, args: anytype) @import("ErrorCode").ErrorBuilder(code, fmt, @TypeOf(args)) { return .{ .global = global, .args = args }; } diff --git a/src/bun.js/bindings/JSMap.zig b/src/bun.js/bindings/JSMap.zig index 5c4ce35be9..0abbd06173 100644 --- a/src/bun.js/bindings/JSMap.zig +++ b/src/bun.js/bindings/JSMap.zig @@ -1,34 +1,26 @@ +/// Opaque type for working with JavaScript `Map` objects. pub const JSMap = opaque { - extern fn JSC__JSMap__create(*JSGlobalObject) JSValue; + pub const create = bun.cpp.JSC__JSMap__create; + pub const set = bun.cpp.JSC__JSMap__set; - pub fn create(globalObject: *JSGlobalObject) JSValue { - return JSC__JSMap__create(globalObject); - } + /// Retrieve a value from this JS Map object. + /// + /// Note this shares semantics with the JS `Map.prototype.get` method, and + /// will return .js_undefined if a value is not found. + pub const get = bun.cpp.JSC__JSMap__get; - pub fn set(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue, value: JSValue) void { - return bun.cpp.JSC__JSMap__set(this, globalObject, key, value); - } + /// Test whether this JS Map object has a given key. + pub const has = bun.cpp.JSC__JSMap__has; - pub fn get_(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue) JSValue { - return bun.cpp.JSC__JSMap__get_(this, globalObject, key); - } + /// Attempt to remove a key from this JS Map object. + pub const remove = bun.cpp.JSC__JSMap__remove; - pub fn get(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue) ?JSValue { - const value = get_(this, globalObject, key); - if (value.isEmpty()) { - return null; - } - return value; - } - - pub fn has(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue) bool { - return bun.cpp.JSC__JSMap__has(this, globalObject, key); - } - - pub fn remove(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue) bool { - return bun.cpp.JSC__JSMap__remove(this, globalObject, key); - } + /// Retrieve the number of entries in this JS Map object. + pub const size = bun.cpp.JSC__JSMap__size; + /// Attempt to convert a `JSValue` to a `*JSMap`. + /// + /// Returns `null` if the value is not a Map. pub fn fromJS(value: JSValue) ?*JSMap { if (value.jsTypeLoose() == .Map) { return bun.cast(*JSMap, value.asEncoded().asPtr.?); @@ -41,5 +33,4 @@ pub const JSMap = opaque { const bun = @import("bun"); const jsc = bun.jsc; -const JSGlobalObject = jsc.JSGlobalObject; const JSValue = jsc.JSValue; diff --git a/src/bun.js/bindings/JSPromise.zig b/src/bun.js/bindings/JSPromise.zig index 35a581b106..3a0313241b 100644 --- a/src/bun.js/bindings/JSPromise.zig +++ b/src/bun.js/bindings/JSPromise.zig @@ -220,6 +220,9 @@ pub const JSPromise = opaque { bun.cpp.JSC__JSPromise__setHandled(this, vm); } + /// Create a new resolved promise resolving to a given value. + /// + /// Note: If you want the result as a JSValue, use `JSPromise.resolvedPromiseValue` instead. pub fn resolvedPromise(globalThis: *JSGlobalObject, value: JSValue) *JSPromise { return JSC__JSPromise__resolvedPromise(globalThis, value); } @@ -230,6 +233,9 @@ pub const JSPromise = opaque { return JSC__JSPromise__resolvedPromiseValue(globalThis, value); } + /// Create a new rejected promise rejecting to a given value. + /// + /// Note: If you want the result as a JSValue, use `JSPromise.rejectedPromiseValue` instead. pub fn rejectedPromise(globalThis: *JSGlobalObject, value: JSValue) *JSPromise { return JSC__JSPromise__rejectedPromise(globalThis, value); } @@ -275,6 +281,11 @@ pub const JSPromise = opaque { bun.cpp.JSC__JSPromise__rejectAsHandled(this, globalThis, value) catch return bun.debugAssert(false); // TODO: properly propagate exception upwards } + /// Create a new pending promise. + /// + /// Note: You should use `JSPromise.resolvedPromise` or + /// `JSPromise.rejectedPromise` if you want to create a promise that + /// is already resolved or rejected. pub fn create(globalThis: *JSGlobalObject) *JSPromise { return JSC__JSPromise__create(globalThis); } diff --git a/src/bun.js/bindings/JSRef.zig b/src/bun.js/bindings/JSRef.zig index c76a4056eb..08928aa73e 100644 --- a/src/bun.js/bindings/JSRef.zig +++ b/src/bun.js/bindings/JSRef.zig @@ -1,3 +1,7 @@ +/// Holds a reference to a JSValue. +/// +/// This reference can be either weak (a JSValue) or may be strong, in which +/// case it prevents the garbage collector from collecting the value. pub const JSRef = union(enum) { weak: jsc.JSValue, strong: jsc.Strong.Optional, @@ -91,6 +95,11 @@ pub const JSRef = union(enum) { }; } + /// Test whether this reference is a strong reference. + pub fn isStrong(this: *const @This()) bool { + return this.* == .strong; + } + pub fn deinit(this: *@This()) void { switch (this.*) { .weak => { diff --git a/src/bun.js/bindings/bindings.cpp b/src/bun.js/bindings/bindings.cpp index 496c5e14f8..6111976187 100644 --- a/src/bun.js/bindings/bindings.cpp +++ b/src/bun.js/bindings/bindings.cpp @@ -1,3 +1,12 @@ +/** + * Source code for JavaScriptCore bindings used by bind. + * + * This file is processed by cppbind.ts. + * + * @see cppbind.ts holds helpful tips on how to add and implement new bindings. + * Note that cppbind.ts also automatically runs some error-checking which + * can be disabled if necessary. Consult cppbind.ts for details. + */ #include "root.h" #include "JavaScriptCore/ErrorType.h" @@ -37,6 +46,7 @@ #include "JavaScriptCore/JSArrayInlines.h" #include "JavaScriptCore/ErrorInstanceInlines.h" #include "JavaScriptCore/BigIntObject.h" +#include "JavaScriptCore/OrderedHashTableHelper.h" #include "JavaScriptCore/JSCallbackObject.h" #include "JavaScriptCore/JSClassRef.h" @@ -6463,32 +6473,49 @@ CPP_DECL WebCore::DOMFormData* WebCore__DOMFormData__fromJS(JSC::EncodedJSValue #pragma mark - JSC::JSMap -CPP_DECL JSC::EncodedJSValue JSC__JSMap__create(JSC::JSGlobalObject* arg0) +CPP_DECL [[ZIG_EXPORT(nothrow)]] JSC::EncodedJSValue JSC__JSMap__create(JSC::JSGlobalObject* arg0) { - JSC::JSMap* map = JSC::JSMap::create(arg0->vm(), arg0->mapStructure()); - return JSC::JSValue::encode(map); + return JSC::JSValue::encode(JSC::JSMap::create(arg0->vm(), arg0->mapStructure())); } -CPP_DECL [[ZIG_EXPORT(nothrow)]] JSC::EncodedJSValue JSC__JSMap__get_(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) -{ - JSC::JSValue value = JSC::JSValue::decode(JSValue2); - return JSC::JSValue::encode(map->get(arg1, value)); -} -CPP_DECL [[ZIG_EXPORT(nothrow)]] bool JSC__JSMap__has(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) +// JSMap::get never returns JSValue::zero, even in the case of an exception. The +// best we can, therefore, do is manually test for exceptions. +// NOLINTNEXTLINE(bun-bindgen-force-zero_is_throw-for-jsvalue) +CPP_DECL [[ZIG_EXPORT(zero_is_throw)]] JSC::EncodedJSValue JSC__JSMap__get(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) { - JSC::JSValue value = JSC::JSValue::decode(JSValue2); + auto& vm = JSC::getVM(arg1); + const JSC::JSValue key = JSC::JSValue::decode(JSValue2); + + // JSMap::get never returns JSValue::zero, even in the case of an exception. + // It will return JSValue::undefined and set an exception on the VM. + auto scope = DECLARE_THROW_SCOPE(vm); + const JSValue value = map->get(arg1, key); + RETURN_IF_EXCEPTION(scope, {}); + return JSC::JSValue::encode(value); +} + +CPP_DECL [[ZIG_EXPORT(check_slow)]] bool JSC__JSMap__has(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) +{ + const JSC::JSValue value = JSC::JSValue::decode(JSValue2); return map->has(arg1, value); } -CPP_DECL [[ZIG_EXPORT(nothrow)]] bool JSC__JSMap__remove(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) + +CPP_DECL [[ZIG_EXPORT(check_slow)]] bool JSC__JSMap__remove(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) { - JSC::JSValue value = JSC::JSValue::decode(JSValue2); + const JSC::JSValue value = JSC::JSValue::decode(JSValue2); return map->remove(arg1, value); } -CPP_DECL [[ZIG_EXPORT(nothrow)]] void JSC__JSMap__set(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2, JSC::EncodedJSValue JSValue3) + +CPP_DECL [[ZIG_EXPORT(check_slow)]] void JSC__JSMap__set(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2, JSC::EncodedJSValue JSValue3) { map->set(arg1, JSC::JSValue::decode(JSValue2), JSC::JSValue::decode(JSValue3)); } +CPP_DECL [[ZIG_EXPORT(check_slow)]] uint32_t JSC__JSMap__size(JSC::JSMap* map, JSC::JSGlobalObject* arg1) +{ + return map->size(); +} + CPP_DECL void JSC__VM__setControlFlowProfiler(JSC::VM* vm, bool isEnabled) { if (isEnabled) { diff --git a/src/bun.js/bindings/headers.h b/src/bun.js/bindings/headers.h index 4f3d3454b4..f02d203054 100644 --- a/src/bun.js/bindings/headers.h +++ b/src/bun.js/bindings/headers.h @@ -8,7 +8,7 @@ #define AUTO_EXTERN_C extern "C" #ifdef WIN32 - #define AUTO_EXTERN_C_ZIG extern "C" + #define AUTO_EXTERN_C_ZIG extern "C" #else #define AUTO_EXTERN_C_ZIG extern "C" __attribute__((weak)) #endif @@ -129,7 +129,7 @@ CPP_DECL void WebCore__AbortSignal__cleanNativeBindings(WebCore::AbortSignal* ar CPP_DECL JSC::EncodedJSValue WebCore__AbortSignal__create(JSC::JSGlobalObject* arg0); CPP_DECL WebCore::AbortSignal* WebCore__AbortSignal__fromJS(JSC::EncodedJSValue JSValue0); CPP_DECL WebCore::AbortSignal* WebCore__AbortSignal__ref(WebCore::AbortSignal* arg0); -CPP_DECL WebCore::AbortSignal* WebCore__AbortSignal__signal(WebCore::AbortSignal* arg0, JSC::JSGlobalObject*, uint8_t abortReason); +CPP_DECL WebCore::AbortSignal* WebCore__AbortSignal__signal(WebCore::AbortSignal* arg0, JSC::JSGlobalObject*, uint8_t abortReason); CPP_DECL JSC::EncodedJSValue WebCore__AbortSignal__toJS(WebCore::AbortSignal* arg0, JSC::JSGlobalObject* arg1); CPP_DECL void WebCore__AbortSignal__unref(WebCore::AbortSignal* arg0); @@ -186,10 +186,11 @@ CPP_DECL JSC::VM* JSC__JSGlobalObject__vm(JSC::JSGlobalObject* arg0); #pragma mark - JSC::JSMap CPP_DECL JSC::EncodedJSValue JSC__JSMap__create(JSC::JSGlobalObject* arg0); -CPP_DECL JSC::EncodedJSValue JSC__JSMap__get_(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2); +CPP_DECL JSC::EncodedJSValue JSC__JSMap__get(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2); CPP_DECL bool JSC__JSMap__has(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2); CPP_DECL bool JSC__JSMap__remove(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2); CPP_DECL void JSC__JSMap__set(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2, JSC::EncodedJSValue JSValue3); +CPP_DECL uint32_t JSC__JSMap__size(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1); #pragma mark - JSC::JSValue diff --git a/src/bun.js/event_loop.zig b/src/bun.js/event_loop.zig index dd4451e81f..b84f9c7d41 100644 --- a/src/bun.js/event_loop.zig +++ b/src/bun.js/event_loop.zig @@ -57,12 +57,25 @@ pub const Debug = if (Environment.isDebug) struct { pub inline fn exit(_: Debug) void {} }; +/// Before your code enters JavaScript at the top of the event loop, call +/// `loop.enter()`. If running a single callback, prefer `runCallback` instead. +/// +/// When we call into JavaScript, we must drain process.nextTick & microtasks +/// afterwards (so that promises run). We must only do that once per task in the +/// event loop. To make that work, we count enter/exit calls and once that +/// counter reaches 0, we drain the microtasks. +/// +/// This function increments the counter for the number of times we've entered +/// the event loop. pub fn enter(this: *EventLoop) void { log("enter() = {d}", .{this.entered_event_loop_count}); this.entered_event_loop_count += 1; this.debug.enter(); } +/// "exit" a microtask context in the event loop. +/// +/// See the documentation for `enter` for more information. pub fn exit(this: *EventLoop) void { const count = this.entered_event_loop_count; log("exit() = {d}", .{count - 1}); diff --git a/src/deps/uws/us_socket_t.zig b/src/deps/uws/us_socket_t.zig index 0af82ca70e..a67efb9286 100644 --- a/src/deps/uws/us_socket_t.zig +++ b/src/deps/uws/us_socket_t.zig @@ -13,7 +13,7 @@ pub const us_socket_t = opaque { }; pub fn open(this: *us_socket_t, comptime is_ssl: bool, is_client: bool, ip_addr: ?[]const u8) void { - debug("us_socket_open({d}, is_client: {})", .{ @intFromPtr(this), is_client }); + debug("us_socket_open({p}, is_client: {})", .{ this, is_client }); const ssl = @intFromBool(is_ssl); if (ip_addr) |ip| { @@ -25,22 +25,22 @@ pub const us_socket_t = opaque { } pub fn pause(this: *us_socket_t, ssl: bool) void { - debug("us_socket_pause({d})", .{@intFromPtr(this)}); + debug("us_socket_pause({p})", .{this}); c.us_socket_pause(@intFromBool(ssl), this); } pub fn @"resume"(this: *us_socket_t, ssl: bool) void { - debug("us_socket_resume({d})", .{@intFromPtr(this)}); + debug("us_socket_resume({p})", .{this}); c.us_socket_resume(@intFromBool(ssl), this); } pub fn close(this: *us_socket_t, ssl: bool, code: CloseCode) void { - debug("us_socket_close({d}, {s})", .{ @intFromPtr(this), @tagName(code) }); + debug("us_socket_close({p}, {s})", .{ this, @tagName(code) }); _ = c.us_socket_close(@intFromBool(ssl), this, code, null); } pub fn shutdown(this: *us_socket_t, ssl: bool) void { - debug("us_socket_shutdown({d})", .{@intFromPtr(this)}); + debug("us_socket_shutdown({p})", .{this}); c.us_socket_shutdown(@intFromBool(ssl), this); } @@ -128,25 +128,25 @@ pub const us_socket_t = opaque { pub fn write(this: *us_socket_t, ssl: bool, data: []const u8) i32 { const rc = c.us_socket_write(@intFromBool(ssl), this, data.ptr, @intCast(data.len)); - debug("us_socket_write({d}, {d}) = {d}", .{ @intFromPtr(this), data.len, rc }); + debug("us_socket_write({p}, {d}) = {d}", .{ this, data.len, rc }); return rc; } pub fn writeFd(this: *us_socket_t, data: []const u8, file_descriptor: bun.FD) i32 { if (bun.Environment.isWindows) @compileError("TODO: implement writeFd on Windows"); const rc = c.us_socket_ipc_write_fd(this, data.ptr, @intCast(data.len), file_descriptor.native()); - debug("us_socket_ipc_write_fd({d}, {d}, {d}) = {d}", .{ @intFromPtr(this), data.len, file_descriptor.native(), rc }); + debug("us_socket_ipc_write_fd({p}, {d}, {d}) = {d}", .{ this, data.len, file_descriptor.native(), rc }); return rc; } pub fn write2(this: *us_socket_t, ssl: bool, first: []const u8, second: []const u8) i32 { const rc = c.us_socket_write2(@intFromBool(ssl), this, first.ptr, first.len, second.ptr, second.len); - debug("us_socket_write2({d}, {d}, {d}) = {d}", .{ @intFromPtr(this), first.len, second.len, rc }); + debug("us_socket_write2({p}, {d}, {d}) = {d}", .{ this, first.len, second.len, rc }); return rc; } pub fn rawWrite(this: *us_socket_t, ssl: bool, data: []const u8) i32 { - debug("us_socket_raw_write({d}, {d})", .{ @intFromPtr(this), data.len }); + debug("us_socket_raw_write({p}, {d})", .{ this, data.len }); return c.us_socket_raw_write(@intFromBool(ssl), this, data.ptr, @intCast(data.len)); } diff --git a/src/memory.zig b/src/memory.zig index 47e54a7a65..5ed890ee3e 100644 --- a/src/memory.zig +++ b/src/memory.zig @@ -75,6 +75,36 @@ pub fn deinit(ptr_or_slice: anytype) void { } } +/// Rebase a slice from one memory buffer to another buffer. +/// +/// Given a slice which points into a memory buffer with base `old_base`, return +/// a slice which points to the same offset in a new memory buffer with base +/// `new_base`, preserving the length of the slice. +/// +/// +/// ``` +/// const old_base = [6]u8{}; +/// assert(@ptrToInt(&old_base) == 0x32); +/// +/// 0x32 0x33 0x34 0x35 0x36 0x37 +/// old_base |????|????|????|????|????|????| +/// ^ +/// |<-- slice --->| +/// +/// const new_base = [6]u8{}; +/// assert(@ptrToInt(&new_base) == 0x74); +/// const output = rebaseSlice(slice, old_base, new_base) +/// +/// 0x74 0x75 0x76 0x77 0x78 0x79 +/// new_base |????|????|????|????|????|????| +/// ^ +/// |<-- output -->| +/// ``` +pub fn rebaseSlice(slice: []const u8, old_base: [*]const u8, new_base: [*]const u8) []const u8 { + const offset = @intFromPtr(slice.ptr) - @intFromPtr(old_base); + return new_base[offset..][0..slice.len]; +} + const std = @import("std"); const Allocator = std.mem.Allocator; diff --git a/src/string/StringBuilder.zig b/src/string/StringBuilder.zig index 91e8b1f250..ae89ee4f74 100644 --- a/src/string/StringBuilder.zig +++ b/src/string/StringBuilder.zig @@ -236,6 +236,15 @@ pub fn writable(this: *StringBuilder) []u8 { return ptr[this.len..this.cap]; } +/// Transfer ownership of the underlying memory to a slice. +/// +/// After calling this, you are responsible for freeing the underlying memory. +/// This StringBuilder should not be used after calling this function. +pub fn moveToSlice(this: *StringBuilder, into_slice: *[]u8) void { + into_slice.* = this.allocatedSlice(); + this.* = .{}; +} + const std = @import("std"); const Allocator = std.mem.Allocator; diff --git a/src/valkey/ValkeyCommand.zig b/src/valkey/ValkeyCommand.zig index dfc9c448e6..2349ecad4a 100644 --- a/src/valkey/ValkeyCommand.zig +++ b/src/valkey/ValkeyCommand.zig @@ -137,7 +137,7 @@ pub const Promise = struct { self.promise.resolve(globalObject, js_value); } - pub fn reject(self: *Promise, globalObject: *jsc.JSGlobalObject, jsvalue: jsc.JSValue) void { + pub fn reject(self: *Promise, globalObject: *jsc.JSGlobalObject, jsvalue: JSError!jsc.JSValue) void { self.promise.reject(globalObject, jsvalue); } @@ -162,6 +162,7 @@ const protocol = @import("./valkey_protocol.zig"); const std = @import("std"); const bun = @import("bun"); +const JSError = bun.JSError; const jsc = bun.jsc; const node = bun.api.node; const Slice = jsc.ZigString.Slice; diff --git a/src/valkey/js_valkey.zig b/src/valkey/js_valkey.zig index edde503c59..65406c5b48 100644 --- a/src/valkey/js_valkey.zig +++ b/src/valkey/js_valkey.zig @@ -1,9 +1,233 @@ +pub const SubscriptionCtx = struct { + const Self = @This(); + + // TODO(markovejnovic): Consider using refactoring this to use + // @fieldParentPtr. The reason this was not implemented is because there is + // no support for optional fields yet. + // + // See: https://github.com/ziglang/zig/issues/25241 + // + // An alternative is to hold a flag within the context itself, indicating + // whether it is active or not, but that feels less clean. + _parent: *JSValkeyClient, + + original_enable_offline_queue: bool, + original_enable_auto_pipelining: bool, + + const ParentJS = JSValkeyClient.js; + + pub fn init(parent: *JSValkeyClient, enable_offline_queue: bool, enable_auto_pipelining: bool) bun.JSError!Self { + const callback_map = jsc.JSMap.create(parent.globalObject); + const parent_this = parent.this_value.tryGet() orelse unreachable; + + ParentJS.gc.set(.subscriptionCallbackMap, parent_this, parent.globalObject, callback_map); + + const self = Self{ + ._parent = parent, + .original_enable_offline_queue = enable_offline_queue, + .original_enable_auto_pipelining = enable_auto_pipelining, + }; + return self; + } + + fn subscriptionCallbackMap(this: *Self) *jsc.JSMap { + const parent_this = this._parent.this_value.tryGet() orelse unreachable; + + const value_js = ParentJS.gc.get(.subscriptionCallbackMap, parent_this).?; + return jsc.JSMap.fromJS(value_js).?; + } + + /// Get the total number of channels that this subscription context is subscribed to. + pub fn channelsSubscribedToCount(this: *Self, globalObject: *jsc.JSGlobalObject) bun.JSError!u32 { + return this.subscriptionCallbackMap().size(globalObject); + } + + /// Test whether this context has any subscriptions. It is mandatory to + /// guard deinit with this function. + pub fn hasSubscriptions(this: *Self, globalObject: *jsc.JSGlobalObject) bun.JSError!bool { + return (try this.channelsSubscribedToCount(globalObject)) > 0; + } + + pub fn clearReceiveHandlers( + this: *Self, + globalObject: *jsc.JSGlobalObject, + channelName: JSValue, + ) bun.JSError!void { + const map = this.subscriptionCallbackMap(); + _ = try map.remove(globalObject, channelName); + } + + /// Remove a specific receive handler. + /// + /// Returns: The total number of remaining handlers for this channel, or null if here were no listeners originally + /// registered. + /// + /// Note: This function will empty out the map entry if there are no more handlers registered. + pub fn removeReceiveHandler( + this: *Self, + globalObject: *jsc.JSGlobalObject, + channelName: JSValue, + callback: JSValue, + ) !?usize { + const map = this.subscriptionCallbackMap(); + + const existing = try map.get(globalObject, channelName); + if (existing.isUndefinedOrNull()) { + // Nothing to remove. + return null; + } + + // Existing is guaranteed to be an array of callbacks. + // This check is necessary because crossing between Zig and C++ is necessary because Zig doesn't know that C++ + // is side-effect-free. + if (comptime bun.Environment.isDebug) { + bun.assert(existing.isArray()); + } + + // TODO(markovejnovic): I can't find a better way to do this... I generate a new array, + // filtering out the callback we want to remove. This is woefully inefficient for large + // sets (and surprisingly fast for small sets of callbacks). + // + // Perhaps there is an avenue to build a generic iterator pattern? @taylor.fish and I have + // briefly expressed a desire for this, and I promised her I would look into it, but at + // this moment have no proposal. + var array_it = try existing.arrayIterator(globalObject); + const updated_array = try jsc.JSArray.createEmpty(globalObject, 0); + while (try array_it.next()) |iter| { + if (iter == callback) + continue; + + try updated_array.push(globalObject, iter); + } + + // Otherwise, we have ourselves an array of callbacks. We need to remove the element in the + // array that matches the callback. + _ = try map.remove(globalObject, channelName); + + // Only populate the map if we have remaining callbacks for this channel. + const new_length = try updated_array.getLength(globalObject); + + if (new_length != 0) { + try map.set(globalObject, channelName, updated_array); + } + + return new_length; + } + + /// Add a handler for receiving messages on a specific channel + pub fn upsertReceiveHandler( + this: *Self, + globalObject: *jsc.JSGlobalObject, + channelName: JSValue, + callback: JSValue, + ) bun.JSError!void { + defer this._parent.onNewSubscriptionCallbackInsert(); + const map = this.subscriptionCallbackMap(); + + var handlers_array: JSValue = undefined; + var is_new_channel = false; + const existing_handler_arr = try map.get(globalObject, channelName); + if (existing_handler_arr != .js_undefined) { + debug("Adding a new receive handler.", .{}); + // Note that we need to cover this case because maps in JSC can return undefined when the key has never been + // set. + if (existing_handler_arr.isUndefined()) { + // Create a new array if the existing_handler_arr is undefined/null + handlers_array = try jsc.JSArray.createEmpty(globalObject, 0); + is_new_channel = true; + } else if (existing_handler_arr.isArray()) { + // Use the existing array + handlers_array = existing_handler_arr; + } else unreachable; + } else { + // No existing_handler_arr exists, create a new array + handlers_array = try jsc.JSArray.createEmpty(globalObject, 0); + is_new_channel = true; + } + + // Append the new callback to the array + try handlers_array.push(globalObject, callback); + + // Set the updated array back in the map + try map.set(globalObject, channelName, handlers_array); + } + + pub fn getCallbacks(this: *Self, globalObject: *jsc.JSGlobalObject, channelName: JSValue) bun.JSError!?JSValue { + const result = try this.subscriptionCallbackMap().get(globalObject, channelName); + if (result == .js_undefined) { + return null; + } + + return result; + } + + /// Invoke callbacks for a channel with the given arguments + /// Handles both single callbacks and arrays of callbacks + pub fn invokeCallbacks( + this: *Self, + globalObject: *jsc.JSGlobalObject, + channelName: JSValue, + args: []const JSValue, + ) bun.JSError!void { + const callbacks = try this.getCallbacks(globalObject, channelName) orelse { + debug("No callbacks found for channel {s}", .{channelName.asString().getZigString(globalObject)}); + return; + }; + + if (comptime bun.Environment.isDebug) { + bun.assert(callbacks.isArray()); + } + + const vm = jsc.VirtualMachine.get(); + const event_loop = vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + + // After we go through every single callback, we will have to update the poll ref. + // The user may, for example, unsubscribe in the callbacks, or even stop the client. + defer this._parent.updatePollRef(); + + // If callbacks is an array, iterate and call each one + var iter = try callbacks.arrayIterator(globalObject); + while (try iter.next()) |callback| { + if (comptime bun.Environment.isDebug) { + bun.assert(callback.isCallable()); + } + + event_loop.runCallback(callback, globalObject, .js_undefined, args); + } + } + + /// Return whether the subscription context is ready to be deleted by the JS garbage collector. + pub fn isDeletable(this: *Self, global_object: *jsc.JSGlobalObject) bun.JSError!bool { + // The user may request .close(), in which case we can dispose of the subscription object. If that is the case, + // finalized will be true. Otherwise, we should treat the object as disposable if there are no active + // subscriptions. + return this._parent.client.flags.finalized or !(try this.hasSubscriptions(global_object)); + } + + pub fn deinit(this: *Self, global_object: *jsc.JSGlobalObject) void { + // This check is necessary because crossing between Zig and C++ is necessary because Zig doesn't know that C++ + // is side-effect-free. + if (comptime bun.Environment.isDebug) { + bun.debugAssert(this.isDeletable(this._parent.globalObject) catch unreachable); + } + + if (this._parent.this_value.tryGet()) |parent_this| { + ParentJS.gc.set(.subscriptionCallbackMap, parent_this, global_object, .js_undefined); + } + } +}; + /// Valkey client wrapper for JavaScript pub const JSValkeyClient = struct { client: valkey.ValkeyClient, globalObject: *jsc.JSGlobalObject, this_value: jsc.JSRef = jsc.JSRef.empty(), poll_ref: bun.Async.KeepAlive = .{}, + + _subscription_ctx: ?SubscriptionCtx, + timer: Timer.EventLoopTimer = .{ .tag = .ValkeyConnectionTimeout, .next = .{ @@ -31,11 +255,13 @@ pub const JSValkeyClient = struct { pub const new = bun.TrivialNew(@This()); // Factory function to create a new Valkey client from JS - pub fn constructor(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!*JSValkeyClient { - return try create(globalObject, callframe.arguments()); + pub fn constructor(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame, js_this: JSValue) bun.JSError!*JSValkeyClient { + return try create(globalObject, callframe.arguments(), js_this); } - pub fn create(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue) bun.JSError!*JSValkeyClient { + pub fn createNoJs(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue) bun.JSError!*JSValkeyClient { + const this_allocator = bun.default_allocator; + const vm = globalObject.bunVM(); const url_str = if (arguments.len < 1 or arguments[0].isUndefined()) if (vm.transpiler.env.get("REDIS_URL") orelse vm.transpiler.env.get("VALKEY_URL")) |url| @@ -46,7 +272,7 @@ pub const JSValkeyClient = struct { try arguments[0].toBunString(globalObject); defer url_str.deref(); - const url_utf8 = url_str.toUTF8WithoutRef(bun.default_allocator); + const url_utf8 = url_str.toUTF8WithoutRef(this_allocator); defer url_utf8.deinit(); const url = bun.URL.parse(url_utf8.slice()); @@ -88,7 +314,7 @@ pub const JSValkeyClient = struct { var connection_strings: []u8 = &.{}; errdefer { - bun.default_allocator.free(connection_strings); + this_allocator.free(connection_strings); } if (url.username.len > 0 or url.password.len > 0 or hostname.len > 0) { @@ -96,19 +322,21 @@ pub const JSValkeyClient = struct { b.count(url.username); b.count(url.password); b.count(hostname); - try b.allocate(bun.default_allocator); + try b.allocate(this_allocator); + defer b.deinit(this_allocator); username = b.append(url.username); password = b.append(url.password); hostname = b.append(hostname); - connection_strings = b.allocatedSlice(); + b.moveToSlice(&connection_strings); } const database = if (url.pathname.len > 0) std.fmt.parseInt(u32, url.pathname[1..], 10) catch 0 else 0; bun.analytics.Features.valkey += 1; - return JSValkeyClient.new(.{ + const client = JSValkeyClient.new(.{ .ref_count = .init(), + ._subscription_ctx = null, .client = .{ .vm = vm, .address = switch (uri) { @@ -120,10 +348,11 @@ pub const JSValkeyClient = struct { }, }, }, + .protocol = uri, .username = username, .password = password, - .in_flight = .init(bun.default_allocator), - .queue = .init(bun.default_allocator), + .in_flight = .init(this_allocator), + .queue = .init(this_allocator), .status = .disconnected, .connection_strings = connection_strings, .socket = .{ @@ -134,7 +363,7 @@ pub const JSValkeyClient = struct { }, }, .database = database, - .allocator = bun.default_allocator, + .allocator = this_allocator, .flags = .{ .enable_auto_reconnect = options.enable_auto_reconnect, .enable_offline_queue = options.enable_offline_queue, @@ -146,6 +375,130 @@ pub const JSValkeyClient = struct { }, .globalObject = globalObject, }); + + return client; + } + + pub fn create(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue, js_this: JSValue) bun.JSError!*JSValkeyClient { + var new_client = try JSValkeyClient.createNoJs(globalObject, arguments); + + // Initially, we only need to hold a weak reference to the JS object. + new_client.this_value = jsc.JSRef.initWeak(js_this); + return new_client; + } + + /// Clone this client while remaining in the initial disconnected state. + /// + /// Note that this does not create an object with an associated this_value. + /// You may need to populate it yourself. + pub fn cloneWithoutConnecting( + this: *const JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + ) bun.OOM!*JSValkeyClient { + const vm = globalObject.bunVM(); + + // Make a copy of connection_strings to avoid double-free + const connection_strings_copy = try this.client.allocator.dupe(u8, this.client.connection_strings); + + // Note that there is no need to copy username, password and address since the copies live + // within the connection_strings buffer. + const base_ptr = this.client.connection_strings.ptr; + const new_base = connection_strings_copy.ptr; + const username = bun.memory.rebaseSlice(this.client.username, base_ptr, new_base); + const password = bun.memory.rebaseSlice(this.client.password, base_ptr, new_base); + const orig_hostname = this.client.address.hostname(); + const hostname = bun.memory.rebaseSlice(orig_hostname, base_ptr, new_base); + const new_alloc = this.client.allocator; + + return JSValkeyClient.new(.{ + .ref_count = .init(), + ._subscription_ctx = null, + .client = .{ + .vm = vm, + .address = switch (this.client.protocol) { + .standalone_unix, .standalone_tls_unix => .{ .unix = hostname }, + else => .{ + .host = .{ + .host = hostname, + .port = this.client.address.host.port, + }, + }, + }, + .protocol = this.client.protocol, + .username = username, + .password = password, + .in_flight = .init(new_alloc), + .queue = .init(new_alloc), + .status = .disconnected, + .connection_strings = connection_strings_copy, + .socket = .{ + .SocketTCP = .{ + .socket = .{ + .detached = {}, + }, + }, + }, + .database = this.client.database, + .allocator = new_alloc, + .flags = .{ + // Because this starts in the disconnected state, we need to reset some flags. + .is_authenticated = false, + // If the user manually closed the connection, then duplicating a closed client + // means the new client remains finalized. + .is_manually_closed = this.client.flags.is_manually_closed, + .enable_offline_queue = if (this._subscription_ctx) |*ctx| ctx.original_enable_offline_queue else this.client.flags.enable_offline_queue, + .needs_to_open_socket = true, + .enable_auto_reconnect = this.client.flags.enable_auto_reconnect, + .is_reconnecting = false, + .auto_pipelining = if (this._subscription_ctx) |*ctx| ctx.original_enable_auto_pipelining else this.client.flags.auto_pipelining, + // Duplicating a finalized client means it stays finalized. + .finalized = this.client.flags.finalized, + }, + .max_retries = this.client.max_retries, + .connection_timeout_ms = this.client.connection_timeout_ms, + .idle_timeout_interval_ms = this.client.idle_timeout_interval_ms, + }, + .globalObject = globalObject, + }); + } + + pub fn getOrCreateSubscriptionCtxEnteringSubscriptionMode( + this: *JSValkeyClient, + ) bun.JSError!*SubscriptionCtx { + if (this._subscription_ctx) |*ctx| { + // If we already have a subscription context, return it + return ctx; + } + + // Save the original flag values and create a new subscription context + this._subscription_ctx = try SubscriptionCtx.init( + this, + this.client.flags.enable_offline_queue, + this.client.flags.auto_pipelining, + ); + + // We need to make sure we disable the offline queue. + this.client.flags.enable_offline_queue = false; + this.client.flags.auto_pipelining = false; + + return &(this._subscription_ctx.?); + } + + pub fn deleteSubscriptionCtx(this: *JSValkeyClient) void { + if (this._subscription_ctx) |*ctx| { + // Restore the original flag values when leaving subscription mode + this.client.flags.enable_offline_queue = ctx.original_enable_offline_queue; + this.client.flags.auto_pipelining = ctx.original_enable_auto_pipelining; + + ctx.deinit(this.globalObject); + this._subscription_ctx = null; + } + + this._subscription_ctx = null; + } + + pub fn isSubscriber(this: *const JSValkeyClient) bool { + return this._subscription_ctx != null; } pub fn getConnected(this: *JSValkeyClient, _: *jsc.JSGlobalObject) JSValue { @@ -159,16 +512,22 @@ pub const JSValkeyClient = struct { return JSValue.jsNumber(len); } - pub fn doConnect(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, this_value: JSValue) bun.JSError!JSValue { + pub fn doConnect( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + this_value: JSValue, + ) bun.JSError!JSValue { this.ref(); defer this.deref(); // If already connected, resolve immediately if (this.client.status == .connected) { + debug("Connecting client is already connected.", .{}); return jsc.JSPromise.resolvedPromiseValue(globalObject, js.helloGetCached(this_value) orelse .js_undefined); } if (js.connectionPromiseGetCached(this_value)) |promise| { + debug("Connecting client is already connected.", .{}); return promise; } @@ -181,12 +540,16 @@ pub const JSValkeyClient = struct { this.this_value.setStrong(this_value, globalObject); if (this.client.flags.needs_to_open_socket) { + debug("Need to open socket, starting connection process.", .{}); this.poll_ref.ref(this.client.vm); this.connect() catch |err| { this.poll_ref.unref(this.client.vm); this.client.flags.needs_to_open_socket = true; const err_value = globalObject.ERR(.SOCKET_CLOSED_BEFORE_CONNECTION, " {s} connecting to Valkey", .{@errorName(err)}).toJS(); + const event_loop = this.client.vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); promise_ptr.reject(globalObject, err_value); return promise; }; @@ -274,14 +637,14 @@ pub const JSValkeyClient = struct { /// Safely remove a timer with proper reference counting and event loop keepalive fn removeTimer(this: *JSValkeyClient, timer: *Timer.EventLoopTimer) void { if (timer.state == .ACTIVE) { - // Store VM reference to use later const vm = this.client.vm; // Remove the timer from the event loop vm.timer.remove(timer); - // Balance the ref from addTimer + // this.addTimer() adds a reference to 'this' when the timer is + // alive which is balanced here. this.deref(); } } @@ -391,11 +754,7 @@ pub const JSValkeyClient = struct { // Callback for when Valkey client connects pub fn onValkeyConnect(this: *JSValkeyClient, value: *protocol.RESPValue) void { - // Safety check to ensure a valid connection state - if (this.client.status != .connected) { - debug("onValkeyConnect called but client status is not 'connected': {s}", .{@tagName(this.client.status)}); - return; - } + bun.debugAssert(this.client.status == .connected); const globalObject = this.globalObject; const event_loop = this.client.vm.eventLoop(); @@ -414,7 +773,16 @@ pub const JSValkeyClient = struct { if (js.connectionPromiseGetCached(this_value)) |promise| { js.connectionPromiseSetCached(this_value, globalObject, .zero); - promise.asPromise().?.resolve(globalObject, hello_value); + const js_promise = promise.asPromise().?; + if (this.client.flags.connection_promise_returns_client) { + debug("Resolving connection promise with client instance", .{}); + const this_js = this.toJS(globalObject); + js_promise.resolve(globalObject, this_js); + } else { + debug("Resolving connection promise with HELLO response", .{}); + js_promise.resolve(globalObject, hello_value); + } + this.client.flags.connection_promise_returns_client = false; } } @@ -422,6 +790,97 @@ pub const JSValkeyClient = struct { this.updatePollRef(); } + /// Invoked when the Valkey client receives a new listener. + /// + /// `SubscriptionCtx` will invoke this to communicate that it has added a new listener. + pub fn onNewSubscriptionCallbackInsert(this: *JSValkeyClient) void { + this.ref(); + defer this.deref(); + + this.client.onWritable(); + this.updatePollRef(); + } + + pub fn onValkeySubscribe(this: *JSValkeyClient, value: *protocol.RESPValue) void { + bun.debugAssert(this.isSubscriber()); + bun.debugAssert(this.this_value.isStrong()); + + this.ref(); + defer this.deref(); + + _ = value; + + this.client.onWritable(); + this.updatePollRef(); + } + + pub fn onValkeyUnsubscribe(this: *JSValkeyClient) bun.JSError!void { + bun.debugAssert(this.isSubscriber()); + bun.debugAssert(this.this_value.isStrong()); + + this.ref(); + defer this.deref(); + + var subscription_ctx = this._subscription_ctx.?; + + // Check if we have any remaining subscriptions + // If the callback map is empty, we can exit subscription mode + + // If fetching the subscription count fails, the best we can do is + // bubble the error up. + const has_subs = try subscription_ctx.hasSubscriptions(this.globalObject); + if (!has_subs) { + // No more subscriptions, exit subscription mode + this.deleteSubscriptionCtx(); + } + + this.client.onWritable(); + this.updatePollRef(); + } + + pub fn onValkeyMessage(this: *JSValkeyClient, value: []protocol.RESPValue) void { + if (!this.isSubscriber()) { + debug("onMessage called but client is not in subscriber mode", .{}); + return; + } + + const globalObject = this.globalObject; + const event_loop = this.client.vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + + // The message push should be an array with [channel, message] + if (value.len < 2) { + debug("Message array has insufficient elements: {}", .{value.len}); + return; + } + + // Extract channel and message + const channel_value = value[0].toJS(globalObject) catch { + debug("Failed to convert channel to JS", .{}); + return; + }; + const message_value = value[1].toJS(globalObject) catch { + debug("Failed to convert message to JS", .{}); + return; + }; + + // Get the subscription context + const subs_ctx = &this._subscription_ctx.?; + + // Invoke callbacks for this channel with message and channel as arguments + subs_ctx.invokeCallbacks( + globalObject, + channel_value, + &[_]JSValue{ message_value, channel_value }, + ) catch { + return; + }; + + this.client.onWritable(); + this.updatePollRef(); + } + // Callback for when Valkey client needs to reconnect pub fn onValkeyReconnect(this: *JSValkeyClient) void { // Schedule reconnection using our safe timer methods @@ -468,6 +927,9 @@ pub const JSValkeyClient = struct { &[_]JSValue{error_value}, ) catch |e| globalObject.reportActiveExceptionAsUnhandled(e); } + + // Update poll reference to allow garbage collection of disconnected clients + this.updatePollRef(); } // Callback for when Valkey client times out @@ -495,19 +957,20 @@ pub const JSValkeyClient = struct { } pub fn finalize(this: *JSValkeyClient) void { - // Since this.stopTimers impacts the reference count potentially, we - // need to ref/unref here as well. this.ref(); defer this.deref(); this.stopTimers(); - this.this_value.deinit(); - if (this.client.status == .connected or this.client.status == .connecting) { - this.client.flags.is_manually_closed = true; - } + this.this_value.finalize(); this.client.flags.finalized = true; this.client.close(); - this.deref(); + + // We do not need to free the subscription context here because we're + // guaranteed to have freed it by virtue of the fact that we are + // garbage collected now and the subscription context holds a reference + // to us. If we still had a subscription context, we would never be + // garbage collected. + bun.debugAssert(this._subscription_ctx == null); } pub fn stopTimers(this: *JSValkeyClient) void { @@ -521,6 +984,7 @@ pub const JSValkeyClient = struct { } fn connect(this: *JSValkeyClient) !void { + debug("Connecting to Redis.", .{}); this.client.flags.needs_to_open_socket = false; const vm = this.client.vm; @@ -578,6 +1042,9 @@ pub const JSValkeyClient = struct { this.client.flags.needs_to_open_socket = true; const err_value = globalThis.ERR(.SOCKET_CLOSED_BEFORE_CONNECTION, " {s} connecting to Valkey", .{@errorName(err)}).toJS(); const promise = jsc.JSPromise.create(globalThis); + const event_loop = this.client.vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); promise.reject(globalThis, err_value); return promise; }; @@ -612,25 +1079,77 @@ pub const JSValkeyClient = struct { this.client.deinit(null); this.poll_ref.disable(); this.stopTimers(); - this.this_value.deinit(); + this.this_value.finalize(); this.ref_count.assertNoRefs(); bun.destroy(this); } /// Keep the event loop alive, or don't keep it alive + /// + /// This requires this_value to be alive. pub fn updatePollRef(this: *JSValkeyClient) void { - if (!this.client.hasAnyPendingCommands() and this.client.status == .connected) { - this.poll_ref.unref(this.client.vm); - // If we don't have any pending commands and we're connected, we don't need to keep the object alive. - if (this.this_value.tryGet()) |value| { - this.this_value.setWeak(value); - } - } else if (this.client.hasAnyPendingCommands()) { + // TODO(markovejnovic): This function is such a crazy cop out. We really + // should be treating valkey as a state machine, with well-defined + // state and modes in which it tracks and manages its own lifecycle. + // This is a mess beyond belief and it is incredibly fragile. + + const has_pending_commands = this.client.hasAnyPendingCommands(); + + // isDeletable may throw an exception, and if it does, we have to assume + // that the object still has references. Best we can do is hope nothing + // catastrophic happens. + const subs_deletable: bool = if (this._subscription_ctx) |*ctx| + ctx.isDeletable(this.globalObject) catch false + else + true; + + const has_activity = has_pending_commands or !subs_deletable; + + // There's a couple cases to handle here: + if (has_activity) { + // If we currently have pending activity, we need to keep the event + // loop alive. this.poll_ref.ref(this.client.vm); - // If we have pending commands, we need to keep the object alive. - if (this.this_value == .weak) { + } else { + // There is no pending activity so it is safe to remove the event + // loop. + this.poll_ref.unref(this.client.vm); + } + + if (this.this_value.isEmpty()) { + return; + } + + // Orthogonal to this, we need to manage the strong reference to the JS + // object. + switch (this.client.status) { + .connecting, .connected => { + // Whenever we're connected, we need to keep the object alive. + // + // TODO(markovejnovic): This is a leak. + // Note this is an intentional leak. Unless the user manually + // closes the connection, the object will stay alive forever, + // even if it falls out of scope. This is kind of stupid, since + // if the object is out of scope, and isn't subscribed upon, + // how exactly is the user going to call anything on the object? + // + // It is 100% safe to drop the strong reference there and let + // the object be GC'd, but we're not doing that now. this.this_value.upgrade(this.globalObject); - } + }, + .disconnected, .failed => { + // If we're disconnected or failed, we need to check if we have + // any pending activity. + if (has_activity) { + // If we have pending activity, we need to keep the object + // alive. + this.this_value.upgrade(this.globalObject); + } else { + // If we don't have any pending activity, we can drop the + // strong reference. + this.this_value.downgrade(); + } + }, } } @@ -641,6 +1160,7 @@ pub const JSValkeyClient = struct { pub const decr = fns.decr; pub const del = fns.del; pub const dump = fns.dump; + pub const duplicate = fns.duplicate; pub const exists = fns.exists; pub const expire = fns.expire; pub const expiretime = fns.expiretime; @@ -756,18 +1276,21 @@ fn SocketHandler(comptime ssl: bool) type { pub const onHandshake = if (ssl) onHandshake_ else null; pub fn onClose(this: *JSValkeyClient, _: SocketType, _: i32, _: ?*anyopaque) void { + // No need to deref since this.client.onClose() invokes onValkeyClose which does the deref. + + debug("Socket closed.", .{}); + // Ensure the socket pointer is updated. this.client.socket = .{ .SocketTCP = .detached }; this.client.onClose(); + this.updatePollRef(); } pub fn onEnd(this: *JSValkeyClient, socket: SocketType) void { - // Ensure the socket pointer is updated before closing - this.client.socket = _socket(socket); - - // Do not allow half-open connections - socket.close(.normal); + _ = this; + _ = socket; + // Half-opened sockets are not allowed. } pub fn onConnectError(this: *JSValkeyClient, _: SocketType, _: i32) void { @@ -778,6 +1301,8 @@ fn SocketHandler(comptime ssl: bool) type { } pub fn onTimeout(this: *JSValkeyClient, socket: SocketType) void { + debug("Socket timed out.", .{}); + this.client.socket = _socket(socket); // Handle socket timeout } diff --git a/src/valkey/js_valkey_functions.zig b/src/valkey/js_valkey_functions.zig index 4b6d2daad6..9361e404cf 100644 --- a/src/valkey/js_valkey_functions.zig +++ b/src/valkey/js_valkey_functions.zig @@ -1,3 +1,19 @@ +fn requireNotSubscriber(this: *JSValkeyClient, function_name: []const u8) bun.JSError!void { + const fmt_string = "RedisClient.prototype.{s} cannot be called while in subscriber mode."; + + if (this.isSubscriber()) { + return this.globalObject.ERR(.REDIS_INVALID_STATE, fmt_string, .{function_name}).throw(); + } +} + +fn requireSubscriber(this: *JSValkeyClient, function_name: []const u8) bun.JSError!void { + const fmt_string = "RedisClient.prototype.{s} can only be called while in subscriber mode."; + + if (!this.isSubscriber()) { + return this.globalObject.ERR(.REDIS_INVALID_STATE, fmt_string, .{function_name}).throw(); + } +} + pub fn jsSend(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { const command = try callframe.argument(0).toBunString(globalObject); defer command.deref(); @@ -41,6 +57,8 @@ pub fn jsSend(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfram } pub fn get(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("get", "key", "string or buffer"); }; @@ -61,6 +79,8 @@ pub fn get(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: } pub fn getBuffer(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("getBuffer", "key", "string or buffer"); }; @@ -81,6 +101,8 @@ pub fn getBuffer(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callf } pub fn set(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const args_view = callframe.arguments(); var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); var args = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), args_view.len); @@ -127,6 +149,8 @@ pub fn set(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: } pub fn incr(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("incr", "key", "string or buffer"); }; @@ -147,6 +171,8 @@ pub fn incr(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: } pub fn decr(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("decr", "key", "string or buffer"); }; @@ -167,6 +193,8 @@ pub fn decr(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: } pub fn exists(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("exists", "key", "string or buffer"); }; @@ -188,6 +216,8 @@ pub fn exists(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfram } pub fn expire(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("expire", "key", "string or buffer"); }; @@ -219,6 +249,8 @@ pub fn expire(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfram } pub fn ttl(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("ttl", "key", "string or buffer"); }; @@ -240,6 +272,8 @@ pub fn ttl(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: // Implement srem (remove value from a set) pub fn srem(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("srem", "key", "string or buffer"); }; @@ -265,6 +299,8 @@ pub fn srem(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: // Implement srandmember (get random member from set) pub fn srandmember(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("srandmember", "key", "string or buffer"); }; @@ -286,6 +322,8 @@ pub fn srandmember(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, cal // Implement smembers (get all members of a set) pub fn smembers(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("smembers", "key", "string or buffer"); }; @@ -307,6 +345,8 @@ pub fn smembers(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfr // Implement spop (pop a random member from a set) pub fn spop(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("spop", "key", "string or buffer"); }; @@ -328,6 +368,8 @@ pub fn spop(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: // Implement sadd (add member to a set) pub fn sadd(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("sadd", "key", "string or buffer"); }; @@ -353,6 +395,8 @@ pub fn sadd(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: // Implement sismember (check if value is member of a set) pub fn sismember(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("sismember", "key", "string or buffer"); }; @@ -379,6 +423,8 @@ pub fn sismember(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callf // Implement hmget (get multiple values from hash) pub fn hmget(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("hmget", "key", "string or buffer"); }; @@ -426,6 +472,8 @@ pub fn hmget(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe // Implement hincrby (increment hash field by integer value) pub fn hincrby(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = try callframe.argument(0).toBunString(globalObject); defer key.deref(); const field = try callframe.argument(1).toBunString(globalObject); @@ -456,6 +504,8 @@ pub fn hincrby(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfra // Implement hincrbyfloat (increment hash field by float value) pub fn hincrbyfloat(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = try callframe.argument(0).toBunString(globalObject); defer key.deref(); const field = try callframe.argument(1).toBunString(globalObject); @@ -486,6 +536,8 @@ pub fn hincrbyfloat(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, ca // Implement hmset (set multiple values in hash) pub fn hmset(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = try callframe.argument(0).toBunString(globalObject); defer key.deref(); @@ -578,60 +630,346 @@ pub fn ping(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: return promise.toJS(); } -pub const bitcount = compile.@"(key: RedisKey)"("bitcount", "BITCOUNT", "key").call; -pub const dump = compile.@"(key: RedisKey)"("dump", "DUMP", "key").call; -pub const expiretime = compile.@"(key: RedisKey)"("expiretime", "EXPIRETIME", "key").call; -pub const getdel = compile.@"(key: RedisKey)"("getdel", "GETDEL", "key").call; -pub const getex = compile.@"(...strings: string[])"("getex", "GETEX").call; -pub const hgetall = compile.@"(key: RedisKey)"("hgetall", "HGETALL", "key").call; -pub const hkeys = compile.@"(key: RedisKey)"("hkeys", "HKEYS", "key").call; -pub const hlen = compile.@"(key: RedisKey)"("hlen", "HLEN", "key").call; -pub const hvals = compile.@"(key: RedisKey)"("hvals", "HVALS", "key").call; -pub const keys = compile.@"(key: RedisKey)"("keys", "KEYS", "key").call; -pub const llen = compile.@"(key: RedisKey)"("llen", "LLEN", "key").call; -pub const lpop = compile.@"(key: RedisKey)"("lpop", "LPOP", "key").call; -pub const persist = compile.@"(key: RedisKey)"("persist", "PERSIST", "key").call; -pub const pexpiretime = compile.@"(key: RedisKey)"("pexpiretime", "PEXPIRETIME", "key").call; -pub const pttl = compile.@"(key: RedisKey)"("pttl", "PTTL", "key").call; -pub const rpop = compile.@"(key: RedisKey)"("rpop", "RPOP", "key").call; -pub const scard = compile.@"(key: RedisKey)"("scard", "SCARD", "key").call; -pub const strlen = compile.@"(key: RedisKey)"("strlen", "STRLEN", "key").call; -pub const @"type" = compile.@"(key: RedisKey)"("type", "TYPE", "key").call; -pub const zcard = compile.@"(key: RedisKey)"("zcard", "ZCARD", "key").call; -pub const zpopmax = compile.@"(key: RedisKey)"("zpopmax", "ZPOPMAX", "key").call; -pub const zpopmin = compile.@"(key: RedisKey)"("zpopmin", "ZPOPMIN", "key").call; -pub const zrandmember = compile.@"(key: RedisKey)"("zrandmember", "ZRANDMEMBER", "key").call; +pub const bitcount = compile.@"(key: RedisKey)"("bitcount", "BITCOUNT", "key", .not_subscriber).call; +pub const dump = compile.@"(key: RedisKey)"("dump", "DUMP", "key", .not_subscriber).call; +pub const expiretime = compile.@"(key: RedisKey)"("expiretime", "EXPIRETIME", "key", .not_subscriber).call; +pub const getdel = compile.@"(key: RedisKey)"("getdel", "GETDEL", "key", .not_subscriber).call; +pub const getex = compile.@"(...strings: string[])"("getex", "GETEX", .not_subscriber).call; +pub const hgetall = compile.@"(key: RedisKey)"("hgetall", "HGETALL", "key", .not_subscriber).call; +pub const hkeys = compile.@"(key: RedisKey)"("hkeys", "HKEYS", "key", .not_subscriber).call; +pub const hlen = compile.@"(key: RedisKey)"("hlen", "HLEN", "key", .not_subscriber).call; +pub const hvals = compile.@"(key: RedisKey)"("hvals", "HVALS", "key", .not_subscriber).call; +pub const keys = compile.@"(key: RedisKey)"("keys", "KEYS", "key", .not_subscriber).call; +pub const llen = compile.@"(key: RedisKey)"("llen", "LLEN", "key", .not_subscriber).call; +pub const lpop = compile.@"(key: RedisKey)"("lpop", "LPOP", "key", .not_subscriber).call; +pub const persist = compile.@"(key: RedisKey)"("persist", "PERSIST", "key", .not_subscriber).call; +pub const pexpiretime = compile.@"(key: RedisKey)"("pexpiretime", "PEXPIRETIME", "key", .not_subscriber).call; +pub const pttl = compile.@"(key: RedisKey)"("pttl", "PTTL", "key", .not_subscriber).call; +pub const rpop = compile.@"(key: RedisKey)"("rpop", "RPOP", "key", .not_subscriber).call; +pub const scard = compile.@"(key: RedisKey)"("scard", "SCARD", "key", .not_subscriber).call; +pub const strlen = compile.@"(key: RedisKey)"("strlen", "STRLEN", "key", .not_subscriber).call; +pub const @"type" = compile.@"(key: RedisKey)"("type", "TYPE", "key", .not_subscriber).call; +pub const zcard = compile.@"(key: RedisKey)"("zcard", "ZCARD", "key", .not_subscriber).call; +pub const zpopmax = compile.@"(key: RedisKey)"("zpopmax", "ZPOPMAX", "key", .not_subscriber).call; +pub const zpopmin = compile.@"(key: RedisKey)"("zpopmin", "ZPOPMIN", "key", .not_subscriber).call; +pub const zrandmember = compile.@"(key: RedisKey)"("zrandmember", "ZRANDMEMBER", "key", .not_subscriber).call; -pub const append = compile.@"(key: RedisKey, value: RedisValue)"("append", "APPEND", "key", "value").call; -pub const getset = compile.@"(key: RedisKey, value: RedisValue)"("getset", "GETSET", "key", "value").call; -pub const hget = compile.@"(key: RedisKey, value: RedisValue)"("hget", "HGET", "key", "field").call; -pub const lpush = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("lpush", "LPUSH").call; -pub const lpushx = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("lpushx", "LPUSHX").call; -pub const pfadd = compile.@"(key: RedisKey, value: RedisValue)"("pfadd", "PFADD", "key", "value").call; -pub const rpush = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("rpush", "RPUSH").call; -pub const rpushx = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("rpushx", "RPUSHX").call; -pub const setnx = compile.@"(key: RedisKey, value: RedisValue)"("setnx", "SETNX", "key", "value").call; -pub const zscore = compile.@"(key: RedisKey, value: RedisValue)"("zscore", "ZSCORE", "key", "value").call; +pub const append = compile.@"(key: RedisKey, value: RedisValue)"("append", "APPEND", "key", "value", .not_subscriber).call; +pub const getset = compile.@"(key: RedisKey, value: RedisValue)"("getset", "GETSET", "key", "value", .not_subscriber).call; +pub const hget = compile.@"(key: RedisKey, value: RedisValue)"("hget", "HGET", "key", "field", .not_subscriber).call; +pub const lpush = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("lpush", "LPUSH", .not_subscriber).call; +pub const lpushx = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("lpushx", "LPUSHX", .not_subscriber).call; +pub const pfadd = compile.@"(key: RedisKey, value: RedisValue)"("pfadd", "PFADD", "key", "value", .not_subscriber).call; +pub const rpush = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("rpush", "RPUSH", .not_subscriber).call; +pub const rpushx = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("rpushx", "RPUSHX", .not_subscriber).call; +pub const setnx = compile.@"(key: RedisKey, value: RedisValue)"("setnx", "SETNX", "key", "value", .not_subscriber).call; +pub const zscore = compile.@"(key: RedisKey, value: RedisValue)"("zscore", "ZSCORE", "key", "value", .not_subscriber).call; -pub const del = compile.@"(key: RedisKey, ...args: RedisKey[])"("del", "DEL", "key").call; -pub const mget = compile.@"(key: RedisKey, ...args: RedisKey[])"("mget", "MGET", "key").call; +pub const del = compile.@"(key: RedisKey, ...args: RedisKey[])"("del", "DEL", "key", .not_subscriber).call; +pub const mget = compile.@"(key: RedisKey, ...args: RedisKey[])"("mget", "MGET", "key", .not_subscriber).call; -pub const publish = compile.@"(...strings: string[])"("publish", "PUBLISH").call; -pub const script = compile.@"(...strings: string[])"("script", "SCRIPT").call; -pub const select = compile.@"(...strings: string[])"("select", "SELECT").call; -pub const spublish = compile.@"(...strings: string[])"("spublish", "SPUBLISH").call; -pub const smove = compile.@"(...strings: string[])"("smove", "SMOVE").call; -pub const substr = compile.@"(...strings: string[])"("substr", "SUBSTR").call; -pub const hstrlen = compile.@"(...strings: string[])"("hstrlen", "HSTRLEN").call; -pub const zrank = compile.@"(...strings: string[])"("zrank", "ZRANK").call; -pub const zrevrank = compile.@"(...strings: string[])"("zrevrank", "ZREVRANK").call; -pub const subscribe = compile.@"(...strings: string[])"("subscribe", "SUBSCRIBE").call; -pub const psubscribe = compile.@"(...strings: string[])"("psubscribe", "PSUBSCRIBE").call; -pub const unsubscribe = compile.@"(...strings: string[])"("unsubscribe", "UNSUBSCRIBE").call; -pub const punsubscribe = compile.@"(...strings: string[])"("punsubscribe", "PUNSUBSCRIBE").call; -pub const pubsub = compile.@"(...strings: string[])"("pubsub", "PUBSUB").call; +pub const script = compile.@"(...strings: string[])"("script", "SCRIPT", .not_subscriber).call; +pub const select = compile.@"(...strings: string[])"("select", "SELECT", .not_subscriber).call; +pub const spublish = compile.@"(...strings: string[])"("spublish", "SPUBLISH", .not_subscriber).call; +pub const smove = compile.@"(...strings: string[])"("smove", "SMOVE", .not_subscriber).call; +pub const substr = compile.@"(...strings: string[])"("substr", "SUBSTR", .not_subscriber).call; +pub const hstrlen = compile.@"(...strings: string[])"("hstrlen", "HSTRLEN", .not_subscriber).call; +pub const zrank = compile.@"(...strings: string[])"("zrank", "ZRANK", .not_subscriber).call; +pub const zrevrank = compile.@"(...strings: string[])"("zrevrank", "ZREVRANK", .not_subscriber).call; +pub const psubscribe = compile.@"(...strings: string[])"("psubscribe", "PSUBSCRIBE", .dont_care).call; +pub const punsubscribe = compile.@"(...strings: string[])"("punsubscribe", "PUNSUBSCRIBE", .dont_care).call; +pub const pubsub = compile.@"(...strings: string[])"("pubsub", "PUBSUB", .dont_care).call; + +pub fn publish( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + callframe: *jsc.CallFrame, +) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + + const args_view = callframe.arguments(); + var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); + var args = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), args_view.len); + defer { + for (args.items) |*item| { + item.deinit(); + } + args.deinit(); + } + + const arg0 = callframe.argument(0); + if (!arg0.isString()) { + return globalObject.throwInvalidArgumentType("publish", "channel", "string"); + } + const channel = (try fromJS(globalObject, arg0)) orelse unreachable; + + args.appendAssumeCapacity(channel); + + const arg1 = callframe.argument(1); + if (!arg1.isString()) { + return globalObject.throwInvalidArgumentType("publish", "message", "string"); + } + const message = (try fromJS(globalObject, arg1)) orelse unreachable; + args.appendAssumeCapacity(message); + + const promise = this.send( + globalObject, + callframe.this(), + &.{ + .command = "PUBLISH", + .args = .{ .args = args.items }, + }, + ) catch |err| { + return protocol.valkeyErrorToJS(globalObject, "Failed to send PUBLISH command", err); + }; + + return promise.toJS(); +} + +pub fn subscribe( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + callframe: *jsc.CallFrame, +) bun.JSError!JSValue { + const channel_or_many, const handler_callback = callframe.argumentsAsArray(2); + var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); + var redis_channels = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), 1); + defer { + for (redis_channels.items) |*item| { + item.deinit(); + } + redis_channels.deinit(); + } + + if (!handler_callback.isCallable()) { + return globalObject.throwInvalidArgumentType("subscribe", "listener", "function"); + } + + // We now need to register the callback with our subscription context, which may or may not exist. + var subscription_ctx = try this.getOrCreateSubscriptionCtxEnteringSubscriptionMode(); + + // The first argument given is the channel or may be an array of channels. + if (channel_or_many.isArray()) { + if ((try channel_or_many.getLength(globalObject)) == 0) { + return globalObject.throwInvalidArguments("subscribe requires at least one channel", .{}); + } + try redis_channels.ensureTotalCapacity(try channel_or_many.getLength(globalObject)); + + var array_iter = try channel_or_many.arrayIterator(globalObject); + while (try array_iter.next()) |channel_arg| { + const channel = (try fromJS(globalObject, channel_arg)) orelse { + return globalObject.throwInvalidArgumentType("subscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + + try subscription_ctx.upsertReceiveHandler(globalObject, channel_arg, handler_callback); + } + } else if (channel_or_many.isString()) { + // It is a single string channel + const channel = (try fromJS(globalObject, channel_or_many)) orelse { + return globalObject.throwInvalidArgumentType("subscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + + try subscription_ctx.upsertReceiveHandler(globalObject, channel_or_many, handler_callback); + } else { + return globalObject.throwInvalidArgumentType("subscribe", "channel", "string or array"); + } + + const command: valkey.Command = .{ + .command = "SUBSCRIBE", + .args = .{ .args = redis_channels.items }, + }; + const promise = this.send( + globalObject, + callframe.this(), + &command, + ) catch |err| { + // If we find an error, we need to clean up the subscription context. + this.deleteSubscriptionCtx(); + return protocol.valkeyErrorToJS(globalObject, "Failed to send SUBSCRIBE command", err); + }; + + return promise.toJS(); +} + +/// Send redis the UNSUBSCRIBE RESP command and clean up anything necessary after the unsubscribe commoand. +/// +/// The subscription context must exist when calling this function. +fn sendUnsubscribeRequestAndCleanup( + this: *JSValkeyClient, + this_js: jsc.JSValue, + globalObject: *jsc.JSGlobalObject, + redis_channels: []JSArgument, +) !jsc.JSValue { + // Send UNSUBSCRIBE command + const command: valkey.Command = .{ + .command = "UNSUBSCRIBE", + .args = .{ .args = redis_channels }, + }; + const promise = this.send( + globalObject, + this_js, + &command, + ) catch |err| { + return protocol.valkeyErrorToJS(globalObject, "Failed to send UNSUBSCRIBE command", err); + }; + + // We do not delete the subscription context here, but rather when the + // onValkeyUnsubscribe callback is invoked. + + return promise.toJS(); +} + +pub fn unsubscribe( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + callframe: *jsc.CallFrame, +) bun.JSError!JSValue { + // Check if we're in subscription mode + try requireSubscriber(this, @src().fn_name); + + const args_view = callframe.arguments(); + + var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); + var redis_channels = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), 1); + defer { + for (redis_channels.items) |*item| { + item.deinit(); + } + redis_channels.deinit(); + } + + // If no arguments, unsubscribe from all channels + if (args_view.len == 0) { + return try sendUnsubscribeRequestAndCleanup(this, callframe.this(), globalObject, redis_channels.items); + } + + // The first argument can be a channel or an array of channels + const channel_or_many = callframe.argument(0); + + // Get the subscription context + var subscription_ctx = this._subscription_ctx orelse { + return jsc.JSPromise.resolvedPromiseValue(globalObject, .js_undefined); + }; + + // Two arguments means .unsubscribe(channel, listener) is invoked. + if (callframe.arguments().len == 2) { + // In this case, the first argument is a channel string and the second + // argument is the handler to remove. + if (!channel_or_many.isString()) { + return globalObject.throwInvalidArgumentType( + "unsubscribe", + "channel", + "string", + ); + } + + const channel = channel_or_many; + const listener_cb = callframe.argument(1); + + if (!listener_cb.isCallable()) { + return globalObject.throwInvalidArgumentType( + "unsubscribe", + "listener", + "function", + ); + } + + // Populate the redis_channels list with the single channel to + // unsubscribe from. This s important since this list is used to send + // the UNSUBSCRIBE command to redis. Without this, we would end up + // unsubscribing from all channels. + redis_channels.appendAssumeCapacity((try fromJS(globalObject, channel)) orelse { + return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string"); + }); + + const remaining_listeners = subscription_ctx.removeReceiveHandler( + globalObject, + channel, + listener_cb, + ) catch { + return globalObject.throw( + "Failed to remove handler for channel {}", + .{channel.asString().getZigString(globalObject)}, + ); + } orelse { + // Listeners weren't present in the first place, so we can return a + // resolved promise. + return jsc.JSPromise.resolvedPromiseValue(globalObject, .js_undefined); + }; + + // In this case, we only want to send the unsubscribe command to redis if there are no more listeners for this + // channel. + if (remaining_listeners == 0) { + return try sendUnsubscribeRequestAndCleanup(this, callframe.this(), globalObject, redis_channels.items); + } + + // Otherwise, in order to keep the API consistent, we need to return a resolved promise. + return jsc.JSPromise.resolvedPromiseValue(globalObject, .js_undefined); + } + + if (channel_or_many.isArray()) { + if ((try channel_or_many.getLength(globalObject)) == 0) { + return globalObject.throwInvalidArguments( + "unsubscribe requires at least one channel", + .{}, + ); + } + + try redis_channels.ensureTotalCapacity(try channel_or_many.getLength(globalObject)); + // It is an array, so let's iterate over it + var array_iter = try channel_or_many.arrayIterator(globalObject); + while (try array_iter.next()) |channel_arg| { + const channel = (try fromJS(globalObject, channel_arg)) orelse { + return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + // Clear the handlers for this channel + try subscription_ctx.clearReceiveHandlers(globalObject, channel_arg); + } + } else if (channel_or_many.isString()) { + // It is a single string channel + const channel = (try fromJS(globalObject, channel_or_many)) orelse { + return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + // Clear the handlers for this channel + try subscription_ctx.clearReceiveHandlers(globalObject, channel_or_many); + } else { + return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string or array"); + } + + // Now send the unsubscribe command and clean up if necessary + return try sendUnsubscribeRequestAndCleanup(this, callframe.this(), globalObject, redis_channels.items); +} + +pub fn duplicate( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + callframe: *jsc.CallFrame, +) bun.JSError!JSValue { + _ = callframe; + + var new_client: *JSValkeyClient = try this.cloneWithoutConnecting(globalObject); + + const new_client_js = new_client.toJS(globalObject); + new_client.this_value = + if (this.client.status == .connected and !this.client.flags.is_manually_closed) + jsc.JSRef.initStrong(new_client_js, globalObject) + else + jsc.JSRef.initWeak(new_client_js); + + // If the original client is already connected and not manually closed, start connecting the new client. + if (this.client.status == .connected and !this.client.flags.is_manually_closed) { + // Use strong reference during connection to prevent premature GC + new_client.client.flags.connection_promise_returns_client = true; + return try new_client.doConnect(globalObject, new_client_js); + } + + return jsc.JSPromise.resolvedPromiseValue(globalObject, new_client_js); +} -// publish(channel: RedisValue, message: RedisValue) // script(subcommand: "LOAD", script: RedisValue) // select(index: number | string) // spublish(shardchannel: RedisValue, message: RedisValue) @@ -645,13 +983,37 @@ pub const pubsub = compile.@"(...strings: string[])"("pubsub", "PUBSUB").call; // cluster(subcommand: "KEYSLOT", key: RedisKey) const compile = struct { + pub const ClientStateRequirement = enum { + /// The client must be a subscriber (in subscription mode). + subscriber, + /// The client must not be a subscriber (not in subscription mode). + not_subscriber, + /// We don't care about the client state (subscriber or not). + dont_care, + }; + + fn testCorrectState( + this: *JSValkeyClient, + js_client_prototype_function_name: []const u8, + comptime client_state_requirement: ClientStateRequirement, + ) bun.JSError!void { + return switch (client_state_requirement) { + .subscriber => requireSubscriber(this, js_client_prototype_function_name), + .not_subscriber => requireNotSubscriber(this, js_client_prototype_function_name), + .dont_care => {}, + }; + } + pub fn @"(key: RedisKey)"( comptime name: []const u8, comptime command: []const u8, comptime arg0_name: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType(name, arg0_name, "string or buffer"); }; @@ -676,9 +1038,12 @@ const compile = struct { comptime name: []const u8, comptime command: []const u8, comptime arg0_name: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + if (callframe.argument(0).isUndefinedOrNull()) { return globalObject.throwMissingArgumentsValue(&.{arg0_name}); } @@ -722,9 +1087,12 @@ const compile = struct { comptime command: []const u8, comptime arg0_name: []const u8, comptime arg1_name: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType(name, arg0_name, "string or buffer"); }; @@ -752,9 +1120,12 @@ const compile = struct { pub fn @"(...strings: string[])"( comptime name: []const u8, comptime command: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + var args = try std.ArrayList(JSArgument).initCapacity(bun.default_allocator, callframe.arguments().len); defer { for (args.items) |*item| { @@ -788,9 +1159,12 @@ const compile = struct { pub fn @"(key: RedisKey, value: RedisValue, ...args: RedisValue)"( comptime name: []const u8, comptime command: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + var args = try std.ArrayList(JSArgument).initCapacity(bun.default_allocator, callframe.arguments().len); defer { for (args.items) |*item| { diff --git a/src/valkey/valkey.zig b/src/valkey/valkey.zig index 87b9cea495..f2c22bb2c6 100644 --- a/src/valkey/valkey.zig +++ b/src/valkey/valkey.zig @@ -18,6 +18,16 @@ pub const ConnectionFlags = struct { is_reconnecting: bool = false, auto_pipelining: bool = true, finalized: bool = false, + // This flag is a slight hack to allow returning the client instance in the + // promise which resolves when the connection is established. There are two + // modes through which a client may connect: + // 1. Connect through `client.connect()` which has the semantics of + // resolving the promise with the connection information. + // 2. Through `client.duplicate()` which creates a promise through + // `onConnect()` which resolves with the client instance itself. + // This flag is set to true in the latter case to indicate to the promise + // resolution delegation to resolve the promise with the client. + connection_promise_returns_client: bool = false, }; /// Valkey connection status @@ -28,6 +38,13 @@ pub const Status = enum { failed, }; +pub fn isActive(this: *const Status) bool { + return switch (this.*) { + .connected, .connecting => true, + else => false, + }; +} + pub const Command = @import("./ValkeyCommand.zig"); /// Valkey protocol types (standalone, TLS, Unix socket) @@ -106,6 +123,13 @@ pub const Address = union(enum) { port: u16, }, + pub fn hostname(this: *const Address) []const u8 { + return switch (this.*) { + .unix => |unix_addr| unix_addr, + .host => |h| h.host, + }; + } + pub fn connect(this: *const Address, client: *ValkeyClient, ctx: *bun.uws.SocketContext, is_tls: bool) !uws.AnySocket { switch (is_tls) { inline else => |tls| { @@ -155,6 +179,7 @@ pub const ValkeyClient = struct { username: []const u8 = "", database: u32 = 0, address: Address, + protocol: Protocol, connection_strings: []u8 = &.{}, @@ -211,6 +236,8 @@ pub const ValkeyClient = struct { } this.allocator.free(this.connection_strings); + // Note there is no need to deallocate username, password and hostname since they are + // within the this.connection_strings buffer. this.write_buffer.deinit(this.allocator); this.read_buffer.deinit(this.allocator); this.tls.deinit(); @@ -383,14 +410,14 @@ pub const ValkeyClient = struct { /// Mark the connection as failed with error message pub fn fail(this: *ValkeyClient, message: []const u8, err: protocol.RedisError) void { - debug("failed: {s}: {s}", .{ message, @errorName(err) }); + debug("failed: {s}: {}", .{ message, err }); if (this.status == .failed) return; if (this.flags.finalized) { // We can't run promises inside finalizers. if (this.queue.count + this.in_flight.count > 0) { const vm = this.vm; - const deferred_failrue = bun.new(DeferredFailure, .{ + const deferred_failure = bun.new(DeferredFailure, .{ // This memory is not owned by us. .message = bun.handleOom(bun.default_allocator.dupe(u8, message)), @@ -401,7 +428,7 @@ pub const ValkeyClient = struct { }); this.in_flight = .init(this.allocator); this.queue = .init(this.allocator); - deferred_failrue.enqueue(); + deferred_failure.enqueue(); } // Allow the finalizer to call .close() @@ -504,9 +531,9 @@ pub const ValkeyClient = struct { } /// Process data received from socket + /// + /// Caller refs / derefs. pub fn onData(this: *ValkeyClient, data: []const u8) void { - // Caller refs / derefs. - // Path 1: Buffer already has data, append and process from buffer if (this.read_buffer.remaining().len > 0) { this.read_buffer.write(this.allocator, data) catch @panic("failed to write to read buffer"); @@ -613,6 +640,69 @@ pub const ValkeyClient = struct { // If the loop finishes, the entire 'data' was processed without needing the buffer. } + /// Try handling this response as a subscriber-state response. + /// Returns `handled` if we handled it, `fallthrough` if we did not. + fn handleSubscribeResponse(this: *ValkeyClient, value: *protocol.RESPValue, pair: *ValkeyCommand.PromisePair) bun.JSError!enum { handled, fallthrough } { + // Resolve the promise with the potentially transformed value + var promise_ptr = &pair.promise; + const globalThis = this.globalObject(); + const loop = this.vm.eventLoop(); + + debug("Handling a subscribe response: {any}", .{value.*}); + loop.enter(); + defer loop.exit(); + + return switch (value.*) { + .Error => { + promise_ptr.reject(globalThis, value.toJS(globalThis)); + return .handled; + }, + .Push => |push| { + const p = this.parent(); + const subs_ctx = try p.getOrCreateSubscriptionCtxEnteringSubscriptionMode(); + const sub_count = try subs_ctx.channelsSubscribedToCount(globalThis); + + if (protocol.SubscriptionPushMessage.map.get(push.kind)) |msg_type| { + switch (msg_type) { + .subscribe => { + this.onValkeySubscribe(value); + promise_ptr.promise.resolve(globalThis, .jsNumber(sub_count)); + return .handled; + }, + .unsubscribe => { + try this.onValkeyUnsubscribe(); + promise_ptr.promise.resolve(globalThis, .js_undefined); + return .handled; + }, + else => { + // Other push messages (message, pmessage, etc) are not handled here + @branchHint(.cold); + this.fail( + "Push message is not a subscription message.", + protocol.RedisError.InvalidResponseType, + ); + return .fallthrough; + }, + } + } else { + // We should rarely reach this point. If we're guaranteed to be handling a subscribe/unsubscribe, + // then this is an unexpected path. + @branchHint(.cold); + this.fail( + "Push message is not a subscription message.", + protocol.RedisError.InvalidResponseType, + ); + return .handled; + } + }, + else => { + // This may be a regular command response. Let's pass it down + // to the next handler. + return .fallthrough; + }, + }; + } + fn handleHelloResponse(this: *ValkeyClient, value: *protocol.RESPValue) void { debug("Processing HELLO response", .{}); @@ -705,9 +795,68 @@ pub const ValkeyClient = struct { }, }; } + // Let's load the promise pair. + var pair_maybe = this.in_flight.readItem(); + + // We handle subscriptions specially because they are not regular + // commands and their failure will potentially cause the client to drop + // out of subscriber mode. + if (this.parent().isSubscriber()) { + debug("This client is a subscriber. Handling as subscriber...", .{}); + + // There are multiple different commands we may receive in + // subscriber mode. One is from a client.subscribe() call which + // requires that a promise is in-flight, but otherwise, we may also + // receive push messages from the server that do not have an + // associated promise. + if (pair_maybe) |*pair| { + debug("There is a request in flight. Handling as a subscribe request...", .{}); + if ((try this.handleSubscribeResponse(value, pair)) == .handled) { + return; + } + } + + switch (value.*) { + .Error => |err| { + this.fail(err, protocol.RedisError.InvalidResponse); + return; + }, + .Push => |push| { + if (protocol.SubscriptionPushMessage.map.get(push.kind)) |msg_type| { + switch (msg_type) { + .message => { + @branchHint(.likely); + debug("Received a message.", .{}); + this.onValkeyMessage(push.data); + return; + }, + else => { + @branchHint(.cold); + debug("Received non-message push without promise: {any}", .{push.data}); + return; + }, + } + } else { + @branchHint(.cold); + this.fail( + "Unexpected push message kind without promise", + protocol.RedisError.InvalidResponseType, + ); + return; + } + }, + else => { + // In the else case, we fall through to the regular + // handler. Subscribers can send .Push commands which have + // the same semantics as regular commands. + }, + } + + debug("Treating subscriber response as a regular command...", .{}); + } // For regular commands, get the next command+promise pair from the queue - var pair = this.in_flight.readItem() orelse { + var pair = pair_maybe orelse { debug("Received response but no promise in queue", .{}); return; }; @@ -932,7 +1081,14 @@ pub const ValkeyClient = struct { if (this.flags.enable_offline_queue) { try this.enqueue(command, &promise); } else { - promise.reject(globalThis, globalThis.ERR(.REDIS_CONNECTION_CLOSED, "Connection is closed and offline queue is disabled", .{}).toJS()); + promise.reject( + globalThis, + globalThis.ERR( + .REDIS_CONNECTION_CLOSED, + "Connection is closed and offline queue is disabled", + .{}, + ).toJS(), + ); } }, .failed => { @@ -985,6 +1141,18 @@ pub const ValkeyClient = struct { this.parent().onValkeyConnect(value); } + pub fn onValkeySubscribe(this: *ValkeyClient, value: *protocol.RESPValue) void { + this.parent().onValkeySubscribe(value); + } + + pub fn onValkeyUnsubscribe(this: *ValkeyClient) bun.JSError!void { + return this.parent().onValkeyUnsubscribe(); + } + + pub fn onValkeyMessage(this: *ValkeyClient, value: []protocol.RESPValue) void { + this.parent().onValkeyMessage(value); + } + pub fn onValkeyReconnect(this: *ValkeyClient) void { this.parent().onValkeyReconnect(); } @@ -999,9 +1167,9 @@ pub const ValkeyClient = struct { }; // Auto-pipelining - const debug = bun.Output.scoped(.Redis, .visible); +const ValkeyCommand = @import("./ValkeyCommand.zig"); const protocol = @import("./valkey_protocol.zig"); const std = @import("std"); diff --git a/src/valkey/valkey_protocol.zig b/src/valkey/valkey_protocol.zig index a02514f7c6..ed63377925 100644 --- a/src/valkey/valkey_protocol.zig +++ b/src/valkey/valkey_protocol.zig @@ -656,6 +656,18 @@ pub const Attribute = struct { } }; +pub const SubscriptionPushMessage = enum(u2) { + message, + subscribe, + unsubscribe, + + pub const map = bun.ComptimeStringMap(SubscriptionPushMessage, .{ + .{ "message", .message }, + .{ "subscribe", .subscribe }, + .{ "unsubscribe", .unsubscribe }, + }); +}; + const std = @import("std"); const bun = @import("bun"); diff --git a/test/_util/numeric.ts b/test/_util/numeric.ts new file mode 100644 index 0000000000..2b4729acec --- /dev/null +++ b/test/_util/numeric.ts @@ -0,0 +1,192 @@ +/** + * Parameter accepted by some of the algorithms in this namespace which + * controls the input/output format of numbers. + * + * @example + * ```ts + * // Returns an integer + * numeric.random.between(0, 10, { domain: "integral" }); + * // Returns a floating point number + * numeric.random.between(0, 10, { domain: "floating" }); + * ``` + */ +export type FormatSpecifier = { + domain: "floating" | "integral"; +}; + +const DefaultFormatSpecifier: FormatSpecifier = { + domain: "floating", +}; + +/** + * Generate an array of evenly-spaced numbers in a range. + * + * The name iota comes from https://aplwiki.com/wiki/Index_Generator. It is + * commonly used across programming languages and libraries. + * + * @param count The total number of points to generate. + * @param step The step size between each value. + * @returns An array of evenly-spaced numbers. + */ +export function iota(count: number, step: number = 1) { + return Array.from({ length: count }, (_, i) => i * step); +} + +/** + * Create an array of linearly spaced numbers. + * + * @param start The starting value of the sequence. + * @param end The end value of the sequence. + * @param numPoints The number of points to generate. + * + * @returns An array of numbers, spaced evenly in the linear space. + */ +export function linSpace(start: number, end: number, numPoints: number): number[] { + if (numPoints <= 0) return []; + if (numPoints === 1) return [start]; + if (numPoints === 2) return [start, end]; + const step = (end - start) / (numPoints - 1); + + return iota(numPoints).map(i => start + i * step); +} + +/** + * Create an array of exponentially spaced numbers. + * + * @param start The starting value of the sequence. + * @param end The end value of the sequence. + * @param numPoints The number of points to generate. + * @param base The exponential base + * + * @returns An array of numbers, spaced evenly in the exponential space. + */ +export function expSpace(start: number, end: number, numPoints: number, base: number): number[] { + if (numPoints <= 0) return []; + if (numPoints === 1) return [start]; + + if (!Number.isFinite(base) || base <= 0 || base === 1) { + throw new Error('expSpace: "base" must be > 0 and !== 1'); + } + + // Generate exponentially spaced values from 0 to 1 + const exponentialValues = Array.from( + { length: numPoints }, + (_, i) => (Math.pow(base, i / (numPoints - 1)) - 1) / (base - 1), + ); + + // Scale and shift to fit the [start, end] range + return exponentialValues.map(t => start + t * (end - start)); +} + +export namespace stats { + /** + * Computes the Pearson correlation coefficient between two arrays of numbers. + * + * The Pearson correlation coefficient, also known as Pearson's r, is a + * statistical measure that quantifies the strength and direction of a linear + * relationship between two variables. + * + * @param xs The first array of numbers. + * @param ys The second array of numbers. + * @returns The Pearson correlation coefficient, or 0 if there is no correlation. + */ + export function computePearsonCorrelation(xs: number[], ys: number[]): number { + if (xs.length !== ys.length || xs.length === 0) { + throw new Error("Input arrays must have the same non-zero length"); + } + + const n = xs.length; + const sumX = xs.reduce((a, b) => a + b, 0); + const sumY = ys.reduce((a, b) => a + b, 0); + const sumXY = xs.reduce((sum, x, i) => sum + x * ys[i], 0); + const sumX2 = xs.reduce((sum, x) => sum + x * x, 0); + const sumY2 = ys.reduce((sum, y) => sum + y * y, 0); + + // Compute the Pearson correlation coefficient (r) using the formula: + // r = (n * Σ(xy) - Σx * Σy) / sqrt[(n * Σ(x^2) - (Σx)^2) * (n * Σ(y^2) - (Σy)^2)] + const numerator = n * sumXY - sumX * sumY; + const denominator = Math.sqrt((n * sumX2 - sumX * sumX) * (n * sumY2 - sumY * sumY)); + + if (denominator === 0) { + return 0; // Avoid division by zero; implies no correlation + } + + return numerator / denominator; + } + + /** + * Compute the slope of the best-fit line using linear regression. + * + * @param xs The random variable. + * @param ys The dependent variable. + * @returns The slope of the best-fit line. + */ + export function computeLinearSlope(xs: number[], ys: number[]): number { + if (xs.length !== ys.length || xs.length === 0) { + throw new Error("Input arrays must have the same non-zero length"); + } + + const n = xs.length; + const sumX = xs.reduce((a, b) => a + b, 0); + const sumY = ys.reduce((a, b) => a + b, 0); + const sumXY = xs.reduce((sum, x, i) => sum + x * ys[i], 0); + const sumX2 = xs.reduce((sum, x) => sum + x * x, 0); + + // Compute the slope (m) using the formula: + // m = (n * Σ(xy) - Σx * Σy) / (n * Σ(x^2) - (Σx)^2) + const slope = (n * sumXY - sumX * sumY) / (n * sumX2 - sumX * sumX); + return slope; + } + + /** + * Compute euclidean the mean (average) of an array of numbers. + * + * @param xs An array of numbers. + * @returns The mean of the numbers. + */ + export function computeMean(xs: number[]): number { + return xs.reduce((a, b) => a + b, 0) / xs.length; + } + + /** + * Compute the average absolute deviation of an array of numbers. + * + * The average absolute deviation (AAD) of a data set is the average of the + * absolute deviations from a central point. + * + * @param xs An array of numbers. + * @returns The average absolute deviation of the numbers. + */ + export function computeAverageAbsoluteDeviation(xs: number[]): number { + const mean = computeMean(xs); + return xs.reduce((sum, x) => sum + Math.abs(x - mean), 0) / xs.length; + } +} + +/** + * Utilities for numeric randomness. + * + * @todo Perhaps this does not belong in the numeric namespace. + */ +export namespace random { + /** + * Generate a random number between the specified range. + * + * @param min The minimum value (inclusive for integrals). + * @param max The maximum value (inclusive for integrals). + * @param format The format specifier for the random number. + * @returns A random number between min and max, formatted according to the specifier. + */ + export function between(min: number, max: number, format: FormatSpecifier = DefaultFormatSpecifier): number { + if (!Number.isFinite(min) || !Number.isFinite(max)) throw new Error("min/max must be finite"); + if (max < min) throw new Error("max must be >= min"); + + if (format.domain === "floating") { + return Math.random() * (max - min) + min; + } + + const lo = Math.ceil(min); + const hi = Math.floor(max); + return Math.floor(Math.random() * (hi - lo + 1)) + lo; + } +} diff --git a/test/harness.ts b/test/harness.ts index 9b46505c57..f54cf878a3 100644 --- a/test/harness.ts +++ b/test/harness.ts @@ -7,12 +7,13 @@ import { gc as bunGC, sleepSync, spawnSync, unsafe, which, write } from "bun"; import { heapStats } from "bun:jsc"; -import { afterAll, beforeAll, describe, expect, test } from "bun:test"; +import { beforeAll, describe, expect } from "bun:test"; import { ChildProcess, execSync, fork } from "child_process"; import { readdir, readFile, readlink, rm, writeFile } from "fs/promises"; import fs, { closeSync, openSync, rmSync } from "node:fs"; import os from "node:os"; import { dirname, isAbsolute, join } from "path"; +import * as numeric from "_util/numeric.ts"; type Awaitable = T | Promise; @@ -357,7 +358,7 @@ export function bunRunAsScript( } export function randomLoneSurrogate() { - const n = randomRange(0, 2); + const n = numeric.random.between(0, 2, { domain: "integral" }); if (n === 0) return randomLoneHighSurrogate(); return randomLoneLowSurrogate(); } @@ -370,16 +371,12 @@ export function randomInvalidSurrogatePair() { // Generates a random lone high surrogate (from the range D800-DBFF) export function randomLoneHighSurrogate() { - return String.fromCharCode(randomRange(0xd800, 0xdbff)); + return String.fromCharCode(numeric.random.between(0xd800, 0xdbff, { domain: "integral" })); } // Generates a random lone high surrogate (from the range DC00-DFFF) export function randomLoneLowSurrogate() { - return String.fromCharCode(randomRange(0xdc00, 0xdfff)); -} - -function randomRange(low: number, high: number): number { - return low + Math.floor(Math.random() * (high - low)); + return String.fromCharCode(numeric.random.between(0xdc00, 0xdfff, { domain: "integral" })); } export function runWithError(cb: () => unknown): Error | undefined { diff --git a/test/integration/bun-types/fixture/redis.ts b/test/integration/bun-types/fixture/redis.ts new file mode 100644 index 0000000000..f7e536200d --- /dev/null +++ b/test/integration/bun-types/fixture/redis.ts @@ -0,0 +1,31 @@ +import { expectType } from "./utilities"; + +expectType(Bun.redis.publish("hello", "world")).is>(); + +const copy = await Bun.redis.duplicate(); +expectType(copy.connected).is(); +expectType(copy).is(); + +const listener: Bun.RedisClient.StringPubSubListener = (message, channel) => { + expectType(message).is(); + expectType(channel).is(); +}; +Bun.redis.subscribe("hello", listener); + +// Buffer subscriptions are not yet implemented +// const bufferListener: Bun.RedisClient.BufferPubSubListener = (message, channel) => { +// expectType(message).is>(); +// expectType(channel).is(); +// }; +// Bun.redis.subscribe("hello", bufferListener); + +expectType( + copy.subscribe("hello", message => { + expectType(message).is(); + }), +).is>(); + +await copy.unsubscribe(); +await copy.unsubscribe("hello"); + +expectType(copy.unsubscribe("hello", () => {})).is>(); diff --git a/test/js/valkey/test-utils.ts b/test/js/valkey/test-utils.ts index 34ce3347cc..04deb19046 100644 --- a/test/js/valkey/test-utils.ts +++ b/test/js/valkey/test-utils.ts @@ -5,8 +5,6 @@ import path from "path"; import * as dockerCompose from "../../docker/index.ts"; import { UnixDomainSocketProxy } from "../../unix-domain-socket-proxy.ts"; -import * as fs from "node:fs"; -import * as os from "node:os"; const dockerCLI = dockerExe() as string; export const isEnabled = @@ -165,7 +163,7 @@ async function startContainer(): Promise { REDIS_PORT = port; REDIS_TLS_PORT = tlsPort; REDIS_HOST = redisInfo.host; - REDIS_UNIX_SOCKET = unixSocketProxy.path; // Use the proxy socket + REDIS_UNIX_SOCKET = unixSocketProxy.path; // Use the proxy socket DEFAULT_REDIS_URL = `redis://${REDIS_HOST}:${REDIS_PORT}`; TLS_REDIS_URL = `rediss://${REDIS_HOST}:${REDIS_TLS_PORT}`; UNIX_REDIS_URL = `redis+unix://${REDIS_UNIX_SOCKET}`; @@ -223,12 +221,12 @@ import { tmpdir } from "os"; * Create a new client with specific connection type */ export function createClient( - connectionType: ConnectionType = ConnectionType.TCP, - customOptions = {}, - dbId: number | undefined = undefined, + connectionType: ConnectionType = ConnectionType.TCP, + customOptions = {}, + dbId: number | undefined = undefined, ) { let url: string; - const mkUrl = (baseUrl: string) => dbId ? `${baseUrl}/${dbId}`: baseUrl; + const mkUrl = (baseUrl: string) => (dbId ? `${baseUrl}/${dbId}` : baseUrl); let options: any = {}; context.id++; @@ -314,6 +312,9 @@ export interface TestContext { redisWriteOnly?: RedisClient; id: number; restartServer: () => Promise; + __subscriberClientPool: RedisClient[]; + newSubscriberClient: (connectionType: ConnectionType) => Promise; + cleanupSubscribers: () => Promise; } // Create a singleton promise for Docker initialization @@ -336,10 +337,30 @@ export const context: TestContext = { redisWriteOnly: undefined, id, restartServer: restartRedisContainer, + __subscriberClientPool: [], + newSubscriberClient: async function (connectionType: ConnectionType) { + const client = createClient(connectionType); + this.__subscriberClientPool.push(client); + await client.connect(); + return client; + }, + cleanupSubscribers: async function () { + for (const client of this.__subscriberClientPool) { + try { + await client.unsubscribe(); + } catch {} + + if (client.connected) { + client.close(); + } + } + + this.__subscriberClientPool = []; + }, }; export { context as ctx }; -if (isEnabled) +if (isEnabled) { beforeAll(async () => { // Initialize Docker container once for all tests if (!dockerInitPromise) { @@ -405,8 +426,9 @@ if (isEnabled) // console.warn("Test initialization failed - tests may be skipped"); // } }); +} -if (isEnabled) +if (isEnabled) { afterAll(async () => { console.log("Cleaning up Redis container"); if (!context.redis?.connected) { @@ -426,26 +448,26 @@ if (isEnabled) } // Disconnect all clients - await context.redis.close(); + context.redis.close(); if (context.redisTLS) { - await context.redisTLS.close(); + context.redisTLS.close(); } if (context.redisUnix) { - await context.redisUnix.close(); + context.redisUnix.close(); } if (context.redisAuth) { - await context.redisAuth.close(); + context.redisAuth.close(); } if (context.redisReadOnly) { - await context.redisReadOnly.close(); + context.redisReadOnly.close(); } if (context.redisWriteOnly) { - await context.redisWriteOnly.close(); + context.redisWriteOnly.close(); } // Clean up Unix socket proxy if it exists @@ -456,6 +478,7 @@ if (isEnabled) console.error("Error during test cleanup:", err); } }); +} if (!isEnabled) { console.warn("Redis is not enabled, skipping tests"); @@ -545,6 +568,14 @@ async function getRedisContainerName(): Promise { /** * Restart the Redis container to simulate connection drop + * + * Restarts the container identified by the test harness and waits briefly for it + * to come back online (approximately 2 seconds). Use this to simulate a server + * restart or connection drop during tests. + * + * @returns A promise that resolves when the restart and short wait complete. + * @throws If the Docker restart command exits with a non-zero code; the error + * message includes the container's stderr output. */ export async function restartRedisContainer(): Promise { // If using docker-compose, get the actual container name @@ -611,3 +642,50 @@ export async function restartRedisContainer(): Promise { } } } + +/** + * @returns true or false with approximately equal probability + */ +export function randomCoinFlip(): boolean { + return Math.floor(Math.random() * 2) == 0; +} + +/** + * Utility for creating a counter that can be awaited until it reaches a target value. + */ +export function awaitableCounter(timeoutMs: number = 1000) { + let activeResolvers: [number, NodeJS.Timeout, (value: number) => void][] = []; + let currentCount = 0; + + return { + increment: () => { + currentCount++; + + for (const [value, alarm, resolve] of activeResolvers) { + alarm.close(); + + if (currentCount >= value) { + resolve(currentCount); + } + } + + // Remove resolved promises + activeResolvers = activeResolvers.filter(([value]) => currentCount < value); + }, + count: () => currentCount, + + untilValue: (value: number) => + new Promise((resolve, reject) => { + if (currentCount >= value) { + resolve(currentCount); + return; + } + + const alarm = setTimeout(() => { + reject(new Error(`Timeout waiting for counter to reach ${value}, current is ${currentCount}.`)); + }, timeoutMs); + + activeResolvers.push([value, alarm, resolve]); + }), + }; +} diff --git a/test/js/valkey/valkey.failing-subscriber.ts b/test/js/valkey/valkey.failing-subscriber.ts new file mode 100644 index 0000000000..872ea5a072 --- /dev/null +++ b/test/js/valkey/valkey.failing-subscriber.ts @@ -0,0 +1,48 @@ +// Program which sets up a subscriber outside the scope of the main Jest process. +// Used within valkey.test.ts. +// +// DO NOT IMPORT FROM test-utils.ts. That import is janky and will have different state at different from different +// importers. +import {RedisClient} from "bun"; + +function trySend(msg: any) { + if (process === undefined || process.send === undefined) { + throw new Error("process is undefined"); + } + + process.send(msg); +} + +let redisUrlResolver: (url: string) => void; +const redisUrl = new Promise((resolve) => { + redisUrlResolver = resolve; +}); + +process.on("message", (msg: any) => { + if (msg.event === "start") { + redisUrlResolver(msg.url); + } else { + throw new Error("Unknown event " + msg.event); + } +}); + +const CHANNEL = "error-callback-channel"; + +// We will wait for the parent process to tell us to start. +const url = await redisUrl; +const subscriber = new RedisClient(url); +await subscriber.connect(); +trySend({ event: "ready" }); + +let counter = 0; +await subscriber.subscribe(CHANNEL, () => { + if ((counter++) === 1) { + throw new Error("Intentional callback error"); + } + + trySend({ event: "message", index: counter }); +}); + +process.on("uncaughtException", (e) => { + trySend({ event: "exception", exMsg: e.message }); +}); diff --git a/test/js/valkey/valkey.test.ts b/test/js/valkey/valkey.test.ts index a70c9659b8..11b1902fbd 100644 --- a/test/js/valkey/valkey.test.ts +++ b/test/js/valkey/valkey.test.ts @@ -1,12 +1,14 @@ -import { randomUUIDv7, RedisClient } from "bun"; +import { randomUUIDv7, RedisClient, spawn } from "bun"; import { beforeAll, beforeEach, describe, expect, test } from "bun:test"; import { + awaitableCounter, ConnectionType, createClient, ctx, DEFAULT_REDIS_URL, expectType, isEnabled, + randomCoinFlip, setupDockerContainer, } from "./test-utils"; @@ -26,6 +28,7 @@ describe.skipIf(!isEnabled)("Valkey Redis Client", () => { } // Flush all data for clean test state + await ctx.redis.connect(); await ctx.redis.send("FLUSHALL", ["SYNC"]); }); @@ -40,11 +43,11 @@ describe.skipIf(!isEnabled)("Valkey Redis Client", () => { expect(setResult).toMatchInlineSnapshot(`"OK"`); const setResult2 = await redis.set(testKey, testValue, "GET"); - expect(setResult2).toMatchInlineSnapshot(`"Hello from Bun Redis!"`); + expect(setResult2).toMatchInlineSnapshot(`"${testValue}"`); // GET should return the value we set const getValue = await redis.get(testKey); - expect(getValue).toMatchInlineSnapshot(`"Hello from Bun Redis!"`); + expect(getValue).toMatchInlineSnapshot(`"${testValue}"`); }); test("should test key existence", async () => { @@ -235,4 +238,572 @@ describe.skipIf(!isEnabled)("Valkey Redis Client", () => { expect(valueAfterStop).toBe(TEST_VALUE); }); }); + + describe("PUB/SUB", () => { + var i = 0; + const testChannel = () => { + return `test-channel-${i++}`; + }; + const testKey = () => { + return `test-key-${i++}`; + }; + const testValue = () => { + return `test-value-${i++}`; + }; + const testMessage = () => { + return `test-message-${i++}`; + }; + + beforeEach(async () => { + // The PUB/SUB tests expect that ctx.redis is connected but not in subscriber mode. + await ctx.cleanupSubscribers(); + }); + + test("publishing to a channel does not fail", async () => { + expect(await ctx.redis.publish(testChannel(), testMessage())).toBe(0); + }); + + test("setting in subscriber mode gracefully fails", async () => { + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + + await subscriber.subscribe(testChannel(), () => {}); + + expect(() => subscriber.set(testKey(), testValue())).toThrow( + "RedisClient.prototype.set cannot be called while in subscriber mode", + ); + + await subscriber.unsubscribe(testChannel()); + }); + + test("setting after unsubscribing works", async () => { + const channel = testChannel(); + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + await subscriber.subscribe(channel, () => {}); + await subscriber.unsubscribe(channel); + expect(ctx.redis.set(testKey(), testValue())).resolves.toEqual("OK"); + }); + + test("subscribing to a channel receives messages", async () => { + const TEST_MESSAGE_COUNT = 128; + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + const channel = testChannel(); + const message = testMessage(); + + const counter = awaitableCounter(); + await subscriber.subscribe(channel, (message, channel) => { + counter.increment(); + expect(channel).toBe(channel); + expect(message).toBe(message); + }); + + Array.from({ length: TEST_MESSAGE_COUNT }).forEach(async () => { + expect(await ctx.redis.publish(channel, message)).toBe(1); + }); + + await counter.untilValue(TEST_MESSAGE_COUNT); + expect(counter.count()).toBe(TEST_MESSAGE_COUNT); + }); + + test("messages are received in order", async () => { + const channel = testChannel(); + + await ctx.redis.set("START-TEST", "1"); + const TEST_MESSAGE_COUNT = 1024; + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + + const counter = awaitableCounter(); + var receivedMessages: string[] = []; + await subscriber.subscribe(channel, message => { + receivedMessages.push(message); + counter.increment(); + }); + + const sentMessages = Array.from({ length: TEST_MESSAGE_COUNT }).map(() => { + return randomUUIDv7(); + }); + await Promise.all( + sentMessages.map(async message => { + expect(await ctx.redis.publish(channel, message)).toBe(1); + }), + ); + + await counter.untilValue(TEST_MESSAGE_COUNT); + expect(receivedMessages.length).toBe(sentMessages.length); + expect(receivedMessages).toEqual(sentMessages); + + await subscriber.unsubscribe(channel); + + await ctx.redis.set("STOP-TEST", "1"); + }); + + test("subscribing to multiple channels receives messages", async () => { + const TEST_MESSAGE_COUNT = 128; + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + + const channels = [testChannel(), testChannel()]; + const counter = awaitableCounter(); + + var receivedMessages: { [channel: string]: string[] } = {}; + await subscriber.subscribe(channels, (message, channel) => { + receivedMessages[channel] = receivedMessages[channel] || []; + receivedMessages[channel].push(message); + counter.increment(); + }); + + var sentMessages: { [channel: string]: string[] } = {}; + for (let i = 0; i < TEST_MESSAGE_COUNT; i++) { + const channel = channels[randomCoinFlip() ? 0 : 1]; + const message = randomUUIDv7(); + + expect(await ctx.redis.publish(channel, message)).toBe(1); + + sentMessages[channel] = sentMessages[channel] || []; + sentMessages[channel].push(message); + } + + await counter.untilValue(TEST_MESSAGE_COUNT); + + // Check that we received messages on both channels + expect(Object.keys(receivedMessages).sort()).toEqual(Object.keys(sentMessages).sort()); + + // Check messages match for each channel + for (const channel of channels) { + if (sentMessages[channel]) { + expect(receivedMessages[channel]).toEqual(sentMessages[channel]); + } + } + + await subscriber.unsubscribe(channels); + }); + + test("unsubscribing from specific channels while remaining subscribed to others", async () => { + const channel1 = "channel-1"; + const channel2 = "channel-2"; + const channel3 = "channel-3"; + + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + let receivedMessages: { [channel: string]: string[] } = {}; + + // Total counter for all messages we expect to receive: 3 initial + 2 after unsubscribe = 5 total + const counter = awaitableCounter(); + + // Subscribe to three channels + await subscriber.subscribe([channel1, channel2, channel3], (message, channel) => { + receivedMessages[channel] = receivedMessages[channel] || []; + receivedMessages[channel].push(message); + counter.increment(); + }); + + // Send initial messages to all channels + expect(await ctx.redis.publish(channel1, "msg1-before")).toBe(1); + expect(await ctx.redis.publish(channel2, "msg2-before")).toBe(1); + expect(await ctx.redis.publish(channel3, "msg3-before")).toBe(1); + + // Wait for initial messages, then unsubscribe from channel2 + await counter.untilValue(3); + await subscriber.unsubscribe(channel2); + + // Send messages after unsubscribing from channel2 + expect(await ctx.redis.publish(channel1, "msg1-after")).toBe(1); + expect(await ctx.redis.publish(channel2, "msg2-after")).toBe(0); + expect(await ctx.redis.publish(channel3, "msg3-after")).toBe(1); + + await counter.untilValue(5); + + // Check we received messages only on subscribed channels + expect(receivedMessages[channel1]).toEqual(["msg1-before", "msg1-after"]); + expect(receivedMessages[channel2]).toEqual(["msg2-before"]); // No "msg2-after" + expect(receivedMessages[channel3]).toEqual(["msg3-before", "msg3-after"]); + + await subscriber.unsubscribe([channel1, channel3]); + }); + + test("subscribing to the same channel multiple times", async () => { + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + const channel = testChannel(); + + const counter = awaitableCounter(); + + let callCount = 0; + const listener = () => { + callCount++; + counter.increment(); + }; + + let callCount2 = 0; + const listener2 = () => { + callCount2++; + counter.increment(); + }; + + // Subscribe to the same channel twice + await subscriber.subscribe(channel, listener); + await subscriber.subscribe(channel, listener2); + + // Publish a single message + expect(await ctx.redis.publish(channel, "test-message")).toBe(1); + + await counter.untilValue(2); + + // Both listeners should have been called once. + expect(callCount).toBe(1); + expect(callCount2).toBe(1); + + await subscriber.unsubscribe(channel); + }); + + test("empty string messages", async () => { + const channel = "empty-message-channel"; + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + const counter = awaitableCounter(); + let receivedMessage: string | undefined = undefined; + await subscriber.subscribe(channel, message => { + receivedMessage = message; + counter.increment(); + }); + + expect(await ctx.redis.publish(channel, "")).toBe(1); + await counter.untilValue(1); + + expect(receivedMessage).not.toBeUndefined(); + expect(receivedMessage!).toBe(""); + + await subscriber.unsubscribe(channel); + }); + + test("special characters in channel names", async () => { + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + const specialChannels = [ + "channel:with:colons", + "channel with spaces", + "channel-with-unicode-😀", + "channel[with]brackets", + "channel@with#special$chars", + ]; + + for (const channel of specialChannels) { + const counter = awaitableCounter(); + let received = false; + await subscriber.subscribe(channel, () => { + received = true; + counter.increment(); + }); + + expect(await ctx.redis.publish(channel, "test")).toBe(1); + await counter.untilValue(1); + + expect(received).toBe(true); + await subscriber.unsubscribe(channel); + } + }); + + test("ping works in subscription mode", async () => { + const channel = "ping-test-channel"; + + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + await subscriber.subscribe(channel, () => {}); + + // Ping should work in subscription mode + const pong = await subscriber.ping(); + expect(pong).toBe("PONG"); + + const customPing = await subscriber.ping("hello"); + expect(customPing).toBe("hello"); + }); + + test("publish does not work from a subscribed client", async () => { + const channel = "self-publish-channel"; + + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + await subscriber.subscribe(channel, () => {}); + + // Publishing from the same client should work + expect(async () => subscriber.publish(channel, "self-published")).toThrow(); + }); + + test("complete unsubscribe restores normal command mode", async () => { + const channel = "restore-test-channel"; + const testKey = "restore-test-key"; + + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + await subscriber.subscribe(channel, () => {}); + + // Should fail in subscription mode + expect(() => subscriber.set(testKey, testValue())).toThrow( + "RedisClient.prototype.set cannot be called while in subscriber mode.", + ); + + // Unsubscribe from all channels + await subscriber.unsubscribe(); + + // Should work after unsubscribing + const result = await ctx.redis.set(testKey, "value"); + expect(result).toBe("OK"); + + const value = await ctx.redis.get(testKey); + expect(value).toBe("value"); + }); + + test("publishing without subscribers succeeds", async () => { + const channel = "no-subscribers-channel"; + + // Publishing without subscribers should not throw + expect(await ctx.redis.publish(channel, "message")).toBe(0); + }); + + test("unsubscribing from non-subscribed channels", async () => { + const channel = "never-subscribed-channel"; + + expect(() => ctx.redis.unsubscribe(channel)).toThrow( + "RedisClient.prototype.unsubscribe can only be called while in subscriber mode.", + ); + }); + + test("callback errors don't crash the client", async () => { + const channel = "error-callback-channel"; + + const STEP_SUBSCRIBED = 1; + const STEP_FIRST_MESSAGE = 2; + const STEP_SECOND_MESSAGE = 3; + const STEP_THIRD_MESSAGE = 4; + + // stepCounter is a slight hack to track the progress of the subprocess. + const stepCounter = awaitableCounter(); + let currentMessage: any = {}; + + const subscriberProc = spawn({ + cmd: [self.process.execPath, "run", `${__dirname}/valkey.failing-subscriber.ts`], + stdout: "inherit", + stderr: "inherit", + ipc: msg => { + currentMessage = msg; + stepCounter.increment(); + }, + env: { + ...process.env, + NODE_ENV: "development", + }, + }); + + subscriberProc.send({ event: "start", url: DEFAULT_REDIS_URL }); + + try { + await stepCounter.untilValue(STEP_SUBSCRIBED); + expect(currentMessage.event).toBe("ready"); + + // Send multiple messages + expect(await ctx.redis.publish(channel, "message1")).toBe(1); + await stepCounter.untilValue(STEP_FIRST_MESSAGE); + expect(currentMessage.event).toBe("message"); + expect(currentMessage.index).toBe(1); + + // Now, the subscriber process will crash + expect(await ctx.redis.publish(channel, "message2")).toBe(1); + await stepCounter.untilValue(STEP_SECOND_MESSAGE); + expect(currentMessage.event).toBe("exception"); + //expect(currentMessage.index).toBe(2); + + // But it should recover and continue receiving messages + expect(await ctx.redis.publish(channel, "message3")).toBe(1); + await stepCounter.untilValue(STEP_THIRD_MESSAGE); + expect(currentMessage.event).toBe("message"); + expect(currentMessage.index).toBe(3); + } finally { + subscriberProc.kill(); + } + }); + + test("subscriptions return correct counts", async () => { + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + expect(await subscriber.subscribe("chan1", () => {})).toBe(1); + expect(await subscriber.subscribe("chan2", () => {})).toBe(2); + }); + + test("unsubscribing from listeners", async () => { + const channel = "error-callback-channel"; + + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + // First phase: both listeners should receive 1 message each (2 total) + const counter = awaitableCounter(); + let messageCount1 = 0; + const listener1 = () => { + messageCount1++; + counter.increment(); + }; + await subscriber.subscribe(channel, listener1); + + let messageCount2 = 0; + const listener2 = () => { + messageCount2++; + counter.increment(); + }; + await subscriber.subscribe(channel, listener2); + + await ctx.redis.publish(channel, "message1"); + await counter.untilValue(2); + + expect(messageCount1).toBe(1); + expect(messageCount2).toBe(1); + + console.log("Unsubscribing listener2"); + await subscriber.unsubscribe(channel, listener2); + + await ctx.redis.publish(channel, "message1"); + await counter.untilValue(3); + + expect(messageCount1).toBe(2); + expect(messageCount2).toBe(1); + }); + }); + + describe("duplicate()", () => { + test("should create duplicate of connected client that gets connected", async () => { + const duplicate = await ctx.redis.duplicate(); + + expect(duplicate.connected).toBe(true); + expect(duplicate).not.toBe(ctx.redis); + + // Both should work independently + await ctx.redis.set("test-original", "original-value"); + await duplicate.set("test-duplicate", "duplicate-value"); + + expect(await ctx.redis.get("test-duplicate")).toBe("duplicate-value"); + expect(await duplicate.get("test-original")).toBe("original-value"); + + duplicate.close(); + }); + + test("should preserve connection configuration in duplicate", async () => { + await ctx.redis.connect(); + + const duplicate = await ctx.redis.duplicate(); + + // Both clients should be able to perform the same operations + const testKey = `duplicate-config-test-${randomUUIDv7().substring(0, 8)}`; + const testValue = "test-value"; + + await ctx.redis.set(testKey, testValue); + const retrievedValue = await duplicate.get(testKey); + + expect(retrievedValue).toBe(testValue); + + duplicate.close(); + }); + + test("should allow duplicate to work independently from original", async () => { + const duplicate = await ctx.redis.duplicate(); + + // Close original, duplicate should still work + duplicate.close(); + + const testKey = `independent-test-${randomUUIDv7().substring(0, 8)}`; + const testValue = "independent-value"; + + await ctx.redis.set(testKey, testValue); + const retrievedValue = await ctx.redis.get(testKey); + + expect(retrievedValue).toBe(testValue); + }); + + test("should handle duplicate of client in subscriber mode", async () => { + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + + const testChannel = "test-subscriber-duplicate"; + + // Put original client in subscriber mode + await subscriber.subscribe(testChannel, () => {}); + + const duplicate = await subscriber.duplicate(); + + // Duplicate should not be in subscriber mode + expect(() => duplicate.set("test-key", "test-value")).not.toThrow(); + + await subscriber.unsubscribe(testChannel); + }); + + test("should create multiple duplicates from same client", async () => { + await ctx.redis.connect(); + + const duplicate1 = await ctx.redis.duplicate(); + const duplicate2 = await ctx.redis.duplicate(); + const duplicate3 = await ctx.redis.duplicate(); + + // All should be connected + expect(duplicate1.connected).toBe(true); + expect(duplicate2.connected).toBe(true); + expect(duplicate3.connected).toBe(true); + + // All should work independently + const testKey = `multi-duplicate-test-${randomUUIDv7().substring(0, 8)}`; + await duplicate1.set(`${testKey}-1`, "value-1"); + await duplicate2.set(`${testKey}-2`, "value-2"); + await duplicate3.set(`${testKey}-3`, "value-3"); + + expect(await duplicate1.get(`${testKey}-1`)).toBe("value-1"); + expect(await duplicate2.get(`${testKey}-2`)).toBe("value-2"); + expect(await duplicate3.get(`${testKey}-3`)).toBe("value-3"); + + // Cross-check: each duplicate can read what others wrote + expect(await duplicate1.get(`${testKey}-2`)).toBe("value-2"); + expect(await duplicate2.get(`${testKey}-3`)).toBe("value-3"); + expect(await duplicate3.get(`${testKey}-1`)).toBe("value-1"); + + duplicate1.close(); + duplicate2.close(); + duplicate3.close(); + }); + + test("should duplicate client that failed to connect", async () => { + // Create client with invalid credentials to force connection failure + const url = new URL(DEFAULT_REDIS_URL); + url.username = "invaliduser"; + url.password = "invalidpassword"; + const failedRedis = new RedisClient(url.toString()); + + // Try to connect and expect it to fail + let connectionFailed = false; + try { + await failedRedis.connect(); + } catch { + connectionFailed = true; + } + + expect(connectionFailed).toBe(true); + expect(failedRedis.connected).toBe(false); + + // Duplicate should also remain unconnected + const duplicate = await failedRedis.duplicate(); + expect(duplicate.connected).toBe(false); + }); + + test("should handle duplicate timing with concurrent operations", async () => { + await ctx.redis.connect(); + + // Start some operations on the original client + const testKey = `concurrent-test-${randomUUIDv7().substring(0, 8)}`; + const originalOperation = ctx.redis.set(testKey, "original-value"); + + // Create duplicate while operation is in flight + const duplicate = await ctx.redis.duplicate(); + + // Wait for original operation to complete + await originalOperation; + + // Duplicate should be able to read the value + expect(await duplicate.get(testKey)).toBe("original-value"); + + duplicate.close(); + }); + }); });