fix bun server segfault with abortsignal (#2261)

* removed redundant tests, fixed server segfault

* fix onRejectStream, safer unassign signal

* fix abort Bun.serve signal.addEventListener on async

* move ctx.signal null check up

* keep original behavior of streams onAborted
This commit is contained in:
Ciro Spaciari
2023-03-02 02:40:11 -03:00
committed by GitHub
parent b9137dbdc8
commit 1be834b073
3 changed files with 79 additions and 133 deletions

View File

@@ -650,6 +650,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
/// this prevents an extra pthread_getspecific() call which shows up in profiling
allocator: std.mem.Allocator,
req: *uws.Request,
signal: ?*JSC.AbortSignal = null,
method: HTTP.Method,
aborted: bool = false,
finalized: bun.DebugOnly(bool) = bun.DebugOnlyDefault(false),
@@ -698,11 +699,24 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn onResolve(_: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue {
ctxLog("onResolve", .{});
const arguments = callframe.arguments(2);
var ctx = arguments.ptr[1].asPromisePtr(@This());
const result = arguments.ptr[0];
result.ensureStillAlive();
if (ctx.request_js_object != null and ctx.signal == null) {
var request_js = ctx.request_js_object.?.value();
request_js.ensureStillAlive();
if (request_js.as(Request)) |request_object| {
if (request_object.signal) |signal| {
ctx.signal = signal;
_ = signal.ref();
}
}
}
ctx.pending_promises_for_abort -|= 1;
if (ctx.aborted) {
ctx.finalizeForAbort();
@@ -745,10 +759,23 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn onReject(_: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) callconv(.C) JSValue {
ctxLog("onReject", .{});
const arguments = callframe.arguments(2);
var ctx = arguments.ptr[1].asPromisePtr(@This());
const err = arguments.ptr[0];
if (ctx.request_js_object != null and ctx.signal == null) {
var request_js = ctx.request_js_object.?.value();
request_js.ensureStillAlive();
if (request_js.as(Request)) |request_object| {
if (request_object.signal) |signal| {
ctx.signal = signal;
_ = signal.ref();
}
}
}
ctx.pending_promises_for_abort -|= 1;
if (ctx.aborted) {
@@ -992,13 +1019,24 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
std.debug.assert(!this.aborted);
//mark request as aborted
this.aborted = true;
// if signal is not aborted, abort the signal
if (this.signal) |signal| {
this.signal = null;
if (!signal.aborted()) {
const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
reason.ensureStillAlive();
_ = signal.signal(reason);
}
_ = signal.unref();
}
//if have sink, call onAborted on sink
if (this.sink) |wrapper| {
wrapper.detach();
wrapper.sink.onAborted(resp);
this.sink = null;
wrapper.sink.destroy();
this.finalizeForAbort();
return;
}
@@ -1022,7 +1060,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
// User called .blob(), .json(), text(), or .arrayBuffer() on the Request object
// but we received nothing or the connection was aborted
if (request_js.as(Request)) |req| {
this._signalAbort(req);
// the promise is pending
if (req.body == .Locked and (req.body.Locked.action != .none or req.body.Locked.promise != null)) {
@@ -1059,20 +1096,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
}
pub fn _signalAbort(this: *RequestContext, req: *Request) void {
//only call when actually aborted
if (!this.aborted) return;
//check if have a valid signal
if (req.signal) |signal| {
// if signal is not aborted, abort the signal
if (!signal.aborted()) {
const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
reason.ensureStillAlive();
_ = signal.signal(reason);
}
}
}
pub fn markComplete(this: *RequestContext) void {
if (!this.has_marked_complete) this.server.onRequestComplete();
this.has_marked_complete = true;
@@ -1098,6 +1121,17 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this.response_jsvalue = JSC.JSValue.zero;
}
// if signal is not aborted, abort the signal
if (this.signal) |signal| {
this.signal = null;
if (this.aborted and !signal.aborted()) {
const reason = JSC.AbortSignal.createAbortError(JSC.ZigString.static("The user aborted a request"), &JSC.ZigString.Empty, this.server.globalThis);
reason.ensureStillAlive();
_ = signal.signal(reason);
}
_ = signal.unref();
}
if (this.request_js_object != null) {
ctxLog("finalizeWithoutDeinit: request_js_object != null", .{});
@@ -1110,7 +1144,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
// User called .blob(), .json(), text(), or .arrayBuffer() on the Request object
// but we received nothing or the connection was aborted
if (request_js.as(Request)) |req| {
this._signalAbort(req);
// the promise is pending
if (req.body == .Locked and req.body.Locked.action != .none and req.body.Locked.promise != null) {
req.body.toErrorInstance(JSC.toTypeError(.ABORT_ERR, "Request aborted", .{}, this.server.globalThis), this.server.globalThis);
@@ -1734,6 +1767,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
switch (promise.status(vm.global.vm())) {
.Pending => {},
.Fulfilled => {
if (ctx.signal == null) {
if (request_object.signal) |signal| {
ctx.signal = signal;
_ = signal.ref();
}
}
const fulfilled_value = promise.result(vm.global.vm());
// if you return a Response object or a Promise<Response>
@@ -1776,6 +1815,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
return;
},
.Rejected => {
if (ctx.signal == null) {
if (request_object.signal) |signal| {
ctx.signal = signal;
_ = signal.ref();
}
}
ctx.handleReject(promise.result(vm.global.vm()));
return;
},
@@ -1816,8 +1861,11 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn handleResolveStream(req: *RequestContext) void {
streamLog("handleResolveStream", .{});
//aborted already called finalizeForAbort at this stage
if (req.aborted) return;
//aborted so call finalizeForAbort
if (req.aborted) {
req.finalizeForAbort();
return;
}
var wrote_anything = false;
if (req.sink) |wrapper| {
@@ -1869,9 +1917,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn handleRejectStream(req: *@This(), globalThis: *JSC.JSGlobalObject, err: JSValue) void {
//aborted already called finalizeForAbort at this stage
if (req.aborted) return;
streamLog("handleRejectStream", .{});
var wrote_anything = req.has_written_status;
@@ -1895,6 +1940,12 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
streamLog("onReject({any})", .{wrote_anything});
//aborted so call finalizeForAbort
if (req.aborted) {
req.finalizeForAbort();
return;
}
if (!err.isEmptyOrUndefinedOrNull() and !wrote_anything) {
req.response_jsvalue.unprotect();
req.response_jsvalue = JSValue.zero;
@@ -4696,8 +4747,12 @@ pub fn NewServer(comptime ssl_enabled_: bool, comptime debug_mode_: bool) type {
ctx.request_js_object = args[0].asObjectRef();
const request_value = args[0];
request_value.ensureStillAlive();
const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
const response_value = this.config.onRequest.callWithThis(this.globalThis, this.thisObject, &args);
if (request_object.signal) |signal| {
ctx.signal = signal;
_ = signal.ref();
}
ctx.onResponse(
this,
req,

View File

@@ -2760,9 +2760,9 @@ pub fn HTTPServerWritable(comptime ssl: bool) type {
pub fn onAborted(this: *@This(), _: *UWSResponse) void {
log("onAborted()", .{});
this.signal.close(null);
this.done = true;
this.aborted = true;
this.signal.close(null);
this.flushPromise();
this.finalize();
}

View File

@@ -27,115 +27,6 @@ afterEach(() => {
const payload = new Uint8Array(1024 * 1024 * 2);
crypto.getRandomValues(payload);
describe("AbortSignalStreamTest", async () => {
async function abortOnStage(body, stage) {
let error = undefined;
var abortController = new AbortController();
{
const server = getServer({
async fetch(request) {
let chunk_count = 0;
const reader = request.body.getReader();
return Response(
new ReadableStream({
async pull(controller) {
while (true) {
chunk_count++;
const { done, value } = await reader.read();
if (chunk_count == stage) {
abortController.abort();
}
if (done) {
controller.close();
return;
}
controller.enqueue(value);
}
},
}),
);
},
});
try {
const signal = abortController.signal;
await fetch(`http://127.0.0.1:${server.port}`, { method: "POST", body, signal: signal }).then(res =>
res.arrayBuffer(),
);
} catch (ex) {
error = ex;
}
expect(error.name).toBe("AbortError");
expect(error.message).toBe("The operation was aborted.");
expect(error instanceof DOMException).toBeTruthy();
}
}
for (let i = 1; i < 7; i++) {
it(`Abort after ${i} chunks`, async () => {
await abortOnStage(payload, i);
});
}
});
describe("AbortSignalDirectStreamTest", () => {
async function abortOnStage(body, stage) {
let error = undefined;
var abortController = new AbortController();
{
const server = getServer({
async fetch(request) {
let chunk_count = 0;
const reader = request.body.getReader();
return Response(
new ReadableStream({
type: "direct",
async pull(controller) {
while (true) {
chunk_count++;
const { done, value } = await reader.read();
if (chunk_count == stage) {
abortController.abort();
}
if (done) {
controller.end();
return;
}
controller.write(value);
}
},
}),
);
},
});
try {
const signal = abortController.signal;
await fetch(`http://127.0.0.1:${server.port}`, { method: "POST", body, signal: signal }).then(res =>
res.arrayBuffer(),
);
} catch (ex) {
error = ex;
}
expect(error.name).toBe("AbortError");
expect(error.message).toBe("The operation was aborted.");
expect(error instanceof DOMException).toBeTruthy();
}
}
for (let i = 1; i < 7; i++) {
it(`Abort after ${i} chunks`, async () => {
await abortOnStage(payload, i);
});
}
});
describe("AbortSignal", () => {
var server;
beforeEach(() => {