diff --git a/packages/bun-uws/src/AsyncSocket.h b/packages/bun-uws/src/AsyncSocket.h index f61b23a5de..941cccd668 100644 --- a/packages/bun-uws/src/AsyncSocket.h +++ b/packages/bun-uws/src/AsyncSocket.h @@ -182,9 +182,9 @@ public: } /* Returns the user space backpressure. */ - unsigned int getBufferedAmount() { + size_t getBufferedAmount() { /* We return the actual amount of bytes in backbuffer, including pendingRemoval */ - return (unsigned int) getAsyncSocketData()->buffer.totalLength(); + return getAsyncSocketData()->buffer.totalLength(); } /* Returns the text representation of an IPv4 or IPv6 address */ @@ -222,6 +222,63 @@ public: return addressAsText(getRemoteAddress()); } + /** + * Flushes the socket buffer by writing as much data as possible to the underlying socket. + * + * @return The total number of bytes successfully written to the socket + */ + size_t flush() { + /* Check if socket is valid for operations */ + if (us_socket_is_closed(SSL, (us_socket_t *) this)) { + /* Socket is closed, no flushing is possible */ + return 0; + } + + /* Get the associated asynchronous socket data structure */ + AsyncSocketData *asyncSocketData = getAsyncSocketData(); + size_t total_written = 0; + + /* Continue flushing as long as we have data in the buffer */ + while (asyncSocketData->buffer.length()) { + /* Get current buffer size */ + size_t buffer_len = asyncSocketData->buffer.length(); + + /* Limit write size to INT_MAX as the underlying socket API uses int for length */ + int max_flush_len = std::min(buffer_len, (size_t)INT_MAX); + + /* Attempt to write data to the socket */ + int written = us_socket_write(SSL, (us_socket_t *) this, asyncSocketData->buffer.data(), max_flush_len, 0); + total_written += written; + + /* Check if we couldn't write the entire buffer */ + if ((unsigned int) written < buffer_len) { + /* Remove the successfully written data from the buffer */ + asyncSocketData->buffer.erase((unsigned int) written); + + /* If we wrote less than we attempted, the socket buffer is likely full + * likely is used as an optimization hint to the compiler + * since written < buffer_len is very likely to be true + */ + if(written < max_flush_len) { + [[likely]] + /* Cannot write more at this time, return what we've written so far */ + return total_written; + } + /* If we wrote exactly max_flush_len, we might be able to write more, so continue + * This is unlikely to happen, because this would be INT_MAX bytes, which is unlikely to be written in one go + * but we keep this check for completeness + */ + continue; + } + + /* Successfully wrote the entire buffer, clear the buffer */ + asyncSocketData->buffer.clear(); + } + + /* Return the total number of bytes written during this flush operation */ + return total_written; + } + /* Write in three levels of prioritization: cork-buffer, syscall, socket-buffer. Always drain if possible. * Returns pair of bytes written (anywhere) and wheter or not this call resulted in the polling for * writable (or we are in a state that implies polling for writable). */ @@ -233,7 +290,6 @@ public: LoopData *loopData = getLoopData(); AsyncSocketData *asyncSocketData = getAsyncSocketData(); - /* We are limited if we have a per-socket buffer */ if (asyncSocketData->buffer.length()) { size_t buffer_len = asyncSocketData->buffer.length(); @@ -261,7 +317,7 @@ public: asyncSocketData->buffer.clear(); } - if (length) { + if (length) { if (loopData->isCorkedWith(this)) { /* We are corked */ if (LoopData::CORK_BUFFER_SIZE - loopData->getCorkOffset() >= (unsigned int) length) { diff --git a/packages/bun-uws/src/HttpContext.h b/packages/bun-uws/src/HttpContext.h index dcc16b641a..67cd550a3e 100644 --- a/packages/bun-uws/src/HttpContext.h +++ b/packages/bun-uws/src/HttpContext.h @@ -365,11 +365,32 @@ private: auto *asyncSocket = reinterpret_cast *>(s); auto *httpResponseData = reinterpret_cast *>(asyncSocket->getAsyncSocketData()); + /* Attempt to drain the socket buffer before triggering onWritable callback */ + size_t bufferedAmount = asyncSocket->getBufferedAmount(); + if (bufferedAmount > 0) { + /* Try to flush pending data from the socket's buffer to the network */ + bufferedAmount -= asyncSocket->flush(); + + /* Check if there's still data waiting to be sent after flush attempt */ + if (bufferedAmount > 0) { + /* Socket buffer is not completely empty yet + * - Reset the timeout to prevent premature connection closure + * - This allows time for another writable event or new request + * - Return the socket to indicate we're still processing + */ + reinterpret_cast *>(s)->resetTimeout(); + return s; + } + /* If bufferedAmount is now 0, we've successfully flushed everything + * and will fall through to the next section of code + */ + } + /* Ask the developer to write data and return success (true) or failure (false), OR skip sending anything and return success (true). */ if (httpResponseData->onWritable) { /* We are now writable, so hang timeout again, the user does not have to do anything so we should hang until end or tryEnd rearms timeout */ us_socket_timeout(SSL, s, 0); - + /* We expect the developer to return whether or not write was successful (true). * If write was never called, the developer should still return true so that we may drain. */ bool success = httpResponseData->callOnWritable(reinterpret_cast *>(asyncSocket), httpResponseData->offset); @@ -384,7 +405,7 @@ private: } /* Drain any socket buffer, this might empty our backpressure and thus finish the request */ - /*auto [written, failed] = */asyncSocket->write(nullptr, 0, true, 0); + asyncSocket->flush(); /* Should we close this connection after a response - and is this response really done? */ if (httpResponseData->state & HttpResponseData::HTTP_CONNECTION_CLOSE) { diff --git a/packages/bun-uws/src/HttpResponse.h b/packages/bun-uws/src/HttpResponse.h index 7524cf2324..b0a6651fe8 100644 --- a/packages/bun-uws/src/HttpResponse.h +++ b/packages/bun-uws/src/HttpResponse.h @@ -122,15 +122,10 @@ public: /* We do not have tryWrite-like functionalities, so ignore optional in this path */ - /* Do not allow sending 0 chunk here */ - if (data.length()) { - Super::write("\r\n", 2); - writeUnsignedHex((unsigned int) data.length()); - Super::write("\r\n", 2); - - /* Ignoring optional for now */ - Super::write(data.data(), (int) data.length()); - } + + /* Write the chunked data if there is any (this will not send zero chunks) */ + this->write(data, nullptr); + /* Terminating 0 chunk */ Super::write("\r\n0\r\n\r\n", 7); @@ -480,6 +475,40 @@ public: return true; } + size_t length = data.length(); + + // Special handling for extremely large data (greater than UINT_MAX bytes) + // most clients expect a max of UINT_MAX, so we need to split the write into multiple writes + if (length > UINT_MAX) { + bool has_failed = false; + size_t total_written = 0; + // Process full-sized chunks until remaining data is less than UINT_MAX + while (length > UINT_MAX) { + size_t written = 0; + // Write a UINT_MAX-sized chunk and check for failure + // even after failure we continue writing because the data will be buffered + if(!this->write(data.substr(0, UINT_MAX), &written)) { + has_failed = true; + } + total_written += written; + length -= UINT_MAX; + data = data.substr(UINT_MAX); + } + // Handle the final chunk (less than UINT_MAX bytes) + if (length > 0) { + size_t written = 0; + if(!this->write(data, &written)) { + has_failed = true; + } + total_written += written; + } + if (writtenPtr) { + *writtenPtr = total_written; + } + return !has_failed; + } + + HttpResponseData *httpResponseData = getHttpResponseData(); if (!(httpResponseData->state & HttpResponseData::HTTP_WROTE_CONTENT_LENGTH_HEADER) && !httpResponseData->fromAncientRequest) { @@ -499,17 +528,36 @@ public: Super::write("\r\n", 2); httpResponseData->state |= HttpResponseData::HTTP_WRITE_CALLED; } + size_t total_written = 0; + bool has_failed = false; - auto [written, failed] = Super::write(data.data(), (int) data.length()); + // Handle data larger than INT_MAX by writing it in chunks of INT_MAX bytes + while (length > INT_MAX) { + // Write the maximum allowed chunk size (INT_MAX) + auto [written, failed] = Super::write(data.data(), INT_MAX); + // If the write failed, set the has_failed flag we continue writting because the data will be buffered + has_failed = has_failed || failed; + total_written += written; + length -= INT_MAX; + data = data.substr(INT_MAX); + } + // Handle the remaining data (less than INT_MAX bytes) + if (length > 0) { + // Write the final chunk with exact remaining length + auto [written, failed] = Super::write(data.data(), (int) length); + has_failed = has_failed || failed; + total_written += written; + } + /* Reset timeout on each sended chunk */ this->resetTimeout(); if (writtenPtr) { - *writtenPtr = written; + *writtenPtr = total_written; } /* If we did not fail the write, accept more */ - return !failed; + return !has_failed; } /* Get the current byte write offset for this Http response */ diff --git a/packages/bun-uws/src/WebSocketContext.h b/packages/bun-uws/src/WebSocketContext.h index 25c6e216ac..16d8092fb0 100644 --- a/packages/bun-uws/src/WebSocketContext.h +++ b/packages/bun-uws/src/WebSocketContext.h @@ -339,7 +339,7 @@ private: /* We store old backpressure since it is unclear whether write drained anything, * however, in case of coming here with 0 backpressure we still need to emit drain event */ - unsigned int backpressure = asyncSocket->getBufferedAmount(); + size_t backpressure = asyncSocket->getBufferedAmount(); /* Drain as much as possible */ asyncSocket->write(nullptr, 0); diff --git a/src/bun.js/api/server/NodeHTTPResponse.zig b/src/bun.js/api/server/NodeHTTPResponse.zig index 7d61a6d960..2b6d58cb87 100644 --- a/src/bun.js/api/server/NodeHTTPResponse.zig +++ b/src/bun.js/api/server/NodeHTTPResponse.zig @@ -747,6 +747,7 @@ fn onDrain(this: *NodeHTTPResponse, offset: u64, response: uws.AnyResponse) bool defer this.deref(); response.clearOnWritable(); if (this.socket_closed or this.request_has_completed) { + // return false means we don't have anything to drain return false; } const on_writable = this.onWritableCallback.trySwap() orelse return false; @@ -754,10 +755,7 @@ fn onDrain(this: *NodeHTTPResponse, offset: u64, response: uws.AnyResponse) bool const vm = globalThis.bunVM(); response.corked(JSC.EventLoop.runCallback, .{ vm.eventLoop(), on_writable, globalThis, .undefined, &.{JSC.JSValue.jsNumberFromUint64(offset)} }); - if (this.socket_closed or this.request_has_completed) { - return false; - } - + // return true means we may have something to drain return true; } @@ -868,7 +866,8 @@ fn writeOrEnd( this.onWritableCallback.set(globalObject, callback_value.withAsyncContextIfNeeded(globalObject)); this.raw_response.onWritable(*NodeHTTPResponse, onDrain, this); } - return JSC.JSValue.jsNumberFromInt64(-@as(i64, @intCast(written))); + + return JSC.JSValue.jsNumberFromInt64(-@as(i64, @intCast(@min(written, std.math.maxInt(i64))))); }, } } diff --git a/src/deps/libuwsockets.cpp b/src/deps/libuwsockets.cpp index 910ab15e47..c33bd399ca 100644 --- a/src/deps/libuwsockets.cpp +++ b/src/deps/libuwsockets.cpp @@ -1018,7 +1018,7 @@ extern "C" (uWS::OpCode)(unsigned char)opcode, compress); } - unsigned int uws_ws_get_buffered_amount(int ssl, uws_websocket_t *ws) + size_t uws_ws_get_buffered_amount(int ssl, uws_websocket_t *ws) { if (ssl) { diff --git a/src/deps/uws.zig b/src/deps/uws.zig index c503ce1877..eba9e59d43 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -2807,7 +2807,7 @@ pub const AnyWebSocket = union(enum) { compress, ); } - pub fn getBufferedAmount(this: AnyWebSocket) u32 { + pub fn getBufferedAmount(this: AnyWebSocket) usize { return switch (this) { .ssl => uws_ws_get_buffered_amount(1, this.ssl.raw()), .tcp => uws_ws_get_buffered_amount(0, this.tcp.raw()), @@ -4014,7 +4014,7 @@ extern fn uws_ws_is_subscribed(ssl: i32, ws: ?*RawWebSocket, topic: [*c]const u8 extern fn uws_ws_iterate_topics(ssl: i32, ws: ?*RawWebSocket, callback: ?*const fn ([*c]const u8, usize, ?*anyopaque) callconv(.C) void, user_data: ?*anyopaque) void; extern fn uws_ws_publish(ssl: i32, ws: ?*RawWebSocket, topic: [*c]const u8, topic_length: usize, message: [*c]const u8, message_length: usize) bool; extern fn uws_ws_publish_with_options(ssl: i32, ws: ?*RawWebSocket, topic: [*c]const u8, topic_length: usize, message: [*c]const u8, message_length: usize, opcode: Opcode, compress: bool) bool; -extern fn uws_ws_get_buffered_amount(ssl: i32, ws: ?*RawWebSocket) c_uint; +extern fn uws_ws_get_buffered_amount(ssl: i32, ws: ?*RawWebSocket) usize; extern fn uws_ws_get_remote_address(ssl: i32, ws: ?*RawWebSocket, dest: *[*]u8) usize; extern fn uws_ws_get_remote_address_as_text(ssl: i32, ws: ?*RawWebSocket, dest: *[*]u8) usize; extern fn uws_res_get_remote_address_info(res: *uws_res, dest: *[*]const u8, port: *i32, is_ipv6: *bool) usize; diff --git a/src/js/node/http.ts b/src/js/node/http.ts index d6c287dd20..41cf358953 100644 --- a/src/js/node/http.ts +++ b/src/js/node/http.ts @@ -1922,6 +1922,10 @@ function callWriteHeadIfObservable(self, headerState) { } } +function allowWritesToContinue() { + this._callPendingCallbacks(); + this.emit("drain"); +} const ServerResponsePrototype = { constructor: ServerResponse, __proto__: OutgoingMessage.prototype, @@ -2119,11 +2123,10 @@ const ServerResponsePrototype = { // If handle.writeHead throws, we don't want headersSent to be set to true. // So we set it here. this[headerStateSymbol] = NodeHTTPHeaderState.sent; - - result = handle.write(chunk, encoding); + result = handle.write(chunk, encoding, allowWritesToContinue.bind(this)); }); } else { - result = handle.write(chunk, encoding); + result = handle.write(chunk, encoding, allowWritesToContinue.bind(this)); } if (result < 0) { diff --git a/test/js/node/http/node-http-backpressure-max.test.ts b/test/js/node/http/node-http-backpressure-max.test.ts new file mode 100644 index 0000000000..143d545b55 --- /dev/null +++ b/test/js/node/http/node-http-backpressure-max.test.ts @@ -0,0 +1,50 @@ +/** + * All new tests in this file should also run in Node.js. + * + * Do not add any tests that only run in Bun. + * + * A handful of older tests do not run in Node in this file. These tests should be updated to run in Node, or deleted. + */ +import { once } from "node:events"; +import http from "node:http"; +import type { AddressInfo } from "node:net"; +import { isCI, isLinux } from "harness"; + +describe("backpressure", () => { + // Linux CI only have 8GB with is not enought because we will clone all or most of this 4GB into memory + it.skipIf(isCI && isLinux)( + "should handle backpressure with the maximum allowed bytes", + async () => { + // max allowed by node:http to be sent in one go, more will throw an error + const payloadSize = 4 * 1024 * 1024 * 1024; + await using server = http.createServer((req, res) => { + res.writeHead(200, { + "Content-Type": "application/octet-stream", + "Transfer-Encoding": "chunked", + }); + const payload = Buffer.allocUnsafe(payloadSize); + res.write(payload, () => { + res.end(); + }); + }); + + await once(server.listen(0), "listening"); + + const PORT = (server.address() as AddressInfo).port; + const response = await fetch(`http://localhost:${PORT}/`); + const reader = (response.body as ReadableStream).getReader(); + let totalBytes = 0; + while (true) { + const { done, value } = await reader.read(); + + if (value) { + totalBytes += value.byteLength; + } + if (done) break; + } + + expect(totalBytes).toBe(payloadSize); + }, + 60_000, + ); +}); diff --git a/test/js/node/http/node-http-backpressure.test.ts b/test/js/node/http/node-http-backpressure.test.ts new file mode 100644 index 0000000000..1c44e4a3a3 --- /dev/null +++ b/test/js/node/http/node-http-backpressure.test.ts @@ -0,0 +1,98 @@ +/** + * All new tests in this file should also run in Node.js. + * + * Do not add any tests that only run in Bun. + * + * A handful of older tests do not run in Node in this file. These tests should be updated to run in Node, or deleted. + */ +import { once } from "node:events"; +import http from "node:http"; +import type { AddressInfo } from "node:net"; + +describe("backpressure", () => { + // INT_MAX is the maximum we can sent to the socket in one call + const TwoGBPayload = Buffer.allocUnsafe(1024 * 1024 * 1024 * 2); + it("should handle backpressure", async () => { + await using server = http.createServer((req, res) => { + res.writeHead(200, { + "Content-Type": "application/octet-stream", + "Transfer-Encoding": "chunked", + }); + // send 3 chunks of 1MB each which is more than the socket buffer and will trigger a backpressure event + const payload = Buffer.alloc(1024 * 1024, "a"); + res.write(payload, () => { + res.write(payload, () => { + res.write(payload, () => { + res.end(); + }); + }); + }); + }); + await once(server.listen(0), "listening"); + + const PORT = (server.address() as AddressInfo).port; + const bytes = await fetch(`http://localhost:${PORT}/`).then(res => res.arrayBuffer()); + expect(bytes.byteLength).toBe(1024 * 1024 * 3); + }); + it("should handle backpressure with INT_MAX bytes", async () => { + await using server = http.createServer((req, res) => { + res.writeHead(200, { + "Content-Type": "application/octet-stream", + "Transfer-Encoding": "chunked", + }); + + res.write(TwoGBPayload, () => { + res.end(); + }); + }); + + await once(server.listen(0), "listening"); + + const PORT = (server.address() as AddressInfo).port; + const response = await fetch(`http://localhost:${PORT}/`); + const reader = (response.body as ReadableStream).getReader(); + let totalBytes = 0; + while (true) { + const { done, value } = await reader.read(); + + if (value) { + totalBytes += value.byteLength; + } + if (done) break; + } + + expect(totalBytes).toBe(TwoGBPayload.byteLength); + }, 30_000); + + it("should handle backpressure with more than INT_MAX bytes", async () => { + // enough to fill the socket buffer + const smallPayloadSize = 1024 * 1024; + await using server = http.createServer((req, res) => { + res.writeHead(200, { + "Content-Type": "application/octet-stream", + "Transfer-Encoding": "chunked", + }); + res.write(Buffer.alloc(smallPayloadSize, "a")); + res.write(TwoGBPayload, () => { + res.end(); + }); + }); + + await once(server.listen(0), "listening"); + + const PORT = (server.address() as AddressInfo).port; + const response = await fetch(`http://localhost:${PORT}/`); + const reader = (response.body as ReadableStream).getReader(); + let totalBytes = 0; + while (true) { + const { done, value } = await reader.read(); + + if (value) { + totalBytes += value.byteLength; + } + if (done) break; + } + + expect(totalBytes).toBe(TwoGBPayload.byteLength + smallPayloadSize); + }, 30_000); +}); diff --git a/test/js/node/http/node-http.test.ts b/test/js/node/http/node-http.test.ts index b53c5f95be..a722abe78e 100644 --- a/test/js/node/http/node-http.test.ts +++ b/test/js/node/http/node-http.test.ts @@ -23,6 +23,7 @@ import http, { validateHeaderName, validateHeaderValue, } from "node:http"; +import type { AddressInfo } from "node:net"; import https, { createServer as createHttpsServer } from "node:https"; import { tmpdir } from "node:os"; import * as path from "node:path";