From 2580d199a482c822abc66b422c68c764f1f4e979 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Tue, 4 Jun 2024 23:29:21 -0700 Subject: [PATCH] fix some worker-related stability issues (#11494) --- src/bun.js/api/bun/socket.zig | 84 ++++++++++++++--- src/bun.js/api/server.zig | 35 ++++++-- src/bun.js/event_loop.zig | 4 +- src/bun.js/javascript.zig | 22 ++++- src/bun.js/web_worker.zig | 39 +++++--- src/bun.js/webcore/streams.zig | 17 +++- src/http/websocket_http_client.zig | 7 +- .../worker_threads/worker_destruction.test.ts | 11 +++ .../worker_threads/worker_thread_check.ts | 89 +++++++++++++++++++ 9 files changed, 267 insertions(+), 41 deletions(-) create mode 100644 test/js/node/worker_threads/worker_destruction.test.ts create mode 100644 test/js/node/worker_threads/worker_thread_check.ts diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 0e7483182e..b4f69dea5d 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -197,12 +197,22 @@ const Handlers = struct { // corker: Corker = .{}, pub fn resolvePromise(this: *Handlers, value: JSValue) void { + const vm = this.vm; + if (vm.isShuttingDown()) { + return; + } + const promise = this.promise.trySwap() orelse return; const anyPromise = promise.asAnyPromise() orelse return; anyPromise.resolve(this.globalObject, value); } pub fn rejectPromise(this: *Handlers, value: JSValue) bool { + const vm = this.vm; + if (vm.isShuttingDown()) { + return true; + } + const promise = this.promise.trySwap() orelse return false; const anyPromise = promise.asAnyPromise() orelse return false; anyPromise.reject(this.globalObject, value); @@ -233,17 +243,24 @@ const Handlers = struct { } pub fn callErrorHandler(this: *Handlers, thisValue: JSValue, err: []const JSValue) bool { + const vm = this.vm; + if (vm.isShuttingDown()) { + return false; + } + + const globalObject = this.globalObject; const onError = this.onError; + if (onError == .zero) { if (err.len > 0) - _ = this.vm.uncaughtException(this.globalObject, err[0], false); + _ = vm.uncaughtException(globalObject, err[0], false); return false; } - const result = onError.callWithThis(this.globalObject, thisValue, err); + const result = onError.callWithThis(globalObject, thisValue, err); if (result.isAnyError()) { - _ = this.vm.uncaughtException(this.globalObject, result, false); + _ = vm.uncaughtException(globalObject, result, false); } return true; @@ -303,6 +320,10 @@ const Handlers = struct { } pub fn unprotect(this: *Handlers) void { + if (this.vm.isShuttingDown()) { + return; + } + if (comptime Environment.allow_assert) { bun.assert(this.protection_count > 0); this.protection_count -= 1; @@ -882,6 +903,11 @@ pub const Listener = struct { pub fn finalize(this: *Listener) callconv(.C) void { log("Finalize", .{}); + if (this.listener) |listener| { + this.listener = null; + listener.close(this.ssl); + } + this.deinit(); } @@ -890,11 +916,17 @@ pub const Listener = struct { this.strong_data.deinit(); this.poll_ref.unref(this.handlers.vm); bun.assert(this.listener == null); - bun.assert(this.handlers.active_connections == 0); this.handlers.unprotect(); - if (this.socket_context) |ctx| { - ctx.deinit(this.ssl); + if (this.handlers.active_connections > 0) { + if (this.socket_context) |ctx| { + ctx.close(this.ssl); + } + // TODO: fix this leak. + } else { + if (this.socket_context) |ctx| { + ctx.deinit(this.ssl); + } } this.connection.deinit(); @@ -1144,6 +1176,10 @@ fn NewSocket(comptime ssl: bool) type { pub const Socket = uws.NewSocketHandler(ssl); socket: Socket, detached: bool = false, + + /// Prevent onClose from calling into JavaScript while we are finalizing + finalizing: bool = false, + wrapped: WrappedType = .none, handlers: *Handlers, this_value: JSC.JSValue = .zero, @@ -1227,6 +1263,9 @@ fn NewSocket(comptime ssl: bool) type { if (callback == .zero) return; var vm = handlers.vm; + if (vm.isShuttingDown()) { + return; + } vm.eventLoop().enter(); defer vm.eventLoop().exit(); @@ -1250,7 +1289,10 @@ fn NewSocket(comptime ssl: bool) type { const handlers = this.handlers; const callback = handlers.onTimeout; - if (callback == .zero) return; + if (callback == .zero or this.finalizing) return; + if (handlers.vm.isShuttingDown()) { + return; + } // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can @@ -1276,6 +1318,9 @@ fn NewSocket(comptime ssl: bool) type { const handlers = this.handlers; const vm = handlers.vm; this.poll_ref.unrefOnNextTick(vm); + if (vm.isShuttingDown()) { + return; + } const callback = handlers.onConnectError; const globalObject = handlers.globalObject; @@ -1351,7 +1396,6 @@ fn NewSocket(comptime ssl: bool) type { } this.is_active = false; const vm = this.handlers.vm; - this.handlers.markInactive(ssl, this.socket.context(), this.wrapped); this.poll_ref.unref(vm); this.has_pending_activity.store(false, .Release); @@ -1461,7 +1505,7 @@ fn NewSocket(comptime ssl: bool) type { const handlers = this.handlers; const callback = handlers.onEnd; - if (callback == .zero) { + if (callback == .zero or handlers.vm.isShuttingDown()) { this.poll_ref.unref(handlers.vm); // If you don't handle TCP fin, we assume you're done. @@ -1497,6 +1541,10 @@ fn NewSocket(comptime ssl: bool) type { var callback = handlers.onHandshake; var is_open = false; + if (handlers.vm.isShuttingDown()) { + return; + } + // Use open callback when handshake is not provided if (callback == .zero) { callback = handlers.onOpen; @@ -1563,13 +1611,23 @@ fn NewSocket(comptime ssl: bool) type { this.detached = true; defer this.markInactive(); + if (this.finalizing) { + return; + } + const handlers = this.handlers; - this.poll_ref.unref(handlers.vm); + const vm = handlers.vm; + this.poll_ref.unref(vm); const callback = handlers.onClose; + if (callback == .zero) return; + if (vm.isShuttingDown()) { + return; + } + // the handlers must be kept alive for the duration of the function call // that way if we need to call the error handler, we can var scope = handlers.enter(socket.context()); @@ -1594,7 +1652,10 @@ fn NewSocket(comptime ssl: bool) type { const handlers = this.handlers; const callback = handlers.onData; - if (callback == .zero) return; + if (callback == .zero or this.finalizing) return; + if (handlers.vm.isShuttingDown()) { + return; + } const globalObject = handlers.globalObject; const this_value = this.getThisValue(globalObject); @@ -2034,6 +2095,7 @@ fn NewSocket(comptime ssl: bool) type { pub fn finalize(this: *This) callconv(.C) void { log("finalize() {d}", .{@intFromPtr(this)}); + this.finalizing = true; if (!this.detached) { this.detached = true; if (!this.socket.isClosed()) { diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 988b416610..a5b531022b 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -3639,6 +3639,10 @@ pub const WebSocketServer = struct { } pub fn unprotect(this: Handler) void { + if (this.vm.isShuttingDown()) { + return; + } + this.onOpen.unprotect(); this.onMessage.unprotect(); this.onClose.unprotect(); @@ -3905,10 +3909,16 @@ pub const ServerWebSocket = struct { const value_to_cache = this.this_value; var handler = this.handler; + const vm = this.handler.vm; handler.active_connections +|= 1; const globalObject = handler.globalObject; - const onOpenHandler = handler.onOpen; + if (vm.isShuttingDown()) { + log("onOpen called after script execution", .{}); + ws.close(); + return; + } + this.this_value = .zero; this.flags.opened = false; if (value_to_cache != .zero) { @@ -3919,7 +3929,7 @@ pub const ServerWebSocket = struct { if (onOpenHandler.isEmptyOrUndefinedOrNull()) return; const this_value = this.getThisValue(); var args = [_]JSValue{this_value}; - const vm = this.handler.vm; + const loop = vm.eventLoop(); loop.enter(); defer loop.exit(); @@ -3974,6 +3984,12 @@ pub const ServerWebSocket = struct { var globalObject = this.handler.globalObject; // This is the start of a task. const vm = this.handler.vm; + if (vm.isShuttingDown()) { + log("onMessage called after script execution", .{}); + ws.close(); + return; + } + const loop = vm.eventLoop(); loop.enter(); defer loop.exit(); @@ -4027,7 +4043,8 @@ pub const ServerWebSocket = struct { log("onDrain", .{}); const handler = this.handler; - if (this.isClosed()) + const vm = handler.vm; + if (this.isClosed() or vm.isShuttingDown()) return; if (handler.onDrain != .zero) { @@ -4038,7 +4055,6 @@ pub const ServerWebSocket = struct { .globalObject = globalObject, .callback = handler.onDrain, }; - const vm = JSC.VirtualMachine.get(); const loop = vm.eventLoop(); loop.enter(); defer loop.exit(); @@ -4075,11 +4091,10 @@ pub const ServerWebSocket = struct { const handler = this.handler; var cb = handler.onPing; - if (cb.isEmptyOrUndefinedOrNull()) return; + const vm = handler.vm; + if (cb.isEmptyOrUndefinedOrNull() or vm.isShuttingDown()) return; const globalThis = handler.globalObject; - const vm = JSC.VirtualMachine.get(); - // This is the start of a task. const loop = vm.eventLoop(); loop.enter(); @@ -4106,6 +4121,8 @@ pub const ServerWebSocket = struct { const globalThis = handler.globalObject; const vm = handler.vm; + if (vm.isShuttingDown()) return; + // This is the start of a task. const loop = vm.eventLoop(); loop.enter(); @@ -4133,10 +4150,12 @@ pub const ServerWebSocket = struct { } } + const vm = handler.vm; + if (vm.isShuttingDown()) return; + if (!handler.onClose.isEmptyOrUndefinedOrNull()) { var str = ZigString.init(message); const globalObject = handler.globalObject; - const vm = handler.vm; const loop = vm.eventLoop(); loop.enter(); defer loop.exit(); diff --git a/src/bun.js/event_loop.zig b/src/bun.js/event_loop.zig index 4d49202cce..fd7db0e2bb 100644 --- a/src/bun.js/event_loop.zig +++ b/src/bun.js/event_loop.zig @@ -1507,10 +1507,10 @@ pub const EventLoop = struct { const worker = this.virtual_machine.worker orelse @panic("EventLoop.waitForPromiseWithTermination: worker is not initialized"); switch (promise.status(this.virtual_machine.jsc)) { JSC.JSPromise.Status.Pending => { - while (!worker.requested_terminate and promise.status(this.virtual_machine.jsc) == .Pending) { + while (!worker.hasRequestedTerminate() and promise.status(this.virtual_machine.jsc) == .Pending) { this.tick(); - if (!worker.requested_terminate and promise.status(this.virtual_machine.jsc) == .Pending) { + if (!worker.hasRequestedTerminate() and promise.status(this.virtual_machine.jsc) == .Pending) { this.autoTick(); } } diff --git a/src/bun.js/javascript.zig b/src/bun.js/javascript.zig index e163bcca40..1b8f11c4e3 100644 --- a/src/bun.js/javascript.zig +++ b/src/bun.js/javascript.zig @@ -752,7 +752,7 @@ pub const VirtualMachine = struct { return this.debugger != null; } - pub inline fn isShuttingDown(this: *const VirtualMachine) bool { + pub fn isShuttingDown(this: *const VirtualMachine) bool { return this.is_shutting_down; } @@ -924,6 +924,11 @@ pub const VirtualMachine = struct { extern fn Bun__Process__exit(*JSC.JSGlobalObject, code: c_int) noreturn; pub fn unhandledRejection(this: *JSC.VirtualMachine, globalObject: *JSC.JSGlobalObject, reason: JSC.JSValue, promise: JSC.JSValue) bool { + if (this.isShuttingDown()) { + Output.debugWarn("unhandledRejection during shutdown.", .{}); + return true; + } + if (isBunTest) { this.unhandled_error_counter += 1; this.onUnhandledRejection(this, globalObject, reason); @@ -939,6 +944,11 @@ pub const VirtualMachine = struct { } pub fn uncaughtException(this: *JSC.VirtualMachine, globalObject: *JSC.JSGlobalObject, err: JSC.JSValue, is_rejection: bool) bool { + if (this.isShuttingDown()) { + Output.debugWarn("uncaughtException during shutdown.", .{}); + return true; + } + if (isBunTest) { this.unhandled_error_counter += 1; this.onUnhandledRejection(this, globalObject, err); @@ -1066,9 +1076,13 @@ pub const VirtualMachine = struct { } } - pub fn scriptExecutionStatus(this: *VirtualMachine) callconv(.C) JSC.ScriptExecutionStatus { + pub fn scriptExecutionStatus(this: *const VirtualMachine) callconv(.C) JSC.ScriptExecutionStatus { + if (this.is_shutting_down) { + return .stopped; + } + if (this.worker) |worker| { - if (worker.requested_terminate) { + if (worker.hasRequestedTerminate()) { return .stopped; } } @@ -2463,7 +2477,7 @@ pub const VirtualMachine = struct { .Internal = promise, }); if (this.worker) |worker| { - if (worker.requested_terminate) { + if (worker.hasRequestedTerminate()) { return error.WorkerTerminated; } } diff --git a/src/bun.js/web_worker.zig b/src/bun.js/web_worker.zig index bc84b9f758..0cae38391f 100644 --- a/src/bun.js/web_worker.zig +++ b/src/bun.js/web_worker.zig @@ -7,13 +7,15 @@ const JSValue = JSC.JSValue; const Async = bun.Async; const WTFStringImpl = @import("../string.zig").WTFStringImpl; +const Bool = std.atomic.Value(bool); + /// Shared implementation of Web and Node `Worker` pub const WebWorker = struct { /// null when haven't started yet vm: ?*JSC.VirtualMachine = null, status: std.atomic.Value(Status) = std.atomic.Value(Status).init(.start), /// To prevent UAF, the `spin` function (aka the worker's event loop) will call deinit once this is set and properly exit the loop. - requested_terminate: bool = false, + requested_terminate: Bool = Bool.init(false), execution_context_id: u32 = 0, parent_context_id: u32 = 0, parent: *JSC.VirtualMachine, @@ -51,6 +53,14 @@ pub const WebWorker = struct { return worker.cpp_worker; } + pub fn hasRequestedTerminate(this: *const WebWorker) bool { + return this.requested_terminate.load(.Monotonic); + } + + pub fn setRequestedTerminate(this: *WebWorker) bool { + return this.requested_terminate.swap(true, .Release); + } + export fn WebWorker__updatePtr(worker: *WebWorker, ptr: *anyopaque) bool { worker.cpp_worker = ptr; @@ -145,7 +155,7 @@ pub const WebWorker = struct { Output.Source.configureNamedThread("Worker"); } - if (this.requested_terminate) { + if (this.hasRequestedTerminate()) { this.deinit(); return; } @@ -242,7 +252,7 @@ pub const WebWorker = struct { JSC.markBinding(@src()); WebWorker__dispatchError(globalObject, worker.cpp_worker, bun.String.createUTF8(array.toOwnedSliceLeaky()), error_instance); if (vm.worker) |worker_| { - worker.requested_terminate = true; + _ = worker.setRequestedTerminate(); worker.parent_poll_ref.unrefConcurrently(worker.parent); worker_.exitAndDeinit(); } @@ -302,15 +312,15 @@ pub const WebWorker = struct { while (vm.isEventLoopAlive()) { vm.tick(); - if (this.requested_terminate) break; + if (this.hasRequestedTerminate()) break; vm.eventLoop().autoTickActive(); - if (this.requested_terminate) break; + if (this.hasRequestedTerminate()) break; } - log("[{d}] before exit {s}", .{ this.execution_context_id, if (this.requested_terminate) "(terminated)" else "(event loop dead)" }); + log("[{d}] before exit {s}", .{ this.execution_context_id, if (this.hasRequestedTerminate()) "(terminated)" else "(event loop dead)" }); // Only call "beforeExit" if we weren't from a .terminate - if (!this.requested_terminate) { + if (!this.hasRequestedTerminate()) { // TODO: is this able to allow the event loop to continue? vm.onBeforeExit(); } @@ -322,9 +332,14 @@ pub const WebWorker = struct { /// This is worker.ref()/.unref() from JS (Caller thread) pub fn setRef(this: *WebWorker, value: bool) callconv(.C) void { - if (this.requested_terminate) { + if (this.hasRequestedTerminate()) { return; } + + this.setRefInternal(value); + } + + pub fn setRefInternal(this: *WebWorker, value: bool) void { if (value) { this.parent_poll_ref.ref(this.parent); } else { @@ -338,16 +353,16 @@ pub const WebWorker = struct { if (this.status.load(.Acquire) == .terminated) { return; } - if (this.requested_terminate) { + if (this.setRequestedTerminate()) { return; } log("[{d}] requestTerminate", .{this.execution_context_id}); - this.setRef(false); - this.requested_terminate = true; + if (this.vm) |vm| { - vm.jsc.notifyNeedTermination(); vm.eventLoop().wakeup(); } + + this.setRefInternal(false); } /// This handles cleanup, emitting the "close" event, and deinit. diff --git a/src/bun.js/webcore/streams.zig b/src/bun.js/webcore/streams.zig index ac40717733..ed788debf4 100644 --- a/src/bun.js/webcore/streams.zig +++ b/src/bun.js/webcore/streams.zig @@ -669,6 +669,14 @@ pub const StreamResult = union(Tag) { into_array: IntoArray, into_array_and_done: IntoArray, + pub fn deinit(this: *StreamResult) void { + switch (this.*) { + .owned => |*owned| owned.deinitWithAllocator(bun.default_allocator), + .owned_and_done => |*owned_and_done| owned_and_done.deinitWithAllocator(bun.default_allocator), + else => {}, + } + } + pub const Err = enum { Error, JSValue, @@ -921,7 +929,8 @@ pub const StreamResult = union(Tag) { } pub fn fulfillPromise(result: *StreamResult, promise: *JSC.JSPromise, globalThis: *JSC.JSGlobalObject) void { - const loop = globalThis.bunVM().eventLoop(); + const vm = globalThis.bunVM(); + const loop = vm.eventLoop(); const promise_value = promise.asValue(globalThis); defer promise_value.unprotect(); @@ -954,6 +963,12 @@ pub const StreamResult = union(Tag) { } pub fn toJS(this: *const StreamResult, globalThis: *JSGlobalObject) JSValue { + if (JSC.VirtualMachine.get().isShuttingDown()) { + var that = this.*; + that.deinit(); + return .zero; + } + switch (this.*) { .owned => |list| { return JSC.ArrayBuffer.fromBytes(list.slice(), .Uint8Array).toJS(globalThis, null); diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index a496455582..deaa695ba6 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -276,11 +276,13 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { header_values: ?[*]const JSC.ZigString, header_count: usize, ) callconv(.C) ?*HTTPClient { - bun.assert(global.bunVM().event_loop_handle != null); + const vm = global.bunVM(); + + bun.assert(vm.event_loop_handle != null); var client_protocol_hash: u64 = 0; const body = buildRequestBody( - global.bunVM(), + vm, pathname, ssl, host, @@ -289,7 +291,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { &client_protocol_hash, NonUTF8Headers.init(header_names, header_values, header_count), ) catch return null; - const vm = global.bunVM(); var client = HTTPClient.new(.{ .tcp = null, diff --git a/test/js/node/worker_threads/worker_destruction.test.ts b/test/js/node/worker_threads/worker_destruction.test.ts new file mode 100644 index 0000000000..0f2a1afbdc --- /dev/null +++ b/test/js/node/worker_threads/worker_destruction.test.ts @@ -0,0 +1,11 @@ +import { test, describe, expect } from "bun:test"; +import { $ } from "bun"; +import { join } from "path"; +import "harness"; + +describe("Worker destruction", () => { + const method = ["Bun.connect", "Bun.listen"]; + test.each(method)("bun closes cleanly when %s is used in a Worker that is terminating", method => { + expect([join(import.meta.dir, "worker_thread_check.ts"), method]).toRun(); + }); +}); diff --git a/test/js/node/worker_threads/worker_thread_check.ts b/test/js/node/worker_threads/worker_thread_check.ts new file mode 100644 index 0000000000..48ae9de7a1 --- /dev/null +++ b/test/js/node/worker_threads/worker_thread_check.ts @@ -0,0 +1,89 @@ +const CONCURRENCY = 10; +const RUN_COUNT = 5; + +import { Worker, isMainThread, workerData } from "worker_threads"; + +const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + +if (isMainThread) { + let action = process.argv.at(-1); + if (process.argv.length === 2) { + action = "Bun.connect"; + } + + const server = Bun.serve({ + port: 0, + fetch() { + return new Response(); + }, + }); + let remaining = RUN_COUNT; + + while (remaining--) { + const promises = []; + + for (let i = 0; i < CONCURRENCY; i++) { + const worker = new Worker(import.meta.url, { + workerData: { + action, + port: server.port, + }, + }); + worker.ref(); + const { promise, resolve } = Promise.withResolvers(); + promises.push(promise); + + worker.on("online", () => { + sleep(1) + .then(() => { + return worker.terminate(); + }) + .finally(resolve); + }); + } + + await Promise.all(promises); + console.log(`Spawned ${CONCURRENCY} workers`, "RSS", (process.memoryUsage().rss / 1024 / 1024) | 0, "MB"); + Bun.gc(true); + } + server.stop(true); +} else { + Bun.gc(true); + const { action, port } = workerData; + + switch (action) { + case "Bun.connect": { + await Bun.connect({ + hostname: "localhost", + port, + socket: { + open() {}, + error() {}, + data() {}, + drain() {}, + close() {}, + }, + }); + break; + } + case "Bun.listen": { + const server = Bun.listen({ + hostname: "localhost", + port: 0, + socket: { + open() {}, + error() {}, + data() {}, + drain() {}, + close() {}, + }, + }); + break; + } + case "fetch": { + const resp = await fetch("http://localhost:" + port); + await resp.blob(); + break; + } + } +}