From a32d8a5c4ef781f2d4f64a41c65b71a6cd8e4db0 Mon Sep 17 00:00:00 2001 From: Marko Vejnovic Date: Mon, 3 Nov 2025 10:37:11 -0800 Subject: [PATCH] add subscription api --- src/valkey2/js_valkey.zig | 79 ++++++++- src/valkey2/valkey.zig | 290 ++++++++++++++++----------------- test/js/valkey2/test-utils.ts | 12 ++ test/js/valkey2/valkey.test.ts | 16 +- 4 files changed, 244 insertions(+), 153 deletions(-) diff --git a/src/valkey2/js_valkey.zig b/src/valkey2/js_valkey.zig index aac7e7e56e..273dab0446 100644 --- a/src/valkey2/js_valkey.zig +++ b/src/valkey2/js_valkey.zig @@ -15,6 +15,9 @@ pub const JsValkey = struct { /// The context object passed with each request. Keep it small. const RequestContext = union(enum) { /// The JS user requested this command and an associated promise is present. + /// + /// TODO(markovejnovic): + /// Remove this type, it's stupid. All requests are user requests in our use-case pub const UserRequest = struct { _promise: bun.jsc.JSPromise.Strong, // TODO(markovejnovic): This gives array-of-struct vibes instead of struct-of-array. @@ -90,6 +93,7 @@ pub const JsValkey = struct { ctx: *RequestContext, value: *protocol.RESPValue, ) !void { + Self.debug("{*}.onResponse(...)", .{self}); const go = self.parent()._global_obj; switch (ctx.*) { @@ -127,6 +131,10 @@ pub const JsValkey = struct { _ = self; } + pub fn onDisconnect(self: *@This()) void { + _ = self; + } + pub fn onDeinit(self: *@This()) void { _ = self; } @@ -902,6 +910,72 @@ pub const JsValkey = struct { return promise.toJS(); } + pub fn subscribe( + self: *Self, + go: *bun.jsc.JSGlobalObject, + cf: *bun.jsc.CallFrame, + ) bun.JSError!bun.jsc.JSValue { + const channel_or_many, const handler_callback = cf.argumentsAsArray(2); + var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); + var redis_channels = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), 1); + defer { + for (redis_channels.items) |*item| { + item.deinit(); + } + redis_channels.deinit(); + } + + if (!handler_callback.isCallable()) { + return go.throwInvalidArgumentType("subscribe", "listener", "function"); + } + + if (channel_or_many.isArray()) { + if ((try channel_or_many.getLength(go)) == 0) { + return go.throwInvalidArguments("subscribe requires at least one channel", .{}); + } + try redis_channels.ensureTotalCapacity(try channel_or_many.getLength(go)); + + var array_iter = try channel_or_many.arrayIterator(go); + while (try array_iter.next()) |channel_arg| { + const channel = (try jsValueToJsArgument(go, channel_arg)) orelse { + return go.throwInvalidArgumentType("subscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + } + } else if (channel_or_many.isString()) { + // It is a single string channel + const channel = (try jsValueToJsArgument(go, channel_or_many)) orelse { + return go.throwInvalidArgumentType("subscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + } else { + return go.throwInvalidArgumentType("subscribe", "channel", "string or array"); + } + + var channel_slices = try std.ArrayList([]const u8).initCapacity( + stack_fallback.get(), + redis_channels.items.len, + ); + defer channel_slices.deinit(); + for (redis_channels.items) |*channel_arg| { + channel_slices.appendAssumeCapacity(channel_arg.slice()); + } + + var ctx: RequestContext = .{ .user_request = .init(go, false) }; + self._client.subscribe( + channel_slices.items, + handler_callback.asPtrAddress(), + &ctx, + ) catch |err| { + // Synchronous error: swap() gives us the promise and destroys the Strong + const promise = ctx.user_request._promise.swap(); + const error_value = protocol.valkeyErrorToJS(go, err, null, .{}); + promise.reject(go, error_value); + return promise.toJS(); + }; + return ctx.user_request.promise().toJS(); + } + pub fn unsubscribe( self: *Self, go: *bun.jsc.JSGlobalObject, @@ -1085,7 +1159,6 @@ pub const JsValkey = struct { pub const scard = MetFactory.@"(key: RedisKey)"("scard", .SCARD, "key").fxn; pub const script = MetFactory.@"(...strings: string[])"("script", .SCRIPT).fxn; pub const sdiff = MetFactory.@"(...strings: string[])"("sdiff", .SDIFF).fxn; - pub const subscribe = MetFactory.@"(...strings: string[])"("subscribe", .SUBSCRIBE).fxn; pub const sdiffstore = MetFactory.@"(...strings: string[])"("sdiffstore", .SDIFFSTORE).fxn; pub const select = MetFactory.@"(...strings: string[])"("select", .SELECT).fxn; pub const setbit = MetFactory.@"(key: RedisKey, value: RedisValue, value2: RedisValue)"("setbit", .SETBIT, "key", "offset", "value").fxn; @@ -1273,7 +1346,7 @@ const MetFactory = struct { ) type { return struct { pub fn fxn( - this: *JsValkey, + self: *JsValkey, go: *bun.jsc.JSGlobalObject, cf: *bun.jsc.CallFrame, ) bun.JSError!bun.jsc.JSValue { @@ -1299,7 +1372,7 @@ const MetFactory = struct { try args.append(another); } - const promise = this.request( + const promise = self.request( go, cf.this(), Command.initById(command, .{ .args = args.items }), diff --git a/src/valkey2/valkey.zig b/src/valkey2/valkey.zig index c2ae4a0cc9..a12e68f33f 100644 --- a/src/valkey2/valkey.zig +++ b/src/valkey2/valkey.zig @@ -20,36 +20,25 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: const SubscriptionHandlerId = u64; const SubscriptionChannelId = u64; - const _SubscribePushHandler = *const fn ( - /// The unique identifier of the listener. - handler_id: SubscriptionHandlerId, - /// The listener instance. - listener: *ValkeyListener, - /// The channel on which the message was received. - channel: []const u8, - /// The payload of the message. - payload: protocol.RESPValue, - ) void; - const SubscriptionTracker = struct { const Self = @This(); /// Object stored to track an active Pub/Sub subscription. const SubscriptionVectorEntry = struct { /// The listener ID associated with this subscription. handler_id: SubscriptionHandlerId, - /// The listener function associated with this subscription. - listener: _SubscribePushHandler, }; map: std.AutoHashMap(SubscriptionChannelId, SubscriptionMapEntry), - channel_map: std.StringHashMap(SubscriptionChannelId), + channel_map: std.StringArrayHashMap(SubscriptionChannelId), _next_channel_id: SubscriptionChannelId, + _allocator: std.mem.Allocator, pub fn init(allocator: std.mem.Allocator) Self { return Self{ .map = .init(allocator), .channel_map = .init(allocator), ._next_channel_id = 0, + ._allocator = allocator, }; } @@ -59,13 +48,13 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: self: *Self, channel: []const u8, entry: SubscriptionVectorEntry, - ) !void { + ) !SubscriptionChannelId { const channel_id = self.channel_map.get(channel) orelse get_id_blk: { const new_id = self._next_channel_id; try self.channel_map.put(channel, new_id); self._next_channel_id += 1; - while (self.channel_map.get(self._next_channel_id)) { + while (self.map.get(self._next_channel_id) != null) { @branchHint(.cold); self._next_channel_id += 1; } @@ -73,37 +62,29 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: break :get_id_blk new_id; }; - const map_vecs = self.map.getPtr(channel_id) orelse { + const map_vecs = self.map.getPtr(channel_id) orelse insert_blk: { try self.map.put( channel_id, - .{ .active = .init(), .pending = .init() }, + .{ .active = .init(self._allocator), .pending = .init(self._allocator) }, ); + break :insert_blk self.map.getPtr(channel_id).?; }; try map_vecs.pending.append(entry); + + return channel_id; } - // TODO(markovejnovic): It's kind of weird that channel is passed in again here. The - // handler_id is unique after all. pub fn promotePendingListenerToActive( self: *Self, - channel: []const u8, - handler_id: u64, + channel_id: SubscriptionChannelId, + handler_id: SubscriptionHandlerId, ) !void { - const channel_id = self.channel_map.get(channel) orelse { - bun.Output.debugPanic( - "SubscriptionTracker.promotePendingListenerToActive did not find a " ++ - "channel: {s}", - .{channel}, - ); - return; - }; - const map_entry = self.map.getPtr(channel_id) orelse { bun.Output.debugPanic( "SubscriptionTracker.promotePendingListenerToActive did not find an " ++ - "entry: {s}", - .{channel}, + "entry for channel_id {}", + .{channel_id}, ); return; }; @@ -117,34 +98,25 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: if (pending_idx == null) { bun.Output.debugPanic( "SubscriptionTracker.promotePendingListenerToActive did not find pending " ++ - "handler_id {d} on channel {s}", - .{ handler_id, channel }, + "handler_id {d} on channel_id {}", + .{ handler_id, channel_id }, ); return; } const listener = map_entry.pending.swapRemove(pending_idx.?); - map_entry.active.append(listener); + try map_entry.active.append(listener); } fn removeActiveHandler( self: *Self, - channel: []const u8, - handler_id: u64, + channel_id: SubscriptionChannelId, + handler_id: SubscriptionHandlerId, ) !void { - const channel_id = self.channel_map.get(channel) orelse { - bun.Output.debugPanic( - "SubscriptionTracker.removeActiveHandler did not find a " ++ - "channel: {s}", - .{channel}, - ); - return; - }; - const map_entry = self.map.getPtr(channel_id) orelse { bun.Output.debugPanic( - "SubscriptionTracker.removeActiveHandler did not find a channel {s}", - .{channel}, + "SubscriptionTracker.removeActiveHandler did not find a channel_id {}", + .{channel_id}, ); return; }; @@ -157,9 +129,9 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: if (active_idx == null) { bun.Output.debugPanic( - "SubscriptionTracker.removeActiveHandler did not find pending handler_id " ++ - "on channel {s}", - .{ handler_id, channel }, + "SubscriptionTracker.removeActiveHandler did not find pending handler_id {}" ++ + "on channel_id {}", + .{ handler_id, channel_id }, ); return; } @@ -171,7 +143,13 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: map_entry.pending.deinit(); _ = self.map.remove(channel_id); - _ = self.channel_map.remove(channel); + var it = self.channel_map.iterator(); + while (it.next()) |entry| { + if (entry.value_ptr.* == channel_id) { + _ = self.channel_map.swapRemove(entry.key_ptr.*); + break; + } + } } } @@ -251,10 +229,16 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: }; const SubscriptionUpdateRequestContext = struct { - handler_id: u64, - channel_id: u64, - user_context: UserRequestContext, + + handler_id: u64, + channel_ids: std.ArrayList(SubscriptionChannelId), + + request_type: enum { subscribe, unsubscribe }, + + pub fn deinit(self: SubscriptionUpdateRequestContext) void { + self.channel_ids.deinit(); + } }; /// Each message sent or received by the client has a context which is private to this @@ -278,6 +262,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: }, .subscription_context => |*ctx| { ctx.user_context.failOom(listener); + ctx.deinit(); }, } } @@ -289,6 +274,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: }, .subscription_context => |*ctx| { callbacks.onRequestDropped(&ctx.user_context, reason); + ctx.deinit(); }, } } @@ -300,7 +286,10 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: ) !void { return switch (self.*) { .user_context => |*ctx| callbacks.onResponse(ctx, value), - .subscription_context => |*ctx| callbacks.onResponse(&ctx.user_context, value), + .subscription_context => |*ctx| { + try callbacks.onResponse(&ctx.user_context, value); + ctx.deinit(); + }, }; } }; @@ -312,9 +301,6 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: /// All responses are paired with the original request context. pub const ResponseType = Response(UserRequestContext); - /// Type of function invoked whenever a subscription message is received. - pub const SubscribePushHandler = _SubscribePushHandler; - /// Types internal to the ValkeyClient implementation which encode a request pending to be /// sent to the server. const QueuedRequestType = QueuedRequest(RequestContext); @@ -645,16 +631,18 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: const l_state = &self._state.linked; switch (l_state.state) { - .normal => { + // The reason we couple both states together is because we don't really know if + // what we're looking at is a SUBSCRIBE/UNSUBSCRIBE response or whether it is a + // normal response we want to pass to the user. + // + // Consequently, onNormalPacket will handle both normal responses and subscription + // responses, and will delegate behavior. + .normal, .subscriber => { try self.onNormalPacket(value); }, .authenticating => { try self.onAuthenticatingPacket(value); }, - .subscriber => { - // TODO(markovejnovic): Enable this - //try self.onSubscriberPacket(value); - }, } } @@ -667,8 +655,30 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: return; }; - if (req.returns_bool and value.* == .Integer) { - value.* = .{ .Boolean = value.Integer > 0 }; + switch (req.context) { + .user_context => { + if (req.returns_bool and value.* == .Integer) { + value.* = .{ .Boolean = value.Integer > 0 }; + } + }, + .subscription_context => |*ctx| { + for (ctx.channel_ids.items) |channel_id| { + switch (ctx.request_type) { + .subscribe => { + try self._subscriptions.promotePendingListenerToActive( + channel_id, + ctx.handler_id, + ); + }, + .unsubscribe => { + try self._subscriptions.removeActiveHandler( + channel_id, + ctx.handler_id, + ); + }, + } + } + }, } try req.context.succeed(value, self._callbacks); @@ -728,68 +738,74 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: /// /// Args: /// - /// - `channel`: The channel to subscribe to. + /// - `channels`: The channels to subscribe to. /// - `handler_id`: Unique identifier for this listener. This identifier can be used to /// remove the listener later. See `unsubscribeListener`. /// - `handler`: The function to invoke whenever a message is received on this channel. pub fn subscribe( self: *Self, - channel: []const u8, - handler_id: u64, - handler: SubscribePushHandler, + channels: []const []const u8, + handler_id: SubscriptionHandlerId, + ctx: *const UserRequestContext, ) !void { - Self.debug("{*}.subscribe({s}, handler_id={}, ...)", .{ self, channel, handler_id }); + Self.debug("{*}.subscribe({s}, handler_id={}, ...)", .{ self, channels, handler_id }); // Before we do anything, we populate the active subscriptions vector. The // subscriptions vector is what tracks all active subscriptions. Note that this vector - // may contain subscriptions which are not yet confirmed by the server. When the - // server confirms them, we will promote them from "pending" to "active". - try self._subscriptions.addPendingHandler(channel, .{ - .handler = handler, - .handler_id = handler_id, - }); + // may contain subscriptions which are not yet confirmed by the server. When the server + // confirms them, we will promote them from "pending" to "active". + var channel_ids = try std.ArrayList(SubscriptionChannelId).initCapacity( + self._allocator, + channels.len, + ); + for (channels) |channel| { + const channel_id = try self._subscriptions.addPendingHandler(channel, .{ + .handler_id = handler_id, + }); - // Now we need to send out the request. + channel_ids.appendAssumeCapacity(channel_id); + } - //switch (self._state) { - // .linked => |*l_state| { - // switch (l_state.state) { - // .subscriber, .authenticating, .normal => { - // // Great, this state can send requests. - // // - // // Note that subscribers CAN, in-fact, send requests, since RESP3 - // // permits that. - // try self.enqueueRequest(req); - // }, - // } - // }, - // .closed => { - // // We're closed, we can't send requests until the user asks for us to - // // reconnect, explicitly. - // return error.ConnectionClosed; - // }, - // else => { - // Self.debug( - // "{*} Received an unexpected request in {s} state.", - // .{ self, @tagName(self._state) }, - // ); + var req: InternalRequestType = .{ + .command = .initById(.SUBSCRIBE, .{ .raw = channels }), + .context = .{ + .subscription_context = .{ + .user_context = ctx.*, + .handler_id = handler_id, + .channel_ids = channel_ids, + .request_type = .subscribe, + }, + }, + }; - // // Okay, we're not currently in the linked state. What we can do is enqueue the - // // request and attempt to start the connection. - // try self.enqueueRequest(req); - // self.startConnecting() catch |err| { - // switch (err) { - // Error.InvalidState => { - // self._state.warnIllegalState("request-start-connecting"); - // // No-op, we're already connecting or connected. - // }, - // Error.FailedToOpenSocket => { - // return error.ConnectionClosed; - // }, - // } - // }; - // }, - //} + return self.submitInternalRequest(&req); + } + + fn submitInternalRequest(self: *Self, req_internal: *InternalRequestType) !void { + switch (self._state) { + .closed => { + // We're closed, we can't send requests until the user asks for us to + // reconnect, explicitly. + return error.ConnectionClosed; + }, + else => { + try self.enqueueRequest(req_internal); + + if (self._state != .linked) { + self.startConnecting() catch |err| { + switch (err) { + Error.InvalidState => { + self._state.warnIllegalState("request-start-connecting"); + // No-op, we're already connecting or connected. + }, + Error.FailedToOpenSocket => { + return error.ConnectionClosed; + }, + } + }; + } + }, + } } /// Unsubscribe a previously registered handler. @@ -1051,6 +1067,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: /// Invoked when transitioning out of the linked state to any other state. fn onStateLinkedToAny(self: *Self, from_state: *State) !void { + self.unregisterAutoFlusher(); from_state.*.linked._egress_buffer.deinit(self._allocator); from_state.*.linked._ingress_buffer.deinit(self._allocator); } @@ -1066,8 +1083,6 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: self.unregisterAutoFlusher(); self._socket_io.close(); - - // We also want to fail all the pending requests. self.dropAllQueuedMessages(.{ .closing = {} }); self._callbacks.onClose(); @@ -1195,6 +1210,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: /// This will drop all queued messages and transition the client into the disconnected /// state. fn failIrrecoverably(self: *Self, reason: IrrecoverableFailureReason) void { + self.unregisterAutoFlusher(); self.dropAllQueuedMessages(.{ .irrecoverable_failure = reason }); } @@ -1203,6 +1219,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: /// If the state transition fails, this will forcibly set the state to the disconnected /// mode. fn dangerouslyForceClientIntoDisconnectedState(self: *Self) void { + self.unregisterAutoFlusher(); self.dropAllQueuedMessages(.{ .closing = {} }); self._state.transition(.{ .disconnected = {} }) catch |err| { // In the case that we observe an error to transition to closed, that's a bug in @@ -1255,36 +1272,13 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: req.command.args.len(), }); - switch (self._state) { - .closed => { - // We're closed, we can't send requests until the user asks for us to - // reconnect, explicitly. - return error.ConnectionClosed; + var req_internal = Self.InternalRequestType{ + .command = req.command, + .context = .{ + .user_context = req.context, }, - else => { - var req_internal = Self.InternalRequestType{ - .command = req.command, - .context = .{ - .user_context = req.context, - }, - }; - try self.enqueueRequest(&req_internal); - - if (self._state != .linked) { - self.startConnecting() catch |err| { - switch (err) { - Error.InvalidState => { - self._state.warnIllegalState("request-start-connecting"); - // No-op, we're already connecting or connected. - }, - Error.FailedToOpenSocket => { - return error.ConnectionClosed; - }, - } - }; - } - }, - } + }; + try self.submitInternalRequest(&req_internal); } /// Attempt to enqueue a request for sending to the server. This may choose to skip the diff --git a/test/js/valkey2/test-utils.ts b/test/js/valkey2/test-utils.ts index a484915e9c..3219becd69 100644 --- a/test/js/valkey2/test-utils.ts +++ b/test/js/valkey2/test-utils.ts @@ -170,4 +170,16 @@ export namespace ValkeyFaker { // Use 1 KB max size for regular values to keep tests fast. 1kB is still a reasonably large value. return Array.from({ length: count }, () => value(randomEngine, 1024)); } + + export function channel(randomEngine: random.RandomEngine, maxSize: number = 256): string { + return random.dirtyLatin1String(randomEngine, maxSize); + } + + export function channels(randomEngine: random.RandomEngine, count: number): string[] { + return Array.from({ length: count }, () => channel(randomEngine, 256)); + } + + export function publishMessage(randomEngine: random.RandomEngine, maxSize: number = 1024 * 1024): string { + return ValkeyFaker.value(randomEngine, maxSize); + } } diff --git a/test/js/valkey2/valkey.test.ts b/test/js/valkey2/valkey.test.ts index 85b6a1c7ee..bd68fb8e9c 100644 --- a/test/js/valkey2/valkey.test.ts +++ b/test/js/valkey2/valkey.test.ts @@ -6589,7 +6589,7 @@ describeValkey( const testKeyUniquePerDb = crypto.randomUUID(); // TODO(markovejnovic): Don't skip this. - test.skip.each([...Array(16).keys()])("Connecting to database with url $url succeeds", async (dbId: number) => { + test.skip.each([...Array(16).keys()])("Connecting to database with url %s succeeds", async (dbId: number) => { const redis = createClient(connectionType, {}, dbId); const testValue = await redis.get(testKeyUniquePerDb); @@ -6614,11 +6614,23 @@ describeValkey( await ctx.restartServer(); - const valueAfterStop = await ctx.connectedClient().get(TEST_KEY); + const valueAfterStop = (await ctx.connectedClient()).get(TEST_KEY); expect(valueAfterStop).toBe(TEST_VALUE); }); }); + describe("Pub/Sub", () => { + test.each(ValkeyFaker.channels(randomEngine, 4))("publishing to channel %s does not fail", async (channel: string) => { + const client = await ctx.connectedClient(); + expect(await client.publish(channel, ValkeyFaker.publishMessage(randomEngine))).toBe(0); + }); + + test.each(ValkeyFaker.channels(randomEngine, 4))("subscribing to %s does not fail", async (channel: string) => { + const client = await ctx.connectedClient(); + await client.subscribe(channel); + await client.unsubscribe(channel); + }); + }); }, { server: "redis://localhost:6379" }, );