From bdb70d5bc29763bc5ece019366ec566c588f4f37 Mon Sep 17 00:00:00 2001 From: Cameron Haley <42698419+camero2734@users.noreply.github.com> Date: Wed, 21 Feb 2024 23:19:43 +0100 Subject: [PATCH] Account for initial_thread_count in napi threadsafe_function logic (#9035) --- src/napi/napi.zig | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/src/napi/napi.zig b/src/napi/napi.zig index a3eb87860e..6af48627e4 100644 --- a/src/napi/napi.zig +++ b/src/napi/napi.zig @@ -1279,7 +1279,7 @@ pub const ThreadSafeFunction = struct { /// prevent it from being destroyed. poll_ref: Async.KeepAlive, - owning_threads: std.AutoArrayHashMapUnmanaged(u64, void) = .{}, + thread_count: usize = 0, owning_thread_lock: Lock = Lock.init(), event_loop: *JSC.EventLoop, @@ -1422,24 +1422,33 @@ pub const ThreadSafeFunction = struct { defer this.owning_thread_lock.unlock(); if (this.channel.isClosed()) return error.Closed; - _ = this.owning_threads.getOrPut(bun.default_allocator, std.Thread.getCurrentId()) catch unreachable; + this.thread_count += 1; } - pub fn release(this: *ThreadSafeFunction, mode: napi_threadsafe_function_release_mode) void { + pub fn release(this: *ThreadSafeFunction, mode: napi_threadsafe_function_release_mode) napi_status { this.owning_thread_lock.lock(); defer this.owning_thread_lock.unlock(); - if (!this.owning_threads.swapRemove(std.Thread.getCurrentId())) - return; + + if (this.thread_count == 0) { + return invalidArg(); + } + + this.thread_count -= 1; + + if (this.channel.isClosed()) { + return .ok; + } if (mode == .abort) { this.channel.close(); } - if (this.owning_threads.count() == 0) { + if (mode == .abort or this.thread_count == 0) { this.finalizer_task = JSC.AnyTask{ .ctx = this, .callback = finalize }; this.event_loop.enqueueTaskConcurrent(JSC.ConcurrentTask.fromCallback(this, finalize)); - return; } + + return .ok; } }; @@ -1479,10 +1488,10 @@ pub export fn napi_create_threadsafe_function( }, .ctx = context, .channel = ThreadSafeFunction.Queue.init(max_queue_size, bun.default_allocator), - .owning_threads = .{}, + .thread_count = initial_thread_count, .poll_ref = Async.KeepAlive.init(), }; - function.owning_threads.ensureTotalCapacity(bun.default_allocator, initial_thread_count) catch return genericFailure(); + function.finalizer = .{ .ctx = thread_finalize_data, .fun = thread_finalize_cb }; result.* = function; return .ok; @@ -1512,8 +1521,7 @@ pub export fn napi_acquire_threadsafe_function(func: napi_threadsafe_function) n } pub export fn napi_release_threadsafe_function(func: napi_threadsafe_function, mode: napi_threadsafe_function_release_mode) napi_status { log("napi_release_threadsafe_function", .{}); - func.release(mode); - return .ok; + return func.release(mode); } pub export fn napi_unref_threadsafe_function(env: napi_env, func: napi_threadsafe_function) napi_status { log("napi_unref_threadsafe_function", .{});