fix(server) fix abrupt stop (#12472)

Co-authored-by: Jarred Sumner <jarred@jarredsumner.com>
Co-authored-by: Jarred-Sumner <Jarred-Sumner@users.noreply.github.com>
Co-authored-by: cirospaciari <cirospaciari@users.noreply.github.com>
This commit is contained in:
Ciro Spaciari
2024-07-11 18:22:23 -07:00
committed by GitHub
parent 3ac9c3cc1c
commit 11f8d3cb24
5 changed files with 201 additions and 69 deletions

View File

@@ -580,7 +580,19 @@ public:
httpResponseData->onAborted = std::move(handler);
return this;
}
HttpResponse* clearOnWritableAndAborted() {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
httpResponseData->onWritable = nullptr;
httpResponseData->onAborted = nullptr;
return this;
}
HttpResponse* clearOnAborted() {
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
httpResponseData->onAborted = nullptr;
return this;
}
/* Attach a read handler for data sent. Will be called with FIN set true if last segment. */
void onData(MoveOnlyFunction<void(std::string_view, bool)> &&handler) {
HttpResponseData<SSL> *data = getHttpResponseData();

View File

@@ -1477,7 +1477,8 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
result.ensureStillAlive();
ctx.pending_promises_for_abort -|= 1;
if (ctx.flags.aborted) {
if (ctx.isAbortedOrEnded()) {
ctx.finalizeForAbort();
return JSValue.jsUndefined();
}
@@ -1552,27 +1553,35 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
ctx.pending_promises_for_abort -|= 1;
if (ctx.flags.aborted) {
ctx.finalizeForAbort();
return JSValue.jsUndefined();
}
handleReject(ctx, if (!err.isEmptyOrUndefinedOrNull()) err else JSC.JSValue.jsUndefined());
return JSValue.jsUndefined();
}
fn handleReject(ctx: *RequestContext, value: JSC.JSValue) void {
if (ctx.resp == null) {
if (ctx.isAbortedOrEnded()) {
ctx.finalizeForAbort();
return;
}
const resp = ctx.resp.?;
const has_responded = resp.hasResponded();
if (!has_responded)
if (!has_responded) {
const original_state = ctx.defer_deinit_until_callback_completes;
var should_deinit_context = false;
ctx.defer_deinit_until_callback_completes = &should_deinit_context;
ctx.runErrorHandler(
value,
);
ctx.defer_deinit_until_callback_completes = original_state;
// we try to deinit inside runErrorHandler so we just return here and let it deinit
if (should_deinit_context) {
ctx.deinit();
return;
}
}
if (ctx.flags.aborted) {
// check again in case it get aborted after runErrorHandler
if (ctx.isAbortedOrEnded()) {
ctx.finalizeForAbort();
return;
}
@@ -1735,7 +1744,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
resp.clearOnData();
}
resp.end(data, closeConnection);
this.resp = null;
this.detachResponse();
}
}
@@ -1752,8 +1761,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
// We cannot call this function if the Content-Length header was previously set
if (resp.state().isResponsePending())
resp.endStream(closeConnection);
this.resp = null;
this.detachResponse();
}
}
@@ -1764,7 +1772,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
resp.clearOnData();
}
resp.endWithoutBody(closeConnection);
this.resp = null;
this.detachResponse();
}
}
@@ -1772,7 +1780,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
ctxLog("onWritableResponseBuffer", .{});
assert(this.resp == resp);
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return false;
}
@@ -1786,7 +1794,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
ctxLog("onWritableCompleteResponseBufferAndMetadata", .{});
assert(this.resp == resp);
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return false;
}
@@ -1807,7 +1815,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn onWritableCompleteResponseBuffer(this: *RequestContext, write_offset: u64, resp: *App.Response) callconv(.C) bool {
ctxLog("onWritableCompleteResponseBuffer", .{});
assert(this.resp == resp);
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return false;
}
@@ -1878,7 +1886,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
// if we can, free the request now.
if (this.isDeadRequest()) {
this.finalizeWithoutDeinit();
this.markComplete();
this.deinit();
} else {
this.pending_promises_for_abort = 0;
@@ -1922,11 +1929,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
}
pub fn markComplete(this: *RequestContext) void {
if (!this.flags.has_marked_complete) this.server.onRequestComplete();
this.flags.has_marked_complete = true;
}
// This function may be called multiple times
// so it's important that we can safely do that
pub fn finalizeWithoutDeinit(this: *RequestContext) void {
@@ -2008,15 +2010,22 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn finalize(this: *RequestContext) void {
ctxLog("finalize<d> ({*})<r>", .{this});
this.finalizeWithoutDeinit();
this.markComplete();
this.deinit();
}
pub fn deinit(this: *RequestContext) void {
ctxLog("deinit<d> ({*})<r>", .{this});
if (!this.isDeadRequest()) {
ctxLog("deinit<d> ({*})<r> waiting request", .{this});
return;
}
if (!this.flags.has_marked_complete) this.server.onRequestComplete();
this.flags.has_marked_complete = true;
this.detachResponse();
if (this.defer_deinit_until_callback_completes) |defer_deinit| {
defer_deinit.* = true;
ctxLog("deferred deinit <d> ({*})<r>", .{this});
@@ -2030,7 +2039,6 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
if (comptime Environment.allow_assert)
assert(this.flags.has_marked_complete);
var server = this.server;
this.request_body_buf.clearAndFree(this.allocator);
this.response_buf_owned.clearAndFree(this.allocator);
@@ -2039,7 +2047,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this.request_body = null;
}
server.request_pool_allocator.put(this);
this.server.request_pool_allocator.put(this);
}
fn writeHeaders(
@@ -2087,7 +2095,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}};
pub fn onSendfile(this: *RequestContext) bool {
if (this.flags.aborted or this.resp == null) {
if (this.isAbortedOrEnded()) {
this.cleanupAndFinalizeAfterSendfile();
return false;
}
@@ -2107,7 +2115,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
this.sendfile.remain -|= @as(Blob.SizeType, @intCast(this.sendfile.offset -| start));
if (errcode != .SUCCESS or this.flags.aborted or this.sendfile.remain == 0 or val == 0) {
if (errcode != .SUCCESS or this.isAbortedOrEnded() or this.sendfile.remain == 0 or val == 0) {
if (errcode != .AGAIN and errcode != .SUCCESS and errcode != .PIPE and errcode != .NOTCONN) {
Output.prettyErrorln("Error: {s}", .{@tagName(errcode)});
Output.flush();
@@ -2129,7 +2137,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
const wrote = @as(Blob.SizeType, @intCast(sbytes));
this.sendfile.offset +|= wrote;
this.sendfile.remain -|= wrote;
if (errcode != .AGAIN or this.flags.aborted or this.sendfile.remain == 0 or sbytes == 0) {
if (errcode != .AGAIN or this.isAbortedOrEnded() or this.sendfile.remain == 0 or sbytes == 0) {
if (errcode != .AGAIN and errcode != .SUCCESS and errcode != .PIPE and errcode != .NOTCONN) {
Output.prettyErrorln("Error: {s}", .{@tagName(errcode)});
Output.flush();
@@ -2154,7 +2162,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn onWritableBytes(this: *RequestContext, write_offset: u64, resp: *App.Response) callconv(.C) bool {
ctxLog("onWritableBytes", .{});
assert(this.resp == resp);
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return false;
}
@@ -2288,7 +2296,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
.remain = this.blob.Blob.offset + original_size,
.offset = this.blob.Blob.offset,
.auto_close = auto_close,
.socket_fd = if (!this.flags.aborted) resp.getNativeHandle() else bun.invalid_fd,
.socket_fd = if (!this.isAbortedOrEnded()) resp.getNativeHandle() else bun.invalid_fd,
};
// if we are sending only part of a file, include the content-range header
@@ -2322,7 +2330,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
pub fn doSendfile(this: *RequestContext, blob: Blob) void {
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return;
}
@@ -2342,7 +2350,8 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn onReadFile(this: *RequestContext, result: Blob.ReadFile.ResultType) void {
this.flags.has_pending_read = false;
if (this.flags.aborted or this.resp == null) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return;
}
@@ -2395,7 +2404,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
fn renderWithBlobFromBodyValue(this: *RequestContext) void {
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return;
}
@@ -2422,7 +2431,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
ctxLog("doRenderStream", .{});
var this = pair.this;
var stream = pair.stream;
if (this.resp == null or this.flags.aborted) {
if (this.isAbortedOrEnded()) {
stream.cancel(this.server.globalThis);
this.readable_stream_ref.deinit();
this.finalizeForAbort();
@@ -2477,7 +2486,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
if (assignment_result.toError()) |err_value| {
streamLog("returned an error", .{});
if (!this.flags.aborted) resp.clearAborted();
if (!this.isAbortedOrEnded()) resp.clearAborted();
response_stream.detach();
this.sink = null;
response_stream.sink.destroy();
@@ -2485,7 +2494,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
if (resp.hasResponded()) {
if (!this.flags.aborted) resp.clearAborted();
if (!this.isAbortedOrEnded()) resp.clearAborted();
streamLog("done", .{});
response_stream.detach();
this.sink = null;
@@ -2553,7 +2562,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
} else {
// if is not a promise we treat it as Error
streamLog("returned an error", .{});
if (!this.flags.aborted) resp.clearAborted();
if (!this.isAbortedOrEnded()) resp.clearAborted();
response_stream.detach();
this.sink = null;
response_stream.sink.destroy();
@@ -2561,7 +2570,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
}
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
response_stream.detach();
stream.cancel(globalThis);
defer this.readable_stream_ref.deinit();
@@ -2635,6 +2644,15 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
ctx.setAbortHandler();
}
fn detachResponse(this: *RequestContext) void {
this.resp = null;
}
fn isAbortedOrEnded(this: *const RequestContext) bool {
// resp == null or aborted or server.stop(true)
return this.resp == null or this.flags.aborted or this.server.flags.terminated;
}
// Each HTTP request or TCP socket connection is effectively a "task".
//
// However, unlike the regular task queue, we don't drain the microtask
@@ -2659,7 +2677,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
response_value.ensureStillAlive();
ctx.drainMicrotasks();
if (ctx.flags.aborted) {
if (ctx.isAbortedOrEnded()) {
ctx.finalizeForAbort();
return;
}
@@ -2801,7 +2819,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
streamLog("onResolve({any})", .{wrote_anything});
//aborted so call finalizeForAbort
if (req.flags.aborted or req.resp == null) {
if (req.isAbortedOrEnded()) {
req.finalizeForAbort();
return;
}
@@ -2862,7 +2880,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
// aborted so call finalizeForAbort
if (req.flags.aborted) {
if (req.isAbortedOrEnded()) {
req.finalizeForAbort();
return;
}
@@ -2895,7 +2913,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
.Error => {
const err = value.Error;
_ = value.use();
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return;
}
@@ -2913,7 +2931,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
return;
},
.Locked => |*lock| {
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return;
}
@@ -3028,7 +3046,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
}
}
if (this.flags.aborted or this.resp == null) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return;
}
@@ -3076,7 +3094,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn doRender(this: *RequestContext) void {
ctxLog("doRender", .{});
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
this.finalizeForAbort();
return;
}
@@ -3433,7 +3451,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
assert(this.resp == resp);
this.flags.is_waiting_for_request_body = last == false;
if (this.flags.aborted or this.flags.has_marked_complete) return;
if (this.isAbortedOrEnded() or this.flags.has_marked_complete) return;
if (!last and chunk.len == 0) {
// Sometimes, we get back an empty chunk
// We have to ignore those chunks unless it's the last one
@@ -3532,7 +3550,7 @@ fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, comp
pub fn onStartStreamingRequestBody(this: *RequestContext) JSC.WebCore.DrainResult {
ctxLog("onStartStreamingRequestBody", .{});
if (this.flags.aborted) {
if (this.isAbortedOrEnded()) {
return JSC.WebCore.DrainResult{
.aborted = {},
};
@@ -5388,7 +5406,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
}
var upgrader = bun.cast(*RequestContext, request.upgrader.?);
if (upgrader.flags.aborted or upgrader.resp == null) {
if (upgrader.isAbortedOrEnded()) {
return JSC.jsBoolean(false);
}
@@ -6311,7 +6329,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
// uWS request will not live longer than this function
request_object.request_context = JSC.API.AnyRequestContext.Null;
}
const original_state = ctx.defer_deinit_until_callback_completes;
var should_deinit_context = false;
ctx.defer_deinit_until_callback_completes = &should_deinit_context;
ctx.onResponse(
@@ -6321,7 +6339,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
request_value,
response_value,
);
ctx.defer_deinit_until_callback_completes = null;
ctx.defer_deinit_until_callback_completes = original_state;
if (should_deinit_context) {
request_object.request_context = JSC.API.AnyRequestContext.Null;
@@ -6380,6 +6398,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
request_object.request_context = JSC.API.AnyRequestContext.Null;
}
const original_state = ctx.defer_deinit_until_callback_completes;
var should_deinit_context = false;
ctx.defer_deinit_until_callback_completes = &should_deinit_context;
ctx.onResponse(
@@ -6389,7 +6408,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
request_value,
response_value,
);
ctx.defer_deinit_until_callback_completes = null;
ctx.defer_deinit_until_callback_completes = original_state;
if (should_deinit_context) {
request_object.request_context = JSC.API.AnyRequestContext.Null;

View File

@@ -1021,15 +1021,13 @@ extern "C"
if (ssl)
{
uWS::HttpResponse<true> *uwsRes = (uWS::HttpResponse<true> *)res;
uwsRes->getHttpResponseData()->onWritable = nullptr;
uwsRes->onAborted(nullptr);
uwsRes->clearOnWritableAndAborted();
uwsRes->end(std::string_view(data, length), close_connection);
}
else
{
uWS::HttpResponse<false> *uwsRes = (uWS::HttpResponse<false> *)res;
uwsRes->getHttpResponseData()->onWritable = nullptr;
uwsRes->onAborted(nullptr);
uwsRes->clearOnWritableAndAborted();
uwsRes->end(std::string_view(data, length), close_connection);
}
}
@@ -1039,15 +1037,13 @@ extern "C"
if (ssl)
{
uWS::HttpResponse<true> *uwsRes = (uWS::HttpResponse<true> *)res;
uwsRes->getHttpResponseData()->onWritable = nullptr;
uwsRes->onAborted(nullptr);
uwsRes->clearOnWritableAndAborted();
uwsRes->sendTerminatingChunk(close_connection);
}
else
{
uWS::HttpResponse<false> *uwsRes = (uWS::HttpResponse<false> *)res;
uwsRes->getHttpResponseData()->onWritable = nullptr;
uwsRes->onAborted(nullptr);
uwsRes->clearOnWritableAndAborted();
uwsRes->sendTerminatingChunk(close_connection);
}
}
@@ -1264,7 +1260,7 @@ extern "C"
}
else
{
uwsRes->onAborted(nullptr);
uwsRes->clearOnAborted();
}
}
else
@@ -1278,7 +1274,7 @@ extern "C"
}
else
{
uwsRes->onAborted(nullptr);
uwsRes->clearOnAborted();
}
}
}
@@ -1561,8 +1557,7 @@ extern "C"
uWS::HttpResponse<true> *uwsRes = (uWS::HttpResponse<true> *)res;
auto pair = uwsRes->tryEnd(std::string_view(bytes, len), total_len, close);
if (pair.first) {
uwsRes->getHttpResponseData()->onWritable = nullptr;
uwsRes->onAborted(nullptr);
uwsRes->clearOnWritableAndAborted();
}
return pair.first;
@@ -1572,8 +1567,7 @@ extern "C"
uWS::HttpResponse<false> *uwsRes = (uWS::HttpResponse<false> *)res;
auto pair = uwsRes->tryEnd(std::string_view(bytes, len), total_len, close);
if (pair.first) {
uwsRes->getHttpResponseData()->onWritable = nullptr;
uwsRes->onAborted(nullptr);
uwsRes->clearOnWritableAndAborted();
}
return pair.first;

View File

@@ -1,8 +1,8 @@
const s = Bun.serve({
using s = Bun.serve({
fetch(req, res) {
s.stop(true);
throw new Error("1");
},
port: 0,
});
fetch(`http://${s.hostname}:${s.port}`).then(res => console.log(res.status));
await fetch(`http://${s.hostname}:${s.port}`).then(res => console.log(res.status));

View File

@@ -47,6 +47,82 @@ afterAll(() => {
}
});
it("should be able to abruptly stop the server many times", async () => {
async function run() {
const stopped = Promise.withResolvers();
const server = Bun.serve({
port: 0,
error() {
return new Response("Error", { status: 500 });
},
async fetch(req, server) {
await Bun.sleep(50);
server.stop(true);
await Bun.sleep(50);
server = undefined;
if (stopped.resolve) {
stopped.resolve();
stopped.resolve = undefined;
}
return new Response("Hello, World!");
},
});
const url = server.url;
async function request() {
try {
await fetch(url, { keepalive: true }).then(res => res.text());
expect.unreachable();
} catch (e) {
expect(e.code).toBe("ConnectionClosed");
}
}
const requests = new Array(20);
for (let i = 0; i < 20; i++) {
requests[i] = request();
}
await Promise.all(requests);
await stopped.promise;
Bun.gc(true);
}
const runs = new Array(10);
for (let i = 0; i < 10; i++) {
runs[i] = run();
}
await Promise.all(runs);
Bun.gc(true);
});
// This test reproduces a crash in Bun v1.1.18 and earlier
it("should be able to abruptly stop the server", async () => {
for (let i = 0; i < 2; i++) {
const controller = new AbortController();
using server = Bun.serve({
port: 0,
error() {
return new Response("Error", { status: 500 });
},
async fetch(req, server) {
server.stop(true);
await Bun.sleep(10);
return new Response();
},
});
await fetch(server.url, {
signal: controller.signal,
})
.then(res => {
return res.blob();
})
.catch(() => {});
}
});
describe("1000 uploads & downloads in batches of 64 do not leak ReadableStream", () => {
for (let isDirect of [true, false] as const) {
it(
@@ -1096,6 +1172,7 @@ describe("should support Content-Range with Bun.file()", () => {
});
it("formats error responses correctly", async () => {
const { promise, resolve, reject } = Promise.withResolvers();
const c = spawn(bunExe(), ["./error-response.js"], { cwd: import.meta.dir, env: bunEnv });
var output = "";
@@ -1103,9 +1180,16 @@ it("formats error responses correctly", async () => {
output += chunk.toString();
});
c.stderr.on("end", () => {
expect(output).toContain('throw new Error("1");');
c.kill();
try {
expect(output).toContain('throw new Error("1");');
resolve();
} catch (e) {
reject(e);
} finally {
c.kill();
}
});
await promise;
});
it("request body and signal life cycle", async () => {
@@ -1542,3 +1626,26 @@ it("should be able to stop in the middle of a file response", async () => {
process.kill();
}
}, 60_000);
it("should be able to abrupt stop the server", async () => {
for (let i = 0; i < 10; i++) {
using server = Bun.serve({
port: 0,
error() {
return new Response("Error", { status: 500 });
},
async fetch(req, server) {
server.stop(true);
await Bun.sleep(100);
return new Response("Hello, World!");
},
});
try {
await fetch(server.url).then(res => res.text());
expect.unreachable();
} catch (e) {
expect(e.code).toBe("ConnectionClosed");
}
}
});