Compare commits

...

1 Commits

Author SHA1 Message Date
Jarred Sumner
ee239d7159 Introduce server.abort(request: Request) to abort an in-flight HTTP/HTTPs request 2025-01-28 00:01:02 -08:00
8 changed files with 299 additions and 22 deletions

View File

@@ -4223,6 +4223,21 @@ declare module "bun" {
*/
requestIP(request: Request): SocketAddress | null;
/**
* Abort an in-flight HTTP(s) request, triggering the `"abort"` event and leading to a TCP RST ("Connection reset by peer")
*
* @param request The request to abort
* @returns true if the request was aborted, false if it was already aborted or if the request is not in-flight
*
* If called multiple times, it will only return true the first time.
*
* The associated `AbortSignal` will be signaled, causing the `"abort"`
* event to fire. If a `ReadableStream` is attached to the `Response`, it will
* be cancelled. If the request body has a pending promise (like `.text()`), it will
* be rejected.
*/
abort(request: Request): boolean;
/**
* Reset the idleTimeout of the given Request to the number in seconds. 0 means no timeout.
*

View File

@@ -121,6 +121,11 @@ public:
return us_socket_close(SSL, (us_socket_t *) this, 0, nullptr);
}
void abort() {
this->uncorkWithoutSending();
us_socket_close(SSL, (us_socket_t *) this, LIBUS_SOCKET_CLOSE_CODE_CONNECTION_RESET, nullptr);
}
void corkUnchecked() {
/* What if another socket is corked? */
getLoopData()->setCorkedSocket(this, SSL);

View File

@@ -25,6 +25,10 @@ function generate(name) {
fn: "doReload",
length: 2,
},
abort: {
fn: "doAbort",
length: 1,
},
"@@dispose": {
fn: "dispose",
length: 0,

View File

@@ -1590,6 +1590,7 @@ fn NewFlags(comptime debug_mode: bool) type {
has_written_status: bool = false,
response_protected: bool = false,
aborted: bool = false,
user_called_abort: bool = false,
has_finalized: bun.DebugOnly(bool) = bun.DebugOnlyDefault(false),
is_error_promise_pending: bool = false,
@@ -1640,6 +1641,28 @@ pub const AnyRequestContext = struct {
return self.tagged_pointer.get(T);
}
pub fn abort(self: AnyRequestContext) bool {
if (self.tagged_pointer.isNull()) {
return false;
}
switch (self.tagged_pointer.tag()) {
@field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPServer.RequestContext))) => {
return self.tagged_pointer.as(HTTPServer.RequestContext).abort();
},
@field(Pointer.Tag, bun.meta.typeBaseName(@typeName(HTTPSServer.RequestContext))) => {
return self.tagged_pointer.as(HTTPSServer.RequestContext).abort();
},
@field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPServer.RequestContext))) => {
return self.tagged_pointer.as(DebugHTTPServer.RequestContext).abort();
},
@field(Pointer.Tag, bun.meta.typeBaseName(@typeName(DebugHTTPSServer.RequestContext))) => {
return self.tagged_pointer.as(DebugHTTPSServer.RequestContext).abort();
},
else => @panic("Unexpected AnyRequestContext tag"),
}
}
pub fn setTimeout(self: AnyRequestContext, seconds: c_uint) bool {
if (self.tagged_pointer.isNull()) {
return false;
@@ -1934,7 +1957,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
fn handleResolve(ctx: *RequestContext, value: JSC.JSValue) void {
if (ctx.isAbortedOrEnded() or ctx.didUpgradeWebSocket()) {
if (ctx.isAbortedOrEnded() or ctx.didUpgradeWebSocket() or ctx.flags.user_called_abort) {
return;
}
@@ -1954,7 +1977,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
ctx.response_jsvalue = value;
assert(!ctx.flags.response_protected);
ctx.flags.response_protected = true;
JSC.C.JSValueProtect(ctx.server.?.globalThis, value.asObjectRef());
value.protect();
if (ctx.method == .HEAD) {
if (ctx.resp) |resp| {
@@ -1993,6 +2016,16 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
return true;
}
pub fn abort(this: *RequestContext) bool {
if (this.isAbortedOrEnded() or this.flags.user_called_abort) return false;
if (this.resp) |resp| {
this.flags.user_called_abort = true;
resp.abort();
return true;
}
return false;
}
/// destroy RequestContext, should be only called by deref or if defer_deinit_until_callback_completes is ref is set to true
fn deinit(this: *RequestContext) void {
this.detachResponse();
@@ -2053,7 +2086,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
fn handleReject(ctx: *RequestContext, value: JSC.JSValue) void {
if (ctx.isAbortedOrEnded()) {
if (ctx.isAbortedOrEnded() or ctx.flags.user_called_abort) {
return;
}
@@ -2074,7 +2107,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
}
// check again in case it get aborted after runErrorHandler
if (ctx.isAbortedOrEnded()) {
if (ctx.isAbortedOrEnded() or ctx.flags.user_called_abort) {
return;
}
@@ -2322,21 +2355,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
assert(this.resp == resp);
assert(this.server != null);
var any_js_calls = false;
var vm = this.server.?.vm;
const globalThis = this.server.?.globalThis;
defer {
// This is a task in the event loop.
// If we called into JavaScript, we must drain the microtask queue
if (any_js_calls) {
vm.drainMicrotasks();
}
}
if (this.request_weakref.get()) |request| {
if (request.internal_event_callback.trigger(Request.InternalJSEventCallback.EventType.timeout, globalThis)) {
any_js_calls = true;
}
const globalThis = this.server.?.globalThis;
request.internal_event_callback.triggerAtTopOfEventLoop(Request.InternalJSEventCallback.EventType.timeout, globalThis);
}
}
@@ -2349,19 +2370,19 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this.detachResponse();
var any_js_calls = false;
var vm = this.server.?.vm;
const vm: *JSC.VirtualMachine = this.server.?.vm;
const globalThis = this.server.?.globalThis;
const loop = vm.eventLoop();
defer {
// This is a task in the event loop.
// If we called into JavaScript, we must drain the microtask queue
if (any_js_calls) {
vm.drainMicrotasks();
loop.exit();
}
this.deref();
}
if (this.request_weakref.get()) |request| {
request.request_context = AnyRequestContext.Null;
loop.enter();
if (request.internal_event_callback.trigger(Request.InternalJSEventCallback.EventType.abort, globalThis)) {
any_js_calls = true;
}
@@ -2377,6 +2398,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
signal.unref();
}
if (!signal.aborted()) {
if (!any_js_calls) {
loop.enter();
}
signal.signal(globalThis, .ConnectionClosed);
any_js_calls = true;
}
@@ -2384,6 +2408,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
//if have sink, call onAborted on sink
if (this.sink) |wrapper| {
if (!any_js_calls) {
loop.enter();
}
wrapper.sink.abort();
return;
}
@@ -2392,6 +2419,10 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
if (this.isDeadRequest()) {
this.finalizeWithoutDeinit();
} else {
if (!any_js_calls) {
loop.enter();
}
if (this.endRequestStreaming()) {
any_js_calls = true;
}
@@ -2402,6 +2433,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
response.body.value.Locked.readable = .{};
defer strong_readable.deinit();
if (strong_readable.get()) |readable| {
if (!any_js_calls) {
loop.enter();
}
readable.abort(globalThis);
any_js_calls = true;
}
@@ -3248,6 +3282,11 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
return;
}
// if the user called server.abort(request), we don't mind if they don't return a Response.
if (ctx.flags.user_called_abort) {
return;
}
if (response_value.isEmptyOrUndefinedOrNull()) {
ctx.renderMissingInvalidResponse(response_value);
return;
@@ -6008,6 +6047,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
pub const doFetch = onFetch;
pub const doRequestIP = JSC.wrapInstanceMethod(ThisServer, "requestIP", false);
pub const doTimeout = JSC.wrapInstanceMethod(ThisServer, "timeout", false);
pub const doAbort = JSC.wrapInstanceMethod(ThisServer, "abort", false);
pub fn getPlugins(
this: *ThisServer,
@@ -6100,6 +6140,10 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
JSValue.jsNull();
}
pub fn abort(_: *ThisServer, request: *JSC.WebCore.Request) JSC.JSValue {
return JSValue.jsBoolean(request.request_context.abort());
}
pub fn memoryCost(this: *ThisServer) usize {
return @sizeOf(ThisServer) +
this.base_url_string_for_joining.len +

View File

@@ -131,6 +131,14 @@ pub const Request = struct {
return this.function.has();
}
pub fn triggerAtTopOfEventLoop(this: *InternalJSEventCallback, eventType: EventType, globalThis: *JSC.JSGlobalObject) void {
if (this.function.get()) |callback| {
globalThis.bunVM().eventLoop().runCallback(callback, globalThis, .undefined, &.{JSC.JSValue.jsNumber(
@intFromEnum(eventType),
)});
}
}
pub fn trigger(this: *InternalJSEventCallback, eventType: EventType, globalThis: *JSC.JSGlobalObject) bool {
if (this.function.get()) |callback| {
_ = callback.call(globalThis, JSC.JSValue.jsUndefined(), &.{JSC.JSValue.jsNumber(

View File

@@ -1207,6 +1207,16 @@ extern "C"
}
}
void uws_res_abort(int ssl, uws_res_r res) {
if (ssl) {
uWS::HttpResponse<true> *uwsRes = (uWS::HttpResponse<true> *)res;
uwsRes->abort();
} else {
uWS::HttpResponse<false> *uwsRes = (uWS::HttpResponse<false> *)res;
uwsRes->abort();
}
}
void uws_res_end_without_body(int ssl, uws_res_r res, bool close_connection)
{
if (ssl)

View File

@@ -3705,6 +3705,10 @@ pub fn NewApp(comptime ssl: bool) type {
return uws_res_has_responded(ssl_flag, res.downcast());
}
pub fn abort(res: *Response) void {
uws_res_abort(ssl_flag, res.downcast());
}
pub fn getNativeHandle(res: *Response) bun.FileDescriptor {
if (comptime Environment.isWindows) {
// on windows uSockets exposes SOCKET
@@ -4613,3 +4617,5 @@ pub fn onThreadExit() void {
extern fn uws_app_clear_routes(ssl_flag: c_int, app: *uws_app_t) void;
pub extern fn us_socket_upgrade_to_tls(s: *Socket, new_context: *SocketContext, sni: ?[*:0]const u8) ?*Socket;
extern fn uws_res_abort(ssl_flag: c_int, res: *uws_res) void;

View File

@@ -0,0 +1,185 @@
import { describe, expect, test, mock } from "bun:test";
import { bunEnv, bunExe, rejectUnauthorizedScope, tempDirWithFiles, tls } from "harness";
describe("server.abort()", async () => {
test("after sleep", async () => {
using server = Bun.serve({
port: 0,
async fetch(request, server) {
await Bun.sleep(0);
server.abort(request);
return new Response("Hello, world!");
},
});
expect(async () => {
const response = await fetch(`http://localhost:${server.port}`);
}).toThrow("The socket connection was closed");
});
test("before sleep", async () => {
using server = Bun.serve({
port: 0,
async fetch(request, server) {
expect(server.abort(request)).toBe(true);
await Bun.sleep(0);
// calling it again should do nothing
expect(server.abort(request)).toBe(false);
return new Response("Hello, world!");
},
});
expect(async () => {
await fetch(`http://localhost:${server.port}`);
}).toThrow("The socket connection was closed");
});
test("slightly after response is returned", async () => {
using server = Bun.serve({
port: 0,
async fetch(request, server) {
queueMicrotask(() => {
expect(server.abort(request)).toBe(true);
});
return new Response("hello!");
},
});
expect(async () => {
await fetch(`http://localhost:${server.port}`);
}).toThrow("The socket connection was closed");
});
test("after response was probably sent does nothing", async () => {
using server = Bun.serve({
port: 0,
async fetch(request, server) {
setTimeout(() => {
expect(server.abort(request)).toBe(false);
}, 0);
return new Response("hello!");
},
});
const response = await fetch(`http://localhost:${server.port}`);
expect(response.status).toBe(200);
expect(await response.text()).toBe("hello!");
});
test("triggers AbortSignal", async () => {
using server = Bun.serve({
port: 0,
async fetch(request, server) {
const fn = mock(() => {
// already aborted.
expect(server.abort(request)).toBe(false);
});
request.signal.addEventListener("abort", fn);
expect(server.abort(request)).toBe(true);
// you can return undefined and it should not trigger an uncaught exception
},
});
expect(async () => {
await fetch(`http://localhost:${server.port}`);
}).toThrow("The socket connection was closed");
});
test("triggers AbortSignal after sleep", async () => {
using server = Bun.serve({
port: 0,
async fetch(request, server) {
const fn = mock(() => {
// already aborted.
expect(server.abort(request)).toBe(false);
});
request.signal.addEventListener("abort", fn);
await Bun.sleep(0);
expect(server.abort(request)).toBe(true);
// you can return undefined and it should not trigger an uncaught exception
},
});
expect(async () => {
await fetch(`http://localhost:${server.port}`);
}).toThrow("The socket connection was closed");
});
test("works inside of a ReadableStream on the original Request with sleep", async () => {
using server = Bun.serve({
port: 0,
async fetch(request, server) {
return new Response(
new ReadableStream({
async pull(controller) {
await Bun.sleep(0);
server.abort(request);
controller.close();
},
}),
);
},
});
expect(async () => {
await fetch(`http://localhost:${server.port}`);
}).toThrow("The socket connection was closed");
});
test("works inside of a ReadableStream on the original Request without sleep", async () => {
using server = Bun.serve({
port: 0,
async fetch(request, server) {
return new Response(
new ReadableStream({
pull(controller) {
server.abort(request);
controller.close();
},
}),
);
},
});
expect(async () => {
await fetch(`http://localhost:${server.port}`);
}).toThrow("The socket connection was closed");
});
test("works inside of a ReadableStream on the original Request without sleep, with SSL", async () => {
using server = Bun.serve({
port: 0,
tls: tls,
async fetch(request, server) {
return new Response(
new ReadableStream({
pull(controller) {
server.abort(request);
controller.close();
},
}),
);
},
});
expect(async () => {
await fetch(`https://localhost:${server.port}`, {
tls: {
rejectUnauthorized: false,
},
});
}).toThrow("The socket connection was closed");
});
});