diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 6a72b6b785..8b29319968 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -1407,7 +1407,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp /// this prevents an extra pthread_getspecific() call which shows up in profiling allocator: std.mem.Allocator, req: *uws.Request, - signal: ?*JSC.WebCore.AbortSignal = null, + signal: JSC.Strong = .{}, method: HTTP.Method, flags: NewFlags(debug_mode) = .{}, @@ -1446,6 +1446,13 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp // TODO: support builtin compression const can_sendfile = !ssl_enabled and !Environment.isWindows; + pub fn getSignal(this: *const RequestContext) ?*JSC.WebCore.AbortSignal { + if (this.signal.get()) |js_signal| { + return js_signal.as(JSC.WebCore.AbortSignal); + } + return null; + } + pub inline fn isAsync(this: *const RequestContext) bool { return this.defer_deinit_until_callback_completes == null; } @@ -1837,6 +1844,24 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp return true; } + /// Check if we are aborted and signal the abort if we are + /// The signal ref will be cleaned up + /// Returns true if we are aborted + fn checkAndCleanAbortSignal(this: *RequestContext) bool { + if (this.getSignal()) |signal| { + var _signal = this.signal; + this.signal = .{}; + defer _signal.deinit(); + if (this.flags.aborted and !signal.aborted()) { + const reason = JSC.WebCore.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis); + reason.ensureStillAlive(); + _ = signal.signal(reason); + return true; + } + } + return false; + } + pub fn onAbort(this: *RequestContext, resp: *App.Response) void { assert(this.resp == resp); assert(!this.flags.aborted); @@ -1853,17 +1878,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } // if signal is not aborted, abort the signal - if (this.signal) |signal| { - this.signal = null; - if (!signal.aborted()) { - const reason = JSC.WebCore.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis); - reason.ensureStillAlive(); - _ = signal.signal(reason); - any_js_calls = true; - } - _ = signal.unref(); + if (this.checkAndCleanAbortSignal()) { + any_js_calls = true; } - //if have sink, call onAborted on sink if (this.sink) |wrapper| { wrapper.sink.abort(); @@ -1943,15 +1960,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } // if signal is not aborted, abort the signal - if (this.signal) |signal| { - this.signal = null; - if (this.flags.aborted and !signal.aborted()) { - const reason = JSC.WebCore.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis); - reason.ensureStillAlive(); - _ = signal.signal(reason); - } - _ = signal.unref(); - } + _ = this.checkAndCleanAbortSignal(); if (this.request_body) |body| { ctxLog("finalizeWithoutDeinit: request_body != null", .{}); @@ -6227,14 +6236,11 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp var body = JSC.WebCore.InitRequestBodyValue(.{ .Null = {} }) catch unreachable; ctx.request_body = body; - var signal = JSC.WebCore.AbortSignal.new(this.globalThis); - ctx.signal = signal; - request_object.* = .{ .method = ctx.method, .request_context = AnyRequestContext.init(ctx), .https = ssl_enabled, - .signal = signal.ref(), + .signal = null, .body = body.ref(), }; @@ -6297,6 +6303,8 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp const request_value = args[0]; request_value.ensureStillAlive(); + // keep a strong ref so we can signal when the request is aborted (We need to keep JS alive not only the native part) + ctx.signal = JSC.Strong.create(Request.getSignalFromJS(request_value, this.globalThis), this.globalThis); const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args); defer { @@ -6345,15 +6353,12 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp var body = JSC.WebCore.InitRequestBodyValue(.{ .Null = {} }) catch unreachable; ctx.request_body = body; - var signal = JSC.WebCore.AbortSignal.new(this.globalThis); - ctx.signal = signal; - request_object.* = .{ .method = ctx.method, .request_context = AnyRequestContext.init(ctx), .upgrader = ctx, .https = ssl_enabled, - .signal = signal.ref(), + .signal = null, .body = body.ref(), }; ctx.upgrade_context = upgrade_ctx; @@ -6365,6 +6370,9 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp }; const request_value = args[0]; request_value.ensureStillAlive(); + // keep a strong ref so we can signal when the request is aborted (We need to keep JS alive not only the native part) + ctx.signal = JSC.Strong.create(Request.getSignalFromJS(request_value, this.globalThis), this.globalThis); + const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args); defer { if (!ctx.didUpgradeWebSocket()) {} diff --git a/src/bun.js/webcore/request.zig b/src/bun.js/webcore/request.zig index 2a5b06a17c..f44c972379 100644 --- a/src/bun.js/webcore/request.zig +++ b/src/bun.js/webcore/request.zig @@ -98,6 +98,22 @@ pub const Request = struct { } } + /// Returns cached signal or generate a new JS signal and cache it. + pub fn getSignalFromJS( + jsRequest: JSC.JSValue, + globalThis: *JSC.JSGlobalObject, + ) JSC.JSValue { + if(jsRequest.as(Request)) |request| { + if(Request.signalGetCached(jsRequest)) |js_signal| { + return js_signal; + } + const signal = request.getSignal(globalThis); + Request.signalSetCached(jsRequest, globalThis, signal); + return signal; + } + return .zero; + } + pub fn init( url: bun.String, headers: ?*FetchHeaders, diff --git a/test/js/bun/http/serve.test.ts b/test/js/bun/http/serve.test.ts index ea3fc3941a..799f0a80fd 100644 --- a/test/js/bun/http/serve.test.ts +++ b/test/js/bun/http/serve.test.ts @@ -1509,3 +1509,80 @@ it("should work with dispose keyword", async () => { } expect(fetch(url)).rejects.toThrow(); }); + +it("it should call abort when the request is aborted in the middle of a stream", async () => { + const { promise, resolve } = Promise.withResolvers(); + const payload = Buffer.from("data: hello\n\n"); + using server = Bun.serve({ + port: 0, + fetch(req) { + let keepAlive = true; + req.signal.addEventListener("abort", () => { + keepAlive = false; + }); + return new Response( + new ReadableStream({ + async pull(controller) { + while (!req.signal.aborted) { + controller.enqueue(payload); + await Bun.sleep(10); + } + resolve(keepAlive); + }, + }), + { + headers: { + "Cache-Control": "no-store", + "Content-Type": "text/event-stream", + Connection: "keep-alive", + }, + }, + ); + }, + }); + + const abortController = new AbortController(); + const response = await fetch(server.url, { signal: abortController.signal }); + expect(response.status).toBe(200); + abortController.abort(); + expect(await promise).toBe(false); +}); + +it("it should call abort when the request is aborted in the middle of a stream using async fetch", async () => { + const { promise, resolve } = Promise.withResolvers(); + const payload = Buffer.from("data: hello\n\n"); + using server = Bun.serve({ + port: 0, + async fetch(req) { + await Bun.sleep(10); + let keepAlive = true; + req.signal.addEventListener("abort", () => { + keepAlive = false; + }); + return new Response( + new ReadableStream({ + async pull(controller) { + while (!req.signal.aborted) { + controller.enqueue(payload); + await Bun.sleep(10); + } + resolve(keepAlive); + }, + }), + { + headers: { + "Cache-Control": "no-store", + "Content-Type": "text/event-stream", + Connection: "keep-alive", + }, + }, + ); + }, + }); + + const abortController = new AbortController(); + const response = await fetch(server.url, { signal: abortController.signal }); + expect(response.status).toBe(200); + abortController.abort(); + expect(await promise).toBe(false); +});