From d0acbffdbd0bf8867431fa9aa067bd891837b472 Mon Sep 17 00:00:00 2001 From: Marko Vejnovic Date: Mon, 3 Nov 2025 14:38:14 -0800 Subject: [PATCH] implement basic unsubscribe --- src/valkey2/js_valkey.zig | 16 ++- src/valkey2/valkey.zig | 241 ++++++++++++++++++++++++++++++-------- 2 files changed, 209 insertions(+), 48 deletions(-) diff --git a/src/valkey2/js_valkey.zig b/src/valkey2/js_valkey.zig index b577517e12..b67672b316 100644 --- a/src/valkey2/js_valkey.zig +++ b/src/valkey2/js_valkey.zig @@ -1055,7 +1055,13 @@ pub const JsValkey = struct { } var ctx: RequestContext = .{ .user_request = .init(go, false) }; - self._client.unsubscribeChannels(channel_slices.items, ctx); + self._client.unsubscribeChannels(channel_slices.items, ctx) catch |err| { + // Synchronous error: swap() gives us the promise and destroys the Strong + const promise = ctx.user_request._promise.swap(); + const error_value = protocol.valkeyErrorToJS(go, err, null, .{}); + promise.reject(go, error_value); + return promise.toJS(); + }; return ctx.user_request.promise().toJS(); } @@ -1064,7 +1070,13 @@ pub const JsValkey = struct { }; var ctx: RequestContext = .{ .user_request = .init(go, false) }; const channels = [_][]const u8{channel.slice()}; - self._client.unsubscribeChannels(&channels, ctx); + self._client.unsubscribeChannels(&channels, ctx) catch |err| { + // Synchronous error: swap() gives us the promise and destroys the Strong + const promise = ctx.user_request._promise.swap(); + const error_value = protocol.valkeyErrorToJS(go, err, null, .{}); + promise.reject(go, error_value); + return promise.toJS(); + }; return ctx.user_request.promise().toJS(); } diff --git a/src/valkey2/valkey.zig b/src/valkey2/valkey.zig index 0dfbb25a9f..078a0e48ee 100644 --- a/src/valkey2/valkey.zig +++ b/src/valkey2/valkey.zig @@ -45,6 +45,22 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: }; } + /// Get an existing channel ID for a given channel name. + pub fn existingChannelId(self: *Self, channel: []const u8) ?SubscriptionChannelId { + return self.channel_map.get(channel); + } + + /// Generate a new unique channel ID. + fn newChannelId(self: *Self) SubscriptionChannelId { + const new_id = self._next_channel_id; + self._next_channel_id += 1; + while (self.map.get(self._next_channel_id) != null) { + @branchHint(.cold); + self._next_channel_id += 1; + } + return new_id; + } + /// Register a listener as pending for a given channel. This listener will not be invoked /// for messages until it is promoted to active. pub fn addPendingHandler( @@ -58,17 +74,13 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: channel, }); - const channel_id = self.channel_map.get(channel) orelse get_id_blk: { - const new_id = self._next_channel_id; - try self.channel_map.put(channel, new_id); - - self._next_channel_id += 1; - while (self.map.get(self._next_channel_id) != null) { - @branchHint(.cold); - self._next_channel_id += 1; - } - - break :get_id_blk new_id; + const channel_id = self.existingChannelId(channel) orelse new_channel_blk: { + const new_id = self.newChannelId(); + // StringArrayHashMap doesn't own the key, so we must allocate and copy it + const channel_copy = try self._allocator.dupe(u8, channel); + errdefer self._allocator.free(channel_copy); + try self.channel_map.put(channel_copy, new_id); + break :new_channel_blk new_id; }; const map_vecs = self.map.getPtr(channel_id) orelse insert_blk: { @@ -123,6 +135,32 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: try map_entry.active.append(listener); } + fn removeActiveForChannel(self: *Self, channel_id: SubscriptionChannelId) !void { + Self.debug("{*}.removeActiveForChannel({})", .{ self, channel_id }); + + const map_entry = self.map.getPtr(channel_id) orelse { + return; + }; + + map_entry.active.shrinkAndFree(map_entry.pending.items.len); + + if (map_entry.pending.items.len == 0) { + map_entry.active.deinit(); + map_entry.pending.deinit(); + _ = self.map.remove(channel_id); + + var it = self.channel_map.iterator(); + while (it.next()) |entry| { + if (entry.value_ptr.* == channel_id) { + const key_to_free = entry.key_ptr.*; + _ = self.channel_map.swapRemove(key_to_free); + self._allocator.free(key_to_free); + break; + } + } + } + } + fn removeActiveHandler( self: *Self, channel_id: SubscriptionChannelId, @@ -142,7 +180,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: return; }; - const active_idx: ?usize = for (map_entry.pending.items, 0..) |entry, idx| { + const active_idx: ?usize = for (map_entry.active.items, 0..) |entry, idx| { if (entry.handler_id == handler_id) { break idx; } @@ -150,7 +188,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: if (active_idx == null) { bun.Output.debugPanic( - "SubscriptionTracker.removeActiveHandler did not find pending handler_id {}" ++ + "SubscriptionTracker.removeActiveHandler did not find active handler_id {}" ++ "on channel_id {}", .{ handler_id, channel_id }, ); @@ -167,7 +205,9 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: var it = self.channel_map.iterator(); while (it.next()) |entry| { if (entry.value_ptr.* == channel_id) { - _ = self.channel_map.swapRemove(entry.key_ptr.*); + const key_to_free = entry.key_ptr.*; + _ = self.channel_map.swapRemove(key_to_free); + self._allocator.free(key_to_free); break; } } @@ -228,6 +268,23 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: // ot only check the channel map. return self.channel_map.count() > 0; } + + pub fn deinit(self: *Self) void { + // Free all allocated channel names + var it = self.channel_map.iterator(); + while (it.next()) |entry| { + self._allocator.free(entry.key_ptr.*); + } + self.channel_map.deinit(); + + // Free all subscription entries + var map_it = self.map.iterator(); + while (map_it.next()) |entry| { + entry.value_ptr.active.deinit(); + entry.value_ptr.pending.deinit(); + } + self.map.deinit(); + } }; return struct { @@ -244,19 +301,6 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: FailedToOpenSocket, }; - const SubscriptionUpdateRequestContext = struct { - user_context: UserRequestContext, - - handler_id: u64, - channel_ids: std.ArrayList(SubscriptionChannelId), - - request_type: enum { subscribe, unsubscribe }, - - pub fn deinit(self: SubscriptionUpdateRequestContext) void { - self.channel_ids.deinit(); - } - }; - /// Each message sent or received by the client has a context which is private to this /// client. User contexts are contained within this context too. /// @@ -267,16 +311,46 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: const RequestContext = union(enum) { user_context: UserRequestContext, - /// Context for SUBSCRIBE and UNSUBSCRIBE requests. Used to track listener and channel - /// pair. - subscription_context: SubscriptionUpdateRequestContext, + /// Context for SUBSCRIBE requests. Used to track listener and channel pair. + subscribe_context: struct { + user_context: UserRequestContext, + + handler_id: u64, + channel_ids: std.ArrayList(SubscriptionChannelId), + + request_type: enum { subscribe, unsubscribe }, + + pub fn deinit(self: *@This()) void { + self.channel_ids.deinit(); + } + }, + + unsubscribe_context: struct { + user_context: UserRequestContext, + + handler_id: union(enum) { + id: u64, + all: void, + }, + channel_ids: std.ArrayList(SubscriptionChannelId), + + request_type: enum { subscribe, unsubscribe }, + + pub fn deinit(self: *@This()) void { + self.channel_ids.deinit(); + } + }, pub fn failOom(self: *RequestContext, listener: *ValkeyListener) void { switch (self.*) { .user_context => |*ctx| { ctx.failOom(listener); }, - .subscription_context => |*ctx| { + .subscribe_context => |*ctx| { + ctx.user_context.failOom(listener); + ctx.deinit(); + }, + .unsubscribe_context => |*ctx| { ctx.user_context.failOom(listener); ctx.deinit(); }, @@ -292,7 +366,11 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: .user_context => |*ctx| { callbacks.onRequestDropped(ctx, reason); }, - .subscription_context => |*ctx| { + .subscribe_context => |*ctx| { + callbacks.onRequestDropped(&ctx.user_context, reason); + ctx.deinit(); + }, + .unsubscribe_context => |*ctx| { callbacks.onRequestDropped(&ctx.user_context, reason); ctx.deinit(); }, @@ -306,7 +384,11 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: ) !void { return switch (self.*) { .user_context => |*ctx| callbacks.onResponse(ctx, value), - .subscription_context => |*ctx| { + .subscribe_context => |*ctx| { + try callbacks.onResponse(&ctx.user_context, value); + ctx.deinit(); + }, + .unsubscribe_context => |*ctx| { try callbacks.onResponse(&ctx.user_context, value); ctx.deinit(); }, @@ -440,6 +522,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: self._inflight_queue.deinit(); self._outbound_queue.deinit(); self._connection_params.deinit(); + self._subscriptions.deinit(); self._callbacks.onDeinit(); } @@ -681,19 +764,47 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: value.* = .{ .Boolean = value.Integer > 0 }; } }, - .subscription_context => |*ctx| { - for (ctx.channel_ids.items) |channel_id| { - switch (ctx.request_type) { - .subscribe => { + .subscribe_context => |*ctx| { + switch (ctx.request_type) { + .subscribe => { + for (ctx.channel_ids.items) |channel_id| { try self._subscriptions.promotePendingListenerToActive( channel_id, ctx.handler_id, ); - }, + } + }, + else => { + bun.Output.debugPanic( + "Unexpected request type {} in subscribe_context", + .{ctx.request_type}, + ); + }, + } + }, + .unsubscribe_context => |*ctx| { + for (ctx.channel_ids.items) |channel_id| { + switch (ctx.request_type) { .unsubscribe => { - try self._subscriptions.removeActiveHandler( - channel_id, - ctx.handler_id, + switch (ctx.handler_id) { + .all => { + // They requested we drop all handlers for this channel. + for (ctx.channel_ids.items) |ch_id| { + try self._subscriptions.removeActiveForChannel(ch_id); + } + }, + .id => |handler_id| { + try self._subscriptions.removeActiveHandler( + channel_id, + handler_id, + ); + }, + } + }, + else => { + bun.Output.debugPanic( + "Unexpected request type {} in unsubscribe_context", + .{ctx.request_type}, ); }, } @@ -789,7 +900,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: var req: InternalRequestType = .{ .command = .initById(.SUBSCRIBE, .{ .raw = channels }), .context = .{ - .subscription_context = .{ + .subscribe_context = .{ .user_context = ctx.*, .handler_id = handler_id, .channel_ids = channel_ids, @@ -844,14 +955,52 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext: /// Unsubscribe from multiple channels. /// If any subscriptions are in-flight, they will be cancelled. + /// + /// TODO(markovejnovic): This implementation is not ideal. One of the really annoying + /// things about our implementation compared to ioredis is the fact + /// that subscribe and unsubscribe are async functions, which only + /// perform their action once the server confirms the action. + /// + /// We could improve upon this by making these functions sync, eagerly + /// registering the handlers and then, only upon failure, + /// deregistering the handler. This feels kind of silent and bad, but + /// it does seem to be what ioredis does. pub fn unsubscribeChannels( self: *Self, channels: []const []const u8, user_ctx: UserRequestContext, - ) void { - _ = self; - _ = channels; - _ = user_ctx; + ) !void { + Self.debug("{*}.unsubscribeChannels({s})", .{ self, channels }); + + // TODO(markovejnovic): The user experience might be better if we eagerly cancelled our + // listeners. The way it works now is that we cancel the listener + // only when the server confirms the unsubscription, which feels + // undesirable. + var channel_ids = try std.ArrayList(SubscriptionChannelId).initCapacity( + self._allocator, + channels.len, + ); + for (channels) |channel| { + channel_ids.appendAssumeCapacity(self._subscriptions.existingChannelId( + channel, + ) orelse { + continue; + }); + } + + var req: InternalRequestType = .{ + .command = .initById(.UNSUBSCRIBE, .{ .raw = channels }), + .context = .{ + .unsubscribe_context = .{ + .user_context = user_ctx, + .handler_id = .{ .all = {} }, + .channel_ids = channel_ids, + .request_type = .unsubscribe, + }, + }, + }; + + return self.submitInternalRequest(&req); } /// Unsubscribe from all current subscriptions.