diff --git a/packages/bun-types/bun.d.ts b/packages/bun-types/bun.d.ts index bfe9476bef..24b97e617b 100644 --- a/packages/bun-types/bun.d.ts +++ b/packages/bun-types/bun.d.ts @@ -4223,6 +4223,21 @@ declare module "bun" { */ requestIP(request: Request): SocketAddress | null; + /** + * Abort an in-flight HTTP(s) request, triggering the `"abort"` event and leading to a TCP RST ("Connection reset by peer") + * + * @param request The request to abort + * @returns true if the request was aborted, false if it was already aborted or if the request is not in-flight + * + * If called multiple times, it will only return true the first time. + * + * The associated `AbortSignal` will be signaled, causing the `"abort"` + * event to fire. If a `ReadableStream` is attached to the `Response`, it will + * be cancelled. If the request body has a pending promise (like `.text()`), it will + * be rejected. + */ + abort(request: Request): boolean; + /** * Reset the idleTimeout of the given Request to the number in seconds. 0 means no timeout. * diff --git a/packages/bun-uws/src/AsyncSocket.h b/packages/bun-uws/src/AsyncSocket.h index 4a0d82968c..6130035077 100644 --- a/packages/bun-uws/src/AsyncSocket.h +++ b/packages/bun-uws/src/AsyncSocket.h @@ -121,6 +121,11 @@ public: return us_socket_close(SSL, (us_socket_t *) this, 0, nullptr); } + void abort() { + this->uncorkWithoutSending(); + us_socket_close(SSL, (us_socket_t *) this, LIBUS_SOCKET_CLOSE_CODE_CONNECTION_RESET, nullptr); + } + void corkUnchecked() { /* What if another socket is corked? */ getLoopData()->setCorkedSocket(this, SSL); diff --git a/src/bun.js/api/server.classes.ts b/src/bun.js/api/server.classes.ts index 34b08171dd..783ddc4f84 100644 --- a/src/bun.js/api/server.classes.ts +++ b/src/bun.js/api/server.classes.ts @@ -25,6 +25,10 @@ function generate(name) { fn: "doReload", length: 2, }, + abort: { + fn: "doAbort", + length: 1, + }, "@@dispose": { fn: "dispose", length: 0, diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index aa087a29f0..b8f250cb69 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -1590,6 +1590,7 @@ fn NewFlags(comptime debug_mode: bool) type { has_written_status: bool = false, response_protected: bool = false, aborted: bool = false, + user_called_abort: bool = false, has_finalized: bun.DebugOnly(bool) = bun.DebugOnlyDefault(false), is_error_promise_pending: bool = false, @@ -1640,6 +1641,28 @@ pub const AnyRequestContext = struct { return self.tagged_pointer.get(T); } + pub fn abort(self: AnyRequestContext) bool { + if (self.tagged_pointer.isNull()) { + return false; + } + + switch (self.tagged_pointer.tag()) { + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPServer.RequestContext).abort(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(HTTPSServer.RequestContext).abort(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPServer.RequestContext).abort(); + }, + @field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => { + return self.tagged_pointer.as(DebugHTTPSServer.RequestContext).abort(); + }, + else => @panic("Unexpected AnyRequestContext tag"), + } + } + pub fn setTimeout(self: AnyRequestContext, seconds: c_uint) bool { if (self.tagged_pointer.isNull()) { return false; @@ -1934,7 +1957,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } fn handleResolve(ctx: *RequestContext, value: JSC.JSValue) void { - if (ctx.isAbortedOrEnded() or ctx.didUpgradeWebSocket()) { + if (ctx.isAbortedOrEnded() or ctx.didUpgradeWebSocket() or ctx.flags.user_called_abort) { return; } @@ -1954,7 +1977,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp ctx.response_jsvalue = value; assert(!ctx.flags.response_protected); ctx.flags.response_protected = true; - JSC.C.JSValueProtect(ctx.server.?.globalThis, value.asObjectRef()); + value.protect(); if (ctx.method == .HEAD) { if (ctx.resp) |resp| { @@ -1993,6 +2016,16 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp return true; } + pub fn abort(this: *RequestContext) bool { + if (this.isAbortedOrEnded() or this.flags.user_called_abort) return false; + if (this.resp) |resp| { + this.flags.user_called_abort = true; + resp.abort(); + return true; + } + return false; + } + /// destroy RequestContext, should be only called by deref or if defer_deinit_until_callback_completes is ref is set to true fn deinit(this: *RequestContext) void { this.detachResponse(); @@ -2053,7 +2086,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } fn handleReject(ctx: *RequestContext, value: JSC.JSValue) void { - if (ctx.isAbortedOrEnded()) { + if (ctx.isAbortedOrEnded() or ctx.flags.user_called_abort) { return; } @@ -2074,7 +2107,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp } } // check again in case it get aborted after runErrorHandler - if (ctx.isAbortedOrEnded()) { + if (ctx.isAbortedOrEnded() or ctx.flags.user_called_abort) { return; } @@ -2322,21 +2355,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp assert(this.resp == resp); assert(this.server != null); - var any_js_calls = false; - var vm = this.server.?.vm; - const globalThis = this.server.?.globalThis; - defer { - // This is a task in the event loop. - // If we called into JavaScript, we must drain the microtask queue - if (any_js_calls) { - vm.drainMicrotasks(); - } - } - if (this.request_weakref.get()) |request| { - if (request.internal_event_callback.trigger(Request.InternalJSEventCallback.EventType.timeout, globalThis)) { - any_js_calls = true; - } + const globalThis = this.server.?.globalThis; + request.internal_event_callback.triggerAtTopOfEventLoop(Request.InternalJSEventCallback.EventType.timeout, globalThis); } } @@ -2349,19 +2370,19 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp this.detachResponse(); var any_js_calls = false; - var vm = this.server.?.vm; + const vm: *JSC.VirtualMachine = this.server.?.vm; const globalThis = this.server.?.globalThis; + const loop = vm.eventLoop(); defer { - // This is a task in the event loop. - // If we called into JavaScript, we must drain the microtask queue if (any_js_calls) { - vm.drainMicrotasks(); + loop.exit(); } this.deref(); } if (this.request_weakref.get()) |request| { request.request_context = AnyRequestContext.Null; + loop.enter(); if (request.internal_event_callback.trigger(Request.InternalJSEventCallback.EventType.abort, globalThis)) { any_js_calls = true; } @@ -2377,6 +2398,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp signal.unref(); } if (!signal.aborted()) { + if (!any_js_calls) { + loop.enter(); + } signal.signal(globalThis, .ConnectionClosed); any_js_calls = true; } @@ -2384,6 +2408,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp //if have sink, call onAborted on sink if (this.sink) |wrapper| { + if (!any_js_calls) { + loop.enter(); + } wrapper.sink.abort(); return; } @@ -2392,6 +2419,10 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp if (this.isDeadRequest()) { this.finalizeWithoutDeinit(); } else { + if (!any_js_calls) { + loop.enter(); + } + if (this.endRequestStreaming()) { any_js_calls = true; } @@ -2402,6 +2433,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp response.body.value.Locked.readable = .{}; defer strong_readable.deinit(); if (strong_readable.get()) |readable| { + if (!any_js_calls) { + loop.enter(); + } readable.abort(globalThis); any_js_calls = true; } @@ -3248,6 +3282,11 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp return; } + // if the user called server.abort(request), we don't mind if they don't return a Response. + if (ctx.flags.user_called_abort) { + return; + } + if (response_value.isEmptyOrUndefinedOrNull()) { ctx.renderMissingInvalidResponse(response_value); return; @@ -6008,6 +6047,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp pub const doFetch = onFetch; pub const doRequestIP = JSC.wrapInstanceMethod(ThisServer, "requestIP", false); pub const doTimeout = JSC.wrapInstanceMethod(ThisServer, "timeout", false); + pub const doAbort = JSC.wrapInstanceMethod(ThisServer, "abort", false); pub fn getPlugins( this: *ThisServer, @@ -6100,6 +6140,10 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp JSValue.jsNull(); } + pub fn abort(_: *ThisServer, request: *JSC.WebCore.Request) JSC.JSValue { + return JSValue.jsBoolean(request.request_context.abort()); + } + pub fn memoryCost(this: *ThisServer) usize { return @sizeOf(ThisServer) + this.base_url_string_for_joining.len + diff --git a/src/bun.js/webcore/request.zig b/src/bun.js/webcore/request.zig index a3a63b8e22..d127872cf7 100644 --- a/src/bun.js/webcore/request.zig +++ b/src/bun.js/webcore/request.zig @@ -131,6 +131,14 @@ pub const Request = struct { return this.function.has(); } + pub fn triggerAtTopOfEventLoop(this: *InternalJSEventCallback, eventType: EventType, globalThis: *JSC.JSGlobalObject) void { + if (this.function.get()) |callback| { + globalThis.bunVM().eventLoop().runCallback(callback, globalThis, .undefined, &.{JSC.JSValue.jsNumber( + @intFromEnum(eventType), + )}); + } + } + pub fn trigger(this: *InternalJSEventCallback, eventType: EventType, globalThis: *JSC.JSGlobalObject) bool { if (this.function.get()) |callback| { _ = callback.call(globalThis, JSC.JSValue.jsUndefined(), &.{JSC.JSValue.jsNumber( diff --git a/src/deps/libuwsockets.cpp b/src/deps/libuwsockets.cpp index d3aafb25af..59ccb227a6 100644 --- a/src/deps/libuwsockets.cpp +++ b/src/deps/libuwsockets.cpp @@ -1207,6 +1207,16 @@ extern "C" } } + void uws_res_abort(int ssl, uws_res_r res) { + if (ssl) { + uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; + uwsRes->abort(); + } else { + uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; + uwsRes->abort(); + } + } + void uws_res_end_without_body(int ssl, uws_res_r res, bool close_connection) { if (ssl) diff --git a/src/deps/uws.zig b/src/deps/uws.zig index f08d381e0b..05934f74fd 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -3705,6 +3705,10 @@ pub fn NewApp(comptime ssl: bool) type { return uws_res_has_responded(ssl_flag, res.downcast()); } + pub fn abort(res: *Response) void { + uws_res_abort(ssl_flag, res.downcast()); + } + pub fn getNativeHandle(res: *Response) bun.FileDescriptor { if (comptime Environment.isWindows) { // on windows uSockets exposes SOCKET @@ -4613,3 +4617,5 @@ pub fn onThreadExit() void { extern fn uws_app_clear_routes(ssl_flag: c_int, app: *uws_app_t) void; pub extern fn us_socket_upgrade_to_tls(s: *Socket, new_context: *SocketContext, sni: ?[*:0]const u8) ?*Socket; + +extern fn uws_res_abort(ssl_flag: c_int, res: *uws_res) void; diff --git a/test/js/bun/http/bun-serve-abort.test.ts b/test/js/bun/http/bun-serve-abort.test.ts new file mode 100644 index 0000000000..938da23d40 --- /dev/null +++ b/test/js/bun/http/bun-serve-abort.test.ts @@ -0,0 +1,185 @@ +import { describe, expect, test, mock } from "bun:test"; +import { bunEnv, bunExe, rejectUnauthorizedScope, tempDirWithFiles, tls } from "harness"; + +describe("server.abort()", async () => { + test("after sleep", async () => { + using server = Bun.serve({ + port: 0, + + async fetch(request, server) { + await Bun.sleep(0); + server.abort(request); + return new Response("Hello, world!"); + }, + }); + + expect(async () => { + const response = await fetch(`http://localhost:${server.port}`); + }).toThrow("The socket connection was closed"); + }); + + test("before sleep", async () => { + using server = Bun.serve({ + port: 0, + + async fetch(request, server) { + expect(server.abort(request)).toBe(true); + await Bun.sleep(0); + // calling it again should do nothing + expect(server.abort(request)).toBe(false); + + return new Response("Hello, world!"); + }, + }); + + expect(async () => { + await fetch(`http://localhost:${server.port}`); + }).toThrow("The socket connection was closed"); + }); + + test("slightly after response is returned", async () => { + using server = Bun.serve({ + port: 0, + + async fetch(request, server) { + queueMicrotask(() => { + expect(server.abort(request)).toBe(true); + }); + return new Response("hello!"); + }, + }); + + expect(async () => { + await fetch(`http://localhost:${server.port}`); + }).toThrow("The socket connection was closed"); + }); + + test("after response was probably sent does nothing", async () => { + using server = Bun.serve({ + port: 0, + + async fetch(request, server) { + setTimeout(() => { + expect(server.abort(request)).toBe(false); + }, 0); + return new Response("hello!"); + }, + }); + + const response = await fetch(`http://localhost:${server.port}`); + expect(response.status).toBe(200); + expect(await response.text()).toBe("hello!"); + }); + + test("triggers AbortSignal", async () => { + using server = Bun.serve({ + port: 0, + + async fetch(request, server) { + const fn = mock(() => { + // already aborted. + expect(server.abort(request)).toBe(false); + }); + request.signal.addEventListener("abort", fn); + expect(server.abort(request)).toBe(true); + + // you can return undefined and it should not trigger an uncaught exception + }, + }); + + expect(async () => { + await fetch(`http://localhost:${server.port}`); + }).toThrow("The socket connection was closed"); + }); + + test("triggers AbortSignal after sleep", async () => { + using server = Bun.serve({ + port: 0, + + async fetch(request, server) { + const fn = mock(() => { + // already aborted. + expect(server.abort(request)).toBe(false); + }); + request.signal.addEventListener("abort", fn); + + await Bun.sleep(0); + expect(server.abort(request)).toBe(true); + + // you can return undefined and it should not trigger an uncaught exception + }, + }); + + expect(async () => { + await fetch(`http://localhost:${server.port}`); + }).toThrow("The socket connection was closed"); + }); + + test("works inside of a ReadableStream on the original Request with sleep", async () => { + using server = Bun.serve({ + port: 0, + + async fetch(request, server) { + return new Response( + new ReadableStream({ + async pull(controller) { + await Bun.sleep(0); + server.abort(request); + controller.close(); + }, + }), + ); + }, + }); + + expect(async () => { + await fetch(`http://localhost:${server.port}`); + }).toThrow("The socket connection was closed"); + }); + + test("works inside of a ReadableStream on the original Request without sleep", async () => { + using server = Bun.serve({ + port: 0, + + async fetch(request, server) { + return new Response( + new ReadableStream({ + pull(controller) { + server.abort(request); + controller.close(); + }, + }), + ); + }, + }); + + expect(async () => { + await fetch(`http://localhost:${server.port}`); + }).toThrow("The socket connection was closed"); + }); + + test("works inside of a ReadableStream on the original Request without sleep, with SSL", async () => { + using server = Bun.serve({ + port: 0, + tls: tls, + async fetch(request, server) { + return new Response( + new ReadableStream({ + pull(controller) { + server.abort(request); + controller.close(); + }, + }), + ); + }, + }); + + expect(async () => { + await fetch(`https://localhost:${server.port}`, { + tls: { + rejectUnauthorized: false, + }, + }); + }).toThrow("The socket connection was closed"); + }); +});