add subscription api

This commit is contained in:
Marko Vejnovic
2025-11-03 10:37:11 -08:00
parent 046a682ba5
commit a32d8a5c4e
4 changed files with 244 additions and 153 deletions

View File

@@ -15,6 +15,9 @@ pub const JsValkey = struct {
/// The context object passed with each request. Keep it small.
const RequestContext = union(enum) {
/// The JS user requested this command and an associated promise is present.
///
/// TODO(markovejnovic):
/// Remove this type, it's stupid. All requests are user requests in our use-case
pub const UserRequest = struct {
_promise: bun.jsc.JSPromise.Strong,
// TODO(markovejnovic): This gives array-of-struct vibes instead of struct-of-array.
@@ -90,6 +93,7 @@ pub const JsValkey = struct {
ctx: *RequestContext,
value: *protocol.RESPValue,
) !void {
Self.debug("{*}.onResponse(...)", .{self});
const go = self.parent()._global_obj;
switch (ctx.*) {
@@ -127,6 +131,10 @@ pub const JsValkey = struct {
_ = self;
}
pub fn onDisconnect(self: *@This()) void {
_ = self;
}
pub fn onDeinit(self: *@This()) void {
_ = self;
}
@@ -902,6 +910,72 @@ pub const JsValkey = struct {
return promise.toJS();
}
pub fn subscribe(
self: *Self,
go: *bun.jsc.JSGlobalObject,
cf: *bun.jsc.CallFrame,
) bun.JSError!bun.jsc.JSValue {
const channel_or_many, const handler_callback = cf.argumentsAsArray(2);
var stack_fallback = std.heap.stackFallback(512, bun.default_allocator);
var redis_channels = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), 1);
defer {
for (redis_channels.items) |*item| {
item.deinit();
}
redis_channels.deinit();
}
if (!handler_callback.isCallable()) {
return go.throwInvalidArgumentType("subscribe", "listener", "function");
}
if (channel_or_many.isArray()) {
if ((try channel_or_many.getLength(go)) == 0) {
return go.throwInvalidArguments("subscribe requires at least one channel", .{});
}
try redis_channels.ensureTotalCapacity(try channel_or_many.getLength(go));
var array_iter = try channel_or_many.arrayIterator(go);
while (try array_iter.next()) |channel_arg| {
const channel = (try jsValueToJsArgument(go, channel_arg)) orelse {
return go.throwInvalidArgumentType("subscribe", "channel", "string");
};
redis_channels.appendAssumeCapacity(channel);
}
} else if (channel_or_many.isString()) {
// It is a single string channel
const channel = (try jsValueToJsArgument(go, channel_or_many)) orelse {
return go.throwInvalidArgumentType("subscribe", "channel", "string");
};
redis_channels.appendAssumeCapacity(channel);
} else {
return go.throwInvalidArgumentType("subscribe", "channel", "string or array");
}
var channel_slices = try std.ArrayList([]const u8).initCapacity(
stack_fallback.get(),
redis_channels.items.len,
);
defer channel_slices.deinit();
for (redis_channels.items) |*channel_arg| {
channel_slices.appendAssumeCapacity(channel_arg.slice());
}
var ctx: RequestContext = .{ .user_request = .init(go, false) };
self._client.subscribe(
channel_slices.items,
handler_callback.asPtrAddress(),
&ctx,
) catch |err| {
// Synchronous error: swap() gives us the promise and destroys the Strong
const promise = ctx.user_request._promise.swap();
const error_value = protocol.valkeyErrorToJS(go, err, null, .{});
promise.reject(go, error_value);
return promise.toJS();
};
return ctx.user_request.promise().toJS();
}
pub fn unsubscribe(
self: *Self,
go: *bun.jsc.JSGlobalObject,
@@ -1085,7 +1159,6 @@ pub const JsValkey = struct {
pub const scard = MetFactory.@"(key: RedisKey)"("scard", .SCARD, "key").fxn;
pub const script = MetFactory.@"(...strings: string[])"("script", .SCRIPT).fxn;
pub const sdiff = MetFactory.@"(...strings: string[])"("sdiff", .SDIFF).fxn;
pub const subscribe = MetFactory.@"(...strings: string[])"("subscribe", .SUBSCRIBE).fxn;
pub const sdiffstore = MetFactory.@"(...strings: string[])"("sdiffstore", .SDIFFSTORE).fxn;
pub const select = MetFactory.@"(...strings: string[])"("select", .SELECT).fxn;
pub const setbit = MetFactory.@"(key: RedisKey, value: RedisValue, value2: RedisValue)"("setbit", .SETBIT, "key", "offset", "value").fxn;
@@ -1273,7 +1346,7 @@ const MetFactory = struct {
) type {
return struct {
pub fn fxn(
this: *JsValkey,
self: *JsValkey,
go: *bun.jsc.JSGlobalObject,
cf: *bun.jsc.CallFrame,
) bun.JSError!bun.jsc.JSValue {
@@ -1299,7 +1372,7 @@ const MetFactory = struct {
try args.append(another);
}
const promise = this.request(
const promise = self.request(
go,
cf.this(),
Command.initById(command, .{ .args = args.items }),

View File

@@ -20,36 +20,25 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
const SubscriptionHandlerId = u64;
const SubscriptionChannelId = u64;
const _SubscribePushHandler = *const fn (
/// The unique identifier of the listener.
handler_id: SubscriptionHandlerId,
/// The listener instance.
listener: *ValkeyListener,
/// The channel on which the message was received.
channel: []const u8,
/// The payload of the message.
payload: protocol.RESPValue,
) void;
const SubscriptionTracker = struct {
const Self = @This();
/// Object stored to track an active Pub/Sub subscription.
const SubscriptionVectorEntry = struct {
/// The listener ID associated with this subscription.
handler_id: SubscriptionHandlerId,
/// The listener function associated with this subscription.
listener: _SubscribePushHandler,
};
map: std.AutoHashMap(SubscriptionChannelId, SubscriptionMapEntry),
channel_map: std.StringHashMap(SubscriptionChannelId),
channel_map: std.StringArrayHashMap(SubscriptionChannelId),
_next_channel_id: SubscriptionChannelId,
_allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator) Self {
return Self{
.map = .init(allocator),
.channel_map = .init(allocator),
._next_channel_id = 0,
._allocator = allocator,
};
}
@@ -59,13 +48,13 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
self: *Self,
channel: []const u8,
entry: SubscriptionVectorEntry,
) !void {
) !SubscriptionChannelId {
const channel_id = self.channel_map.get(channel) orelse get_id_blk: {
const new_id = self._next_channel_id;
try self.channel_map.put(channel, new_id);
self._next_channel_id += 1;
while (self.channel_map.get(self._next_channel_id)) {
while (self.map.get(self._next_channel_id) != null) {
@branchHint(.cold);
self._next_channel_id += 1;
}
@@ -73,37 +62,29 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
break :get_id_blk new_id;
};
const map_vecs = self.map.getPtr(channel_id) orelse {
const map_vecs = self.map.getPtr(channel_id) orelse insert_blk: {
try self.map.put(
channel_id,
.{ .active = .init(), .pending = .init() },
.{ .active = .init(self._allocator), .pending = .init(self._allocator) },
);
break :insert_blk self.map.getPtr(channel_id).?;
};
try map_vecs.pending.append(entry);
return channel_id;
}
// TODO(markovejnovic): It's kind of weird that channel is passed in again here. The
// handler_id is unique after all.
pub fn promotePendingListenerToActive(
self: *Self,
channel: []const u8,
handler_id: u64,
channel_id: SubscriptionChannelId,
handler_id: SubscriptionHandlerId,
) !void {
const channel_id = self.channel_map.get(channel) orelse {
bun.Output.debugPanic(
"SubscriptionTracker.promotePendingListenerToActive did not find a " ++
"channel: {s}",
.{channel},
);
return;
};
const map_entry = self.map.getPtr(channel_id) orelse {
bun.Output.debugPanic(
"SubscriptionTracker.promotePendingListenerToActive did not find an " ++
"entry: {s}",
.{channel},
"entry for channel_id {}",
.{channel_id},
);
return;
};
@@ -117,34 +98,25 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
if (pending_idx == null) {
bun.Output.debugPanic(
"SubscriptionTracker.promotePendingListenerToActive did not find pending " ++
"handler_id {d} on channel {s}",
.{ handler_id, channel },
"handler_id {d} on channel_id {}",
.{ handler_id, channel_id },
);
return;
}
const listener = map_entry.pending.swapRemove(pending_idx.?);
map_entry.active.append(listener);
try map_entry.active.append(listener);
}
fn removeActiveHandler(
self: *Self,
channel: []const u8,
handler_id: u64,
channel_id: SubscriptionChannelId,
handler_id: SubscriptionHandlerId,
) !void {
const channel_id = self.channel_map.get(channel) orelse {
bun.Output.debugPanic(
"SubscriptionTracker.removeActiveHandler did not find a " ++
"channel: {s}",
.{channel},
);
return;
};
const map_entry = self.map.getPtr(channel_id) orelse {
bun.Output.debugPanic(
"SubscriptionTracker.removeActiveHandler did not find a channel {s}",
.{channel},
"SubscriptionTracker.removeActiveHandler did not find a channel_id {}",
.{channel_id},
);
return;
};
@@ -157,9 +129,9 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
if (active_idx == null) {
bun.Output.debugPanic(
"SubscriptionTracker.removeActiveHandler did not find pending handler_id " ++
"on channel {s}",
.{ handler_id, channel },
"SubscriptionTracker.removeActiveHandler did not find pending handler_id {}" ++
"on channel_id {}",
.{ handler_id, channel_id },
);
return;
}
@@ -171,7 +143,13 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
map_entry.pending.deinit();
_ = self.map.remove(channel_id);
_ = self.channel_map.remove(channel);
var it = self.channel_map.iterator();
while (it.next()) |entry| {
if (entry.value_ptr.* == channel_id) {
_ = self.channel_map.swapRemove(entry.key_ptr.*);
break;
}
}
}
}
@@ -251,10 +229,16 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
};
const SubscriptionUpdateRequestContext = struct {
handler_id: u64,
channel_id: u64,
user_context: UserRequestContext,
handler_id: u64,
channel_ids: std.ArrayList(SubscriptionChannelId),
request_type: enum { subscribe, unsubscribe },
pub fn deinit(self: SubscriptionUpdateRequestContext) void {
self.channel_ids.deinit();
}
};
/// Each message sent or received by the client has a context which is private to this
@@ -278,6 +262,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
},
.subscription_context => |*ctx| {
ctx.user_context.failOom(listener);
ctx.deinit();
},
}
}
@@ -289,6 +274,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
},
.subscription_context => |*ctx| {
callbacks.onRequestDropped(&ctx.user_context, reason);
ctx.deinit();
},
}
}
@@ -300,7 +286,10 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
) !void {
return switch (self.*) {
.user_context => |*ctx| callbacks.onResponse(ctx, value),
.subscription_context => |*ctx| callbacks.onResponse(&ctx.user_context, value),
.subscription_context => |*ctx| {
try callbacks.onResponse(&ctx.user_context, value);
ctx.deinit();
},
};
}
};
@@ -312,9 +301,6 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
/// All responses are paired with the original request context.
pub const ResponseType = Response(UserRequestContext);
/// Type of function invoked whenever a subscription message is received.
pub const SubscribePushHandler = _SubscribePushHandler;
/// Types internal to the ValkeyClient implementation which encode a request pending to be
/// sent to the server.
const QueuedRequestType = QueuedRequest(RequestContext);
@@ -645,16 +631,18 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
const l_state = &self._state.linked;
switch (l_state.state) {
.normal => {
// The reason we couple both states together is because we don't really know if
// what we're looking at is a SUBSCRIBE/UNSUBSCRIBE response or whether it is a
// normal response we want to pass to the user.
//
// Consequently, onNormalPacket will handle both normal responses and subscription
// responses, and will delegate behavior.
.normal, .subscriber => {
try self.onNormalPacket(value);
},
.authenticating => {
try self.onAuthenticatingPacket(value);
},
.subscriber => {
// TODO(markovejnovic): Enable this
//try self.onSubscriberPacket(value);
},
}
}
@@ -667,8 +655,30 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
return;
};
if (req.returns_bool and value.* == .Integer) {
value.* = .{ .Boolean = value.Integer > 0 };
switch (req.context) {
.user_context => {
if (req.returns_bool and value.* == .Integer) {
value.* = .{ .Boolean = value.Integer > 0 };
}
},
.subscription_context => |*ctx| {
for (ctx.channel_ids.items) |channel_id| {
switch (ctx.request_type) {
.subscribe => {
try self._subscriptions.promotePendingListenerToActive(
channel_id,
ctx.handler_id,
);
},
.unsubscribe => {
try self._subscriptions.removeActiveHandler(
channel_id,
ctx.handler_id,
);
},
}
}
},
}
try req.context.succeed(value, self._callbacks);
@@ -728,68 +738,74 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
///
/// Args:
///
/// - `channel`: The channel to subscribe to.
/// - `channels`: The channels to subscribe to.
/// - `handler_id`: Unique identifier for this listener. This identifier can be used to
/// remove the listener later. See `unsubscribeListener`.
/// - `handler`: The function to invoke whenever a message is received on this channel.
pub fn subscribe(
self: *Self,
channel: []const u8,
handler_id: u64,
handler: SubscribePushHandler,
channels: []const []const u8,
handler_id: SubscriptionHandlerId,
ctx: *const UserRequestContext,
) !void {
Self.debug("{*}.subscribe({s}, handler_id={}, ...)", .{ self, channel, handler_id });
Self.debug("{*}.subscribe({s}, handler_id={}, ...)", .{ self, channels, handler_id });
// Before we do anything, we populate the active subscriptions vector. The
// subscriptions vector is what tracks all active subscriptions. Note that this vector
// may contain subscriptions which are not yet confirmed by the server. When the
// server confirms them, we will promote them from "pending" to "active".
try self._subscriptions.addPendingHandler(channel, .{
.handler = handler,
.handler_id = handler_id,
});
// may contain subscriptions which are not yet confirmed by the server. When the server
// confirms them, we will promote them from "pending" to "active".
var channel_ids = try std.ArrayList(SubscriptionChannelId).initCapacity(
self._allocator,
channels.len,
);
for (channels) |channel| {
const channel_id = try self._subscriptions.addPendingHandler(channel, .{
.handler_id = handler_id,
});
// Now we need to send out the request.
channel_ids.appendAssumeCapacity(channel_id);
}
//switch (self._state) {
// .linked => |*l_state| {
// switch (l_state.state) {
// .subscriber, .authenticating, .normal => {
// // Great, this state can send requests.
// //
// // Note that subscribers CAN, in-fact, send requests, since RESP3
// // permits that.
// try self.enqueueRequest(req);
// },
// }
// },
// .closed => {
// // We're closed, we can't send requests until the user asks for us to
// // reconnect, explicitly.
// return error.ConnectionClosed;
// },
// else => {
// Self.debug(
// "{*} Received an unexpected request in {s} state.",
// .{ self, @tagName(self._state) },
// );
var req: InternalRequestType = .{
.command = .initById(.SUBSCRIBE, .{ .raw = channels }),
.context = .{
.subscription_context = .{
.user_context = ctx.*,
.handler_id = handler_id,
.channel_ids = channel_ids,
.request_type = .subscribe,
},
},
};
// // Okay, we're not currently in the linked state. What we can do is enqueue the
// // request and attempt to start the connection.
// try self.enqueueRequest(req);
// self.startConnecting() catch |err| {
// switch (err) {
// Error.InvalidState => {
// self._state.warnIllegalState("request-start-connecting");
// // No-op, we're already connecting or connected.
// },
// Error.FailedToOpenSocket => {
// return error.ConnectionClosed;
// },
// }
// };
// },
//}
return self.submitInternalRequest(&req);
}
fn submitInternalRequest(self: *Self, req_internal: *InternalRequestType) !void {
switch (self._state) {
.closed => {
// We're closed, we can't send requests until the user asks for us to
// reconnect, explicitly.
return error.ConnectionClosed;
},
else => {
try self.enqueueRequest(req_internal);
if (self._state != .linked) {
self.startConnecting() catch |err| {
switch (err) {
Error.InvalidState => {
self._state.warnIllegalState("request-start-connecting");
// No-op, we're already connecting or connected.
},
Error.FailedToOpenSocket => {
return error.ConnectionClosed;
},
}
};
}
},
}
}
/// Unsubscribe a previously registered handler.
@@ -1051,6 +1067,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
/// Invoked when transitioning out of the linked state to any other state.
fn onStateLinkedToAny(self: *Self, from_state: *State) !void {
self.unregisterAutoFlusher();
from_state.*.linked._egress_buffer.deinit(self._allocator);
from_state.*.linked._ingress_buffer.deinit(self._allocator);
}
@@ -1066,8 +1083,6 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
self.unregisterAutoFlusher();
self._socket_io.close();
// We also want to fail all the pending requests.
self.dropAllQueuedMessages(.{ .closing = {} });
self._callbacks.onClose();
@@ -1195,6 +1210,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
/// This will drop all queued messages and transition the client into the disconnected
/// state.
fn failIrrecoverably(self: *Self, reason: IrrecoverableFailureReason) void {
self.unregisterAutoFlusher();
self.dropAllQueuedMessages(.{ .irrecoverable_failure = reason });
}
@@ -1203,6 +1219,7 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
/// If the state transition fails, this will forcibly set the state to the disconnected
/// mode.
fn dangerouslyForceClientIntoDisconnectedState(self: *Self) void {
self.unregisterAutoFlusher();
self.dropAllQueuedMessages(.{ .closing = {} });
self._state.transition(.{ .disconnected = {} }) catch |err| {
// In the case that we observe an error to transition to closed, that's a bug in
@@ -1255,36 +1272,13 @@ pub fn ValkeyClient(comptime ValkeyListener: type, comptime UserRequestContext:
req.command.args.len(),
});
switch (self._state) {
.closed => {
// We're closed, we can't send requests until the user asks for us to
// reconnect, explicitly.
return error.ConnectionClosed;
var req_internal = Self.InternalRequestType{
.command = req.command,
.context = .{
.user_context = req.context,
},
else => {
var req_internal = Self.InternalRequestType{
.command = req.command,
.context = .{
.user_context = req.context,
},
};
try self.enqueueRequest(&req_internal);
if (self._state != .linked) {
self.startConnecting() catch |err| {
switch (err) {
Error.InvalidState => {
self._state.warnIllegalState("request-start-connecting");
// No-op, we're already connecting or connected.
},
Error.FailedToOpenSocket => {
return error.ConnectionClosed;
},
}
};
}
},
}
};
try self.submitInternalRequest(&req_internal);
}
/// Attempt to enqueue a request for sending to the server. This may choose to skip the

View File

@@ -170,4 +170,16 @@ export namespace ValkeyFaker {
// Use 1 KB max size for regular values to keep tests fast. 1kB is still a reasonably large value.
return Array.from({ length: count }, () => value(randomEngine, 1024));
}
export function channel(randomEngine: random.RandomEngine, maxSize: number = 256): string {
return random.dirtyLatin1String(randomEngine, maxSize);
}
export function channels(randomEngine: random.RandomEngine, count: number): string[] {
return Array.from({ length: count }, () => channel(randomEngine, 256));
}
export function publishMessage(randomEngine: random.RandomEngine, maxSize: number = 1024 * 1024): string {
return ValkeyFaker.value(randomEngine, maxSize);
}
}

View File

@@ -6589,7 +6589,7 @@ describeValkey(
const testKeyUniquePerDb = crypto.randomUUID();
// TODO(markovejnovic): Don't skip this.
test.skip.each([...Array(16).keys()])("Connecting to database with url $url succeeds", async (dbId: number) => {
test.skip.each([...Array(16).keys()])("Connecting to database with url %s succeeds", async (dbId: number) => {
const redis = createClient(connectionType, {}, dbId);
const testValue = await redis.get(testKeyUniquePerDb);
@@ -6614,11 +6614,23 @@ describeValkey(
await ctx.restartServer();
const valueAfterStop = await ctx.connectedClient().get(TEST_KEY);
const valueAfterStop = (await ctx.connectedClient()).get(TEST_KEY);
expect(valueAfterStop).toBe(TEST_VALUE);
});
});
describe("Pub/Sub", () => {
test.each(ValkeyFaker.channels(randomEngine, 4))("publishing to channel %s does not fail", async (channel: string) => {
const client = await ctx.connectedClient();
expect(await client.publish(channel, ValkeyFaker.publishMessage(randomEngine))).toBe(0);
});
test.each(ValkeyFaker.channels(randomEngine, 4))("subscribing to %s does not fail", async (channel: string) => {
const client = await ctx.connectedClient();
await client.subscribe(channel);
await client.unsubscribe(channel);
});
});
},
{ server: "redis://localhost:6379" },
);