Compare commits

...

2 Commits

Author SHA1 Message Date
cirospaciari
8503bdee56 Apply formatting changes 2024-06-11 11:45:20 -03:00
cirospaciari
cbc0a96ead keep JS instead of native 2024-06-11 11:45:20 -03:00
3 changed files with 129 additions and 28 deletions

View File

@@ -1407,7 +1407,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
/// this prevents an extra pthread_getspecific() call which shows up in profiling
allocator: std.mem.Allocator,
req: *uws.Request,
signal: ?*JSC.WebCore.AbortSignal = null,
signal: JSC.Strong = .{},
method: HTTP.Method,
flags: NewFlags(debug_mode) = .{},
@@ -1446,6 +1446,13 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
// TODO: support builtin compression
const can_sendfile = !ssl_enabled and !Environment.isWindows;
pub fn getSignal(this: *const RequestContext) ?*JSC.WebCore.AbortSignal {
if (this.signal.get()) |js_signal| {
return js_signal.as(JSC.WebCore.AbortSignal);
}
return null;
}
pub inline fn isAsync(this: *const RequestContext) bool {
return this.defer_deinit_until_callback_completes == null;
}
@@ -1837,6 +1844,24 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
return true;
}
/// Check if we are aborted and signal the abort if we are
/// The signal ref will be cleaned up
/// Returns true if we are aborted
fn checkAndCleanAbortSignal(this: *RequestContext) bool {
if (this.getSignal()) |signal| {
var _signal = this.signal;
this.signal = .{};
defer _signal.deinit();
if (this.flags.aborted and !signal.aborted()) {
const reason = JSC.WebCore.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
reason.ensureStillAlive();
_ = signal.signal(reason);
return true;
}
}
return false;
}
pub fn onAbort(this: *RequestContext, resp: *App.Response) void {
assert(this.resp == resp);
assert(!this.flags.aborted);
@@ -1853,17 +1878,9 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
// if signal is not aborted, abort the signal
if (this.signal) |signal| {
this.signal = null;
if (!signal.aborted()) {
const reason = JSC.WebCore.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
reason.ensureStillAlive();
_ = signal.signal(reason);
any_js_calls = true;
}
_ = signal.unref();
if (this.checkAndCleanAbortSignal()) {
any_js_calls = true;
}
//if have sink, call onAborted on sink
if (this.sink) |wrapper| {
wrapper.sink.abort();
@@ -1943,15 +1960,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
// if signal is not aborted, abort the signal
if (this.signal) |signal| {
this.signal = null;
if (this.flags.aborted and !signal.aborted()) {
const reason = JSC.WebCore.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
reason.ensureStillAlive();
_ = signal.signal(reason);
}
_ = signal.unref();
}
_ = this.checkAndCleanAbortSignal();
if (this.request_body) |body| {
ctxLog("finalizeWithoutDeinit: request_body != null", .{});
@@ -6227,14 +6236,11 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
var body = JSC.WebCore.InitRequestBodyValue(.{ .Null = {} }) catch unreachable;
ctx.request_body = body;
var signal = JSC.WebCore.AbortSignal.new(this.globalThis);
ctx.signal = signal;
request_object.* = .{
.method = ctx.method,
.request_context = AnyRequestContext.init(ctx),
.https = ssl_enabled,
.signal = signal.ref(),
.signal = null,
.body = body.ref(),
};
@@ -6297,6 +6303,8 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
const request_value = args[0];
request_value.ensureStillAlive();
// keep a strong ref so we can signal when the request is aborted (We need to keep JS alive not only the native part)
ctx.signal = JSC.Strong.create(Request.getSignalFromJS(request_value, this.globalThis), this.globalThis);
const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
defer {
@@ -6345,15 +6353,12 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
var body = JSC.WebCore.InitRequestBodyValue(.{ .Null = {} }) catch unreachable;
ctx.request_body = body;
var signal = JSC.WebCore.AbortSignal.new(this.globalThis);
ctx.signal = signal;
request_object.* = .{
.method = ctx.method,
.request_context = AnyRequestContext.init(ctx),
.upgrader = ctx,
.https = ssl_enabled,
.signal = signal.ref(),
.signal = null,
.body = body.ref(),
};
ctx.upgrade_context = upgrade_ctx;
@@ -6365,6 +6370,9 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
};
const request_value = args[0];
request_value.ensureStillAlive();
// keep a strong ref so we can signal when the request is aborted (We need to keep JS alive not only the native part)
ctx.signal = JSC.Strong.create(Request.getSignalFromJS(request_value, this.globalThis), this.globalThis);
const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
defer {
if (!ctx.didUpgradeWebSocket()) {}

View File

@@ -98,6 +98,22 @@ pub const Request = struct {
}
}
/// Returns cached signal or generate a new JS signal and cache it.
pub fn getSignalFromJS(
jsRequest: JSC.JSValue,
globalThis: *JSC.JSGlobalObject,
) JSC.JSValue {
if (jsRequest.as(Request)) |request| {
if (Request.signalGetCached(jsRequest)) |js_signal| {
return js_signal;
}
const signal = request.getSignal(globalThis);
Request.signalSetCached(jsRequest, globalThis, signal);
return signal;
}
return .zero;
}
pub fn init(
url: bun.String,
headers: ?*FetchHeaders,

View File

@@ -1509,3 +1509,80 @@ it("should work with dispose keyword", async () => {
}
expect(fetch(url)).rejects.toThrow();
});
it("it should call abort when the request is aborted in the middle of a stream", async () => {
const { promise, resolve } = Promise.withResolvers();
const payload = Buffer.from("data: hello\n\n");
using server = Bun.serve({
port: 0,
fetch(req) {
let keepAlive = true;
req.signal.addEventListener("abort", () => {
keepAlive = false;
});
return new Response(
new ReadableStream({
async pull(controller) {
while (!req.signal.aborted) {
controller.enqueue(payload);
await Bun.sleep(10);
}
resolve(keepAlive);
},
}),
{
headers: {
"Cache-Control": "no-store",
"Content-Type": "text/event-stream",
Connection: "keep-alive",
},
},
);
},
});
const abortController = new AbortController();
const response = await fetch(server.url, { signal: abortController.signal });
expect(response.status).toBe(200);
abortController.abort();
expect(await promise).toBe(false);
});
it("it should call abort when the request is aborted in the middle of a stream using async fetch", async () => {
const { promise, resolve } = Promise.withResolvers();
const payload = Buffer.from("data: hello\n\n");
using server = Bun.serve({
port: 0,
async fetch(req) {
await Bun.sleep(10);
let keepAlive = true;
req.signal.addEventListener("abort", () => {
keepAlive = false;
});
return new Response(
new ReadableStream({
async pull(controller) {
while (!req.signal.aborted) {
controller.enqueue(payload);
await Bun.sleep(10);
}
resolve(keepAlive);
},
}),
{
headers: {
"Cache-Control": "no-store",
"Content-Type": "text/event-stream",
Connection: "keep-alive",
},
},
);
},
});
const abortController = new AbortController();
const response = await fetch(server.url, { signal: abortController.signal });
expect(response.status).toBe(200);
abortController.abort();
expect(await promise).toBe(false);
});