From 85271f9dd917f0cc0c902660907742ae02b7f2c1 Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Tue, 23 Sep 2025 16:46:59 -0700 Subject: [PATCH] fix(node:http) allow CONNECT in node http/https servers (#22756) ### What does this PR do? Fixes https://github.com/oven-sh/bun/issues/22755 Fixes https://github.com/oven-sh/bun/issues/19790 Fixes https://github.com/oven-sh/bun/issues/16372 ### How did you verify your code works? --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- packages/bun-uws/src/App.h | 8 +- packages/bun-uws/src/HttpContext.h | 37 +- packages/bun-uws/src/HttpContextData.h | 5 + packages/bun-uws/src/HttpParser.h | 66 ++- packages/bun-uws/src/HttpResponse.h | 9 +- packages/bun-uws/src/HttpResponseData.h | 1 + src/bun.js/api/server.zig | 19 +- src/bun.js/api/server/NodeHTTPResponse.zig | 142 +++--- src/bun.js/api/server/ServerWebSocket.zig | 189 ++++---- src/bun.js/bindings/NodeHTTP.cpp | 298 +++++++++++- src/deps/libuwsockets.cpp | 24 +- src/deps/uws/Response.zig | 21 + src/deps/uws/us_socket_t.zig | 116 +++++ src/js/node/_http_server.ts | 106 ++++- test/bun.lock | 27 +- test/internal/ban-limits.json | 2 +- test/js/node/http/node-http-connect.node.mts | 413 +++++++++++++++++ test/js/node/http/node-http-connect.test.ts | 464 +++++++++++++++++++ test/js/node/http/node-http.test.ts | 79 +++- test/package.json | 1 + 20 files changed, 1784 insertions(+), 243 deletions(-) create mode 100644 test/js/node/http/node-http-connect.node.mts create mode 100644 test/js/node/http/node-http-connect.test.ts diff --git a/packages/bun-uws/src/App.h b/packages/bun-uws/src/App.h index d98389e787..6840c23fed 100644 --- a/packages/bun-uws/src/App.h +++ b/packages/bun-uws/src/App.h @@ -627,9 +627,15 @@ public: return std::move(*this); } - void setOnClose(HttpContextData::OnSocketClosedCallback onClose) { + void setOnSocketClosed(HttpContextData::OnSocketClosedCallback onClose) { httpContext->getSocketContextData()->onSocketClosed = onClose; } + void setOnSocketDrain(HttpContextData::OnSocketDrainCallback onDrain) { + httpContext->getSocketContextData()->onSocketDrain = onDrain; + } + void setOnSocketData(HttpContextData::OnSocketDataCallback onData) { + httpContext->getSocketContextData()->onSocketData = onData; + } void setOnClientError(HttpContextData::OnClientErrorCallback onClientError) { httpContext->getSocketContextData()->onClientError = std::move(onClientError); diff --git a/packages/bun-uws/src/HttpContext.h b/packages/bun-uws/src/HttpContext.h index c0866ffdde..6fc803295d 100644 --- a/packages/bun-uws/src/HttpContext.h +++ b/packages/bun-uws/src/HttpContext.h @@ -193,23 +193,32 @@ private: auto *httpResponseData = reinterpret_cast *>(us_socket_ext(SSL, s)); - /* Call filter */ HttpContextData *httpContextData = getSocketContextDataS(s); + + if(httpResponseData && httpResponseData->isConnectRequest) { + if (httpResponseData->socketData && httpContextData->onSocketData) { + httpContextData->onSocketData(httpResponseData->socketData, SSL, s, "", 0, true); + } + if(httpResponseData->inStream) { + httpResponseData->inStream(reinterpret_cast *>(s), "", 0, true, httpResponseData->userData); + httpResponseData->inStream = nullptr; + } + } for (auto &f : httpContextData->filterHandlers) { f((HttpResponse *) s, -1); } + if (httpResponseData->socketData && httpContextData->onSocketClosed) { + httpContextData->onSocketClosed(httpResponseData->socketData, SSL, s); + } /* Signal broken HTTP request only if we have a pending request */ if (httpResponseData->onAborted != nullptr && httpResponseData->userData != nullptr) { httpResponseData->onAborted((HttpResponse *)s, httpResponseData->userData); } - if (httpResponseData->socketData && httpContextData->onSocketClosed) { - httpContextData->onSocketClosed(httpResponseData->socketData, SSL, s); - } /* Destruct socket ext */ httpResponseData->~HttpResponseData(); @@ -254,7 +263,9 @@ private: /* The return value is entirely up to us to interpret. The HttpParser cares only for whether the returned value is DIFFERENT from passed user */ - auto result = httpResponseData->consumePostPadded(httpContextData->maxHeaderSize, httpContextData->flags.requireHostHeader,httpContextData->flags.useStrictMethodValidation, data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * { + auto result = httpResponseData->consumePostPadded(httpContextData->maxHeaderSize, httpResponseData->isConnectRequest, httpContextData->flags.requireHostHeader,httpContextData->flags.useStrictMethodValidation, data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * { + + /* For every request we reset the timeout and hang until user makes action */ /* Warning: if we are in shutdown state, resetting the timer is a security issue! */ us_socket_timeout(SSL, (us_socket_t *) s, 0); @@ -330,7 +341,12 @@ private: /* Continue parsing */ return s; - }, [httpResponseData](void *user, std::string_view data, bool fin) -> void * { + }, [httpResponseData, httpContextData](void *user, std::string_view data, bool fin) -> void * { + + + if (httpResponseData->isConnectRequest && httpResponseData->socketData && httpContextData->onSocketData) { + httpContextData->onSocketData(httpResponseData->socketData, SSL, (struct us_socket_t *) user, data.data(), data.length(), fin); + } /* We always get an empty chunk even if there is no data */ if (httpResponseData->inStream) { @@ -449,7 +465,7 @@ private: us_socket_context_on_writable(SSL, getSocketContext(), [](us_socket_t *s) { auto *asyncSocket = reinterpret_cast *>(s); auto *httpResponseData = reinterpret_cast *>(asyncSocket->getAsyncSocketData()); - + /* Attempt to drain the socket buffer before triggering onWritable callback */ size_t bufferedAmount = asyncSocket->getBufferedAmount(); if (bufferedAmount > 0) { @@ -470,6 +486,12 @@ private: */ } + auto *httpContextData = getSocketContextDataS(s); + + + if (httpResponseData->isConnectRequest && httpResponseData->socketData && httpContextData->onSocketDrain) { + httpContextData->onSocketDrain(httpResponseData->socketData, SSL, (struct us_socket_t *) s); + } /* Ask the developer to write data and return success (true) or failure (false), OR skip sending anything and return success (true). */ if (httpResponseData->onWritable) { /* We are now writable, so hang timeout again, the user does not have to do anything so we should hang until end or tryEnd rearms timeout */ @@ -514,6 +536,7 @@ private: us_socket_context_on_end(SSL, getSocketContext(), [](us_socket_t *s) { auto *asyncSocket = reinterpret_cast *>(s); asyncSocket->uncorkWithoutSending(); + /* We do not care for half closed sockets */ return asyncSocket->close(); }); diff --git a/packages/bun-uws/src/HttpContextData.h b/packages/bun-uws/src/HttpContextData.h index 48ec202dd1..a595927d56 100644 --- a/packages/bun-uws/src/HttpContextData.h +++ b/packages/bun-uws/src/HttpContextData.h @@ -44,7 +44,10 @@ struct alignas(16) HttpContextData { private: std::vector *, int)>> filterHandlers; using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket); + using OnSocketDataCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket, const char *data, int length, bool last); + using OnSocketDrainCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket); using OnClientErrorCallback = MoveOnlyFunction; + MoveOnlyFunction missingServerNameHandler; @@ -61,6 +64,8 @@ private: void *upgradedWebSocket = nullptr; /* Used to simulate Node.js socket events. */ OnSocketClosedCallback onSocketClosed = nullptr; + OnSocketDrainCallback onSocketDrain = nullptr; + OnSocketDataCallback onSocketData = nullptr; OnClientErrorCallback onClientError = nullptr; uint64_t maxHeaderSize = 0; // 0 means no limit diff --git a/packages/bun-uws/src/HttpParser.h b/packages/bun-uws/src/HttpParser.h index f3755a1455..8b6f26fcec 100644 --- a/packages/bun-uws/src/HttpParser.h +++ b/packages/bun-uws/src/HttpParser.h @@ -117,18 +117,19 @@ namespace uWS struct ConsumeRequestLineResult { char *position; bool isAncientHTTP; + bool isConnect; HTTPHeaderParserError headerParserError; public: static ConsumeRequestLineResult error(HTTPHeaderParserError error) { - return ConsumeRequestLineResult{nullptr, false, error}; + return ConsumeRequestLineResult{nullptr, false, false, error}; } - static ConsumeRequestLineResult success(char *position, bool isAncientHTTP = false) { - return ConsumeRequestLineResult{position, isAncientHTTP, HTTP_HEADER_PARSER_ERROR_NONE}; + static ConsumeRequestLineResult success(char *position, bool isAncientHTTP = false, bool isConnect = false) { + return ConsumeRequestLineResult{position, isAncientHTTP, isConnect, HTTP_HEADER_PARSER_ERROR_NONE}; } - static ConsumeRequestLineResult shortRead(bool isAncientHTTP = false) { - return ConsumeRequestLineResult{nullptr, isAncientHTTP, HTTP_HEADER_PARSER_ERROR_NONE}; + static ConsumeRequestLineResult shortRead(bool isAncientHTTP = false, bool isConnect = false) { + return ConsumeRequestLineResult{nullptr, isAncientHTTP, isConnect, HTTP_HEADER_PARSER_ERROR_NONE}; } bool isErrorOrShortRead() { @@ -551,7 +552,10 @@ namespace uWS return ConsumeRequestLineResult::shortRead(); } - if (data[0] == 32 && (__builtin_expect(data[1] == '/', 1) || isHTTPorHTTPSPrefixForProxies(data + 1, end) == 1)) [[likely]] { + + bool isHTTPMethod = (__builtin_expect(data[1] == '/', 1)); + bool isConnect = !isHTTPMethod && (isHTTPorHTTPSPrefixForProxies(data + 1, end) == 1 || ((data - start) == 7 && memcmp(start, "CONNECT", 7) == 0)); + if (isHTTPMethod || isConnect) [[likely]] { header.key = {start, (size_t) (data - start)}; data++; if(!isValidMethod(header.key, useStrictMethodValidation)) { @@ -577,22 +581,22 @@ namespace uWS if (nextPosition >= end) { /* Whatever we have must be part of the version string */ if (memcmp(" HTTP/1.1\r\n", data, std::min(11, (unsigned int) (end - data))) == 0) { - return ConsumeRequestLineResult::shortRead(); + return ConsumeRequestLineResult::shortRead(false, isConnect); } else if (memcmp(" HTTP/1.0\r\n", data, std::min(11, (unsigned int) (end - data))) == 0) { /*Indicates that the request line is ancient HTTP*/ - return ConsumeRequestLineResult::shortRead(true); + return ConsumeRequestLineResult::shortRead(true, isConnect); } return ConsumeRequestLineResult::error(HTTP_HEADER_PARSER_ERROR_INVALID_HTTP_VERSION); } if (memcmp(" HTTP/1.1\r\n", data, 11) == 0) { - return ConsumeRequestLineResult::success(nextPosition); + return ConsumeRequestLineResult::success(nextPosition, false, isConnect); } else if (memcmp(" HTTP/1.0\r\n", data, 11) == 0) { /*Indicates that the request line is ancient HTTP*/ - return ConsumeRequestLineResult::success(nextPosition, true); + return ConsumeRequestLineResult::success(nextPosition, true, isConnect); } /* If we stand at the post padded CR, we have fragmented input so try again later */ if (data[0] == '\r') { - return ConsumeRequestLineResult::shortRead(); + return ConsumeRequestLineResult::shortRead(false, isConnect); } /* This is an error */ return ConsumeRequestLineResult::error(HTTP_HEADER_PARSER_ERROR_INVALID_HTTP_VERSION); @@ -602,14 +606,14 @@ namespace uWS /* If we stand at the post padded CR, we have fragmented input so try again later */ if (data[0] == '\r') { - return ConsumeRequestLineResult::shortRead(); + return ConsumeRequestLineResult::shortRead(false, isConnect); } if (data[0] == 32) { switch (isHTTPorHTTPSPrefixForProxies(data + 1, end)) { // If we haven't received enough data to check if it's http:// or https://, let's try again later case -1: - return ConsumeRequestLineResult::shortRead(); + return ConsumeRequestLineResult::shortRead(false, isConnect); // Otherwise, if it's not http:// or https://, return 400 default: return ConsumeRequestLineResult::error(HTTP_HEADER_PARSER_ERROR_INVALID_REQUEST); @@ -635,7 +639,7 @@ namespace uWS } /* End is only used for the proxy parser. The HTTP parser recognizes "\ra" as invalid "\r\n" scan and breaks. */ - static HttpParserResult getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved, bool &isAncientHTTP, bool useStrictMethodValidation, uint64_t maxHeaderSize) { + static HttpParserResult getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved, bool &isAncientHTTP, bool &isConnectRequest, bool useStrictMethodValidation, uint64_t maxHeaderSize) { char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer; #ifdef UWS_WITH_PROXY /* ProxyParser is passed as reserved parameter */ @@ -689,6 +693,9 @@ namespace uWS if(requestLineResult.isAncientHTTP) { isAncientHTTP = true; } + if(requestLineResult.isConnect) { + isConnectRequest = true; + } /* No request headers found */ const char * headerStart = (headers[0].key.length() > 0) ? headers[0].key.data() : end; @@ -798,7 +805,7 @@ namespace uWS /* This is the only caller of getHeaders and is thus the deepest part of the parser. */ template - HttpParserResult fenceAndConsumePostPadded(uint64_t maxHeaderSize, bool requireHostHeader, bool useStrictMethodValidation, char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction &requestHandler, MoveOnlyFunction &dataHandler) { + HttpParserResult fenceAndConsumePostPadded(uint64_t maxHeaderSize, bool& isConnectRequest, bool requireHostHeader, bool useStrictMethodValidation, char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction &requestHandler, MoveOnlyFunction &dataHandler) { /* How much data we CONSUMED (to throw away) */ unsigned int consumedTotal = 0; @@ -809,7 +816,7 @@ namespace uWS data[length + 1] = 'a'; /* Anything that is not \n, to trigger "invalid request" */ req->ancientHttp = false; for (;length;) { - auto result = getHeaders(data, data + length, req->headers, reserved, req->ancientHttp, useStrictMethodValidation, maxHeaderSize); + auto result = getHeaders(data, data + length, req->headers, reserved, req->ancientHttp, isConnectRequest, useStrictMethodValidation, maxHeaderSize); if(result.isError()) { return result; } @@ -916,6 +923,10 @@ namespace uWS length -= emittable; consumedTotal += emittable; } + } else if(isConnectRequest) { + // This only server to mark that the connect request read all headers + // and can starting emitting data + remainingStreamingBytes = STATE_IS_CHUNKED; } else { /* If we came here without a body; emit an empty data chunk to signal no data */ dataHandler(user, {}, true); @@ -931,15 +942,16 @@ namespace uWS } public: - HttpParserResult consumePostPadded(uint64_t maxHeaderSize, bool requireHostHeader, bool useStrictMethodValidation, char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction &&requestHandler, MoveOnlyFunction &&dataHandler) { - + HttpParserResult consumePostPadded(uint64_t maxHeaderSize, bool& isConnectRequest, bool requireHostHeader, bool useStrictMethodValidation, char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction &&requestHandler, MoveOnlyFunction &&dataHandler) { /* This resets BloomFilter by construction, but later we also reset it again. * Optimize this to skip resetting twice (req could be made global) */ HttpRequest req; if (remainingStreamingBytes) { - - /* It's either chunked or with a content-length */ - if (isParsingChunkedEncoding(remainingStreamingBytes)) { + if (isConnectRequest) { + dataHandler(user, std::string_view(data, length), false); + return HttpParserResult::success(0, user); + } else if (isParsingChunkedEncoding(remainingStreamingBytes)) { + /* It's either chunked or with a content-length */ std::string_view dataToConsume(data, length); for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) { dataHandler(user, chunk, chunk.length() == 0); @@ -950,6 +962,7 @@ public: data = (char *) dataToConsume.data(); length = (unsigned int) dataToConsume.length(); } else { + // this is exactly the same as below! // todo: refactor this if (remainingStreamingBytes >= length) { @@ -980,7 +993,7 @@ public: fallback.append(data, maxCopyDistance); // break here on break - HttpParserResult consumed = fenceAndConsumePostPadded(maxHeaderSize, requireHostHeader, useStrictMethodValidation, fallback.data(), (unsigned int) fallback.length(), user, reserved, &req, requestHandler, dataHandler); + HttpParserResult consumed = fenceAndConsumePostPadded(maxHeaderSize, isConnectRequest, requireHostHeader, useStrictMethodValidation, fallback.data(), (unsigned int) fallback.length(), user, reserved, &req, requestHandler, dataHandler); /* Return data will be different than user if we are upgraded to WebSocket or have an error */ if (consumed.returnedData != user) { return consumed; @@ -997,8 +1010,11 @@ public: length -= consumedBytes - had; if (remainingStreamingBytes) { - /* It's either chunked or with a content-length */ - if (isParsingChunkedEncoding(remainingStreamingBytes)) { + if(isConnectRequest) { + dataHandler(user, std::string_view(data, length), false); + return HttpParserResult::success(0, user); + } else if (isParsingChunkedEncoding(remainingStreamingBytes)) { + /* It's either chunked or with a content-length */ std::string_view dataToConsume(data, length); for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) { dataHandler(user, chunk, chunk.length() == 0); @@ -1037,7 +1053,7 @@ public: } } - HttpParserResult consumed = fenceAndConsumePostPadded(maxHeaderSize, requireHostHeader, useStrictMethodValidation, data, length, user, reserved, &req, requestHandler, dataHandler); + HttpParserResult consumed = fenceAndConsumePostPadded(maxHeaderSize, isConnectRequest, requireHostHeader, useStrictMethodValidation, data, length, user, reserved, &req, requestHandler, dataHandler); /* Return data will be different than user if we are upgraded to WebSocket or have an error */ if (consumed.returnedData != user) { return consumed; diff --git a/packages/bun-uws/src/HttpResponse.h b/packages/bun-uws/src/HttpResponse.h index 03c82ca77d..209e0e79df 100644 --- a/packages/bun-uws/src/HttpResponse.h +++ b/packages/bun-uws/src/HttpResponse.h @@ -243,7 +243,7 @@ public: /* Manually upgrade to WebSocket. Typically called in upgrade handler. Immediately calls open handler. * NOTE: Will invalidate 'this' as socket might change location in memory. Throw away after use. */ template - us_socket_t *upgrade(UserData &&userData, std::string_view secWebSocketKey, std::string_view secWebSocketProtocol, + us_socket_t *upgrade(UserData&& userData, std::string_view secWebSocketKey, std::string_view secWebSocketProtocol, std::string_view secWebSocketExtensions, struct us_socket_context_t *webSocketContext) { @@ -350,7 +350,8 @@ public: us_socket_timeout(SSL, (us_socket_t *) webSocket, webSocketContextData->idleTimeoutComponents.first); /* Move construct the UserData right before calling open handler */ - new (webSocket->getUserData()) UserData(std::move(userData)); + new (webSocket->getUserData()) UserData(std::forward(userData)); + /* Emit open event and start the timeout */ if (webSocketContextData->openHandler) { @@ -741,6 +742,10 @@ public: return httpResponseData->socketData; } + bool isConnectRequest() { + HttpResponseData *httpResponseData = getHttpResponseData(); + return httpResponseData->isConnectRequest; + } void setWriteOffset(uint64_t offset) { HttpResponseData *httpResponseData = getHttpResponseData(); diff --git a/packages/bun-uws/src/HttpResponseData.h b/packages/bun-uws/src/HttpResponseData.h index 26c3428049..8fb572d900 100644 --- a/packages/bun-uws/src/HttpResponseData.h +++ b/packages/bun-uws/src/HttpResponseData.h @@ -108,6 +108,7 @@ struct HttpResponseData : AsyncSocketData, HttpParser { uint8_t state = 0; uint8_t idleTimeout = 10; // default HTTP_TIMEOUT 10 seconds bool fromAncientRequest = false; + bool isConnectRequest = false; bool isIdle = true; bool shouldCloseOnceIdle = false; diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index ec60a78a7b..0ea9ff488b 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -961,18 +961,13 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d // obviously invalid pointer marks it as used upgrader.upgrade_context = @as(*uws.SocketContext, @ptrFromInt(std.math.maxInt(usize))); const signal = upgrader.signal; - upgrader.signal = null; upgrader.resp = null; request.request_context = AnyRequestContext.Null; upgrader.request_weakref.deref(); data_value.ensureStillAlive(); - const ws = ServerWebSocket.new(.{ - .handler = &this.config.websocket.?.handler, - .this_value = data_value, - .signal = signal, - }); + const ws = ServerWebSocket.init(&this.config.websocket.?.handler, data_value, signal); data_value.ensureStillAlive(); var sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator); @@ -2643,7 +2638,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d // If onNodeHTTPRequest is configured, it might be needed for Node.js compatibility layer // for specific Node API routes, even if it's not the main "/*" handler. if (this.config.onNodeHTTPRequest != .zero) { - NodeHTTP_assignOnCloseFunction(ssl_enabled, app); + NodeHTTP_assignOnNodeJSCompat(ssl_enabled, app); } return route_list_value; @@ -2815,7 +2810,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d pub fn onClientErrorCallback(this: *ThisServer, socket: *uws.Socket, error_code: u8, raw_packet: []const u8) void { if (this.on_clienterror.get()) |callback| { const is_ssl = protocol_enum == .https; - const node_socket = bun.jsc.fromJSHostCall(this.globalThis, @src(), Bun__createNodeHTTPServerSocket, .{ is_ssl, socket, this.globalThis }) catch return; + const node_socket = bun.jsc.fromJSHostCall(this.globalThis, @src(), Bun__createNodeHTTPServerSocketForClientError, .{ is_ssl, socket, this.globalThis }) catch return; if (node_socket.isUndefinedOrNull()) return; const error_code_value = JSValue.jsNumber(error_code); @@ -3313,9 +3308,8 @@ extern fn NodeHTTPServer__onRequest_https( node_response_ptr: *?*NodeHTTPResponse, ) jsc.JSValue; -extern fn Bun__createNodeHTTPServerSocket(bool, *anyopaque, *jsc.JSGlobalObject) jsc.JSValue; -extern fn NodeHTTP_assignOnCloseFunction(bool, *anyopaque) void; -extern fn NodeHTTP_setUsingCustomExpectHandler(bool, *anyopaque, bool) void; +extern fn Bun__createNodeHTTPServerSocketForClientError(bool, *anyopaque, *jsc.JSGlobalObject) jsc.JSValue; + extern "c" fn Bun__ServerRouteList__callRoute( globalObject: *jsc.JSGlobalObject, index: u32, @@ -3344,6 +3338,9 @@ fn throwSSLErrorIfNecessary(globalThis: *jsc.JSGlobalObject) bool { return false; } +extern fn NodeHTTP_assignOnNodeJSCompat(bool, *anyopaque) void; +extern fn NodeHTTP_setUsingCustomExpectHandler(bool, *anyopaque, bool) void; + const string = []const u8; const Sys = @import("../../sys.zig"); diff --git a/src/bun.js/api/server/NodeHTTPResponse.zig b/src/bun.js/api/server/NodeHTTPResponse.zig index bc322f087e..94f9dc44e9 100644 --- a/src/bun.js/api/server/NodeHTTPResponse.zig +++ b/src/bun.js/api/server/NodeHTTPResponse.zig @@ -17,7 +17,7 @@ raw_response: uws.AnyResponse, flags: Flags = .{}, -js_ref: jsc.Ref = .{}, +poll_ref: jsc.Ref = .{}, body_read_state: BodyReadState = .none, body_read_ref: jsc.Ref = .{}, @@ -122,17 +122,19 @@ pub fn getServerSocketValue(this: *NodeHTTPResponse) jsc.JSValue { pub fn pauseSocket(this: *NodeHTTPResponse) void { log("pauseSocket", .{}); - if (this.flags.socket_closed or this.flags.upgraded) { + if (this.flags.socket_closed or this.flags.upgraded or this.raw_response.isConnectRequest()) { return; } + this.raw_response.pause(); } pub fn resumeSocket(this: *NodeHTTPResponse) void { log("resumeSocket", .{}); - if (this.flags.socket_closed or this.flags.upgraded) { + if (this.flags.socket_closed or this.flags.upgraded or this.raw_response.isConnectRequest()) { return; } + this.raw_response.@"resume"(); } @@ -145,7 +147,7 @@ const OnBeforeOpen = struct { Bun__setNodeHTTPServerSocketUsSocketValue(ctx.socketValue, socket.asSocket()); ServerWebSocket.js.gc.socket.set(js_websocket, ctx.globalObject, ctx.socketValue); ctx.this.flags.upgraded = true; - defer ctx.this.js_ref.unref(ctx.globalObject.bunVM()); + defer ctx.this.poll_ref.unref(ctx.globalObject.bunVM()); switch (ctx.this.raw_response) { .SSL => ctx.this.raw_response = uws.AnyResponse.init(uws.NewApp(true).Response.castRes(@alignCast(@ptrCast(socket)))), .TCP => ctx.this.raw_response = uws.AnyResponse.init(uws.NewApp(false).Response.castRes(@alignCast(@ptrCast(socket)))), @@ -168,10 +170,7 @@ pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_proto } data_value.ensureStillAlive(); - const ws = ServerWebSocket.new(.{ - .handler = ws_handler, - .this_value = data_value, - }); + const ws = ServerWebSocket.init(ws_handler, data_value, null); var sec_websocket_protocol_str: ?ZigString.Slice = null; defer if (sec_websocket_protocol_str) |*str| str.deinit(); @@ -231,6 +230,7 @@ pub fn maybeStopReadingBody(this: *NodeHTTPResponse, vm: *jsc.VirtualMachine, th { const had_ref = this.body_read_ref.has; if (!this.flags.upgraded and !this.flags.socket_closed) { + log("clearOnData", .{}); this.raw_response.clearOnData(); } @@ -275,7 +275,7 @@ fn markRequestAsDone(this: *NodeHTTPResponse) void { this.buffered_request_body_data_during_pause.clearAndFree(bun.default_allocator); const server = this.server; - this.js_ref.unref(jsc.VirtualMachine.get()); + this.poll_ref.unref(jsc.VirtualMachine.get()); this.deref(); server.onRequestComplete(); } @@ -331,7 +331,7 @@ pub fn create( if (has_body.*) { response.body_read_ref.ref(vm); } - response.js_ref.ref(vm); + response.poll_ref.ref(vm); const js_this = response.toJS(globalObject); node_response_ptr.* = response; return js_this; @@ -400,14 +400,14 @@ pub fn getBufferedAmount(this: *const NodeHTTPResponse, _: *jsc.JSGlobalObject) pub fn jsRef(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!jsc.JSValue { if (!this.isDone()) { - this.js_ref.ref(globalObject.bunVM()); + this.poll_ref.ref(globalObject.bunVM()); } return .js_undefined; } pub fn jsUnref(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!jsc.JSValue { if (!this.isDone()) { - this.js_ref.unref(globalObject.bunVM()); + this.poll_ref.unref(globalObject.bunVM()); } return .js_undefined; } @@ -570,18 +570,16 @@ pub fn onTimeout(this: *NodeHTTPResponse, _: uws.AnyResponse) void { this.handleAbortOrTimeout(.timeout, .zero); } -pub fn doPause(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame, thisValue: jsc.JSValue) bun.JSError!jsc.JSValue { +pub fn doPause(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame, _: jsc.JSValue) bun.JSError!jsc.JSValue { log("doPause", .{}); if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended or this.flags.upgraded) { return .false; } - if (this.body_read_ref.has and js.onDataGetCached(thisValue) == null) { - this.flags.is_data_buffered_during_pause = true; - this.raw_response.onData(*NodeHTTPResponse, onBufferRequestBodyWhilePaused, this); - } + this.flags.is_data_buffered_during_pause = true; + this.raw_response.onData(*NodeHTTPResponse, onBufferRequestBodyWhilePaused, this); + // TODO: figure out why windows is not emitting EOF with UV_DISCONNECT if (!Environment.isWindows) { - // TODO: figure out why windows is not emitting EOF with UV_DISCONNECT pauseSocket(this); } return .true; @@ -592,6 +590,7 @@ pub fn drainRequestBody(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObje } fn drainBufferedRequestBodyFromPause(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject) ?jsc.JSValue { + log("drainBufferedRequestBodyFromPause {d}", .{this.buffered_request_body_data_during_pause.len}); if (this.buffered_request_body_data_during_pause.len > 0) { const result = jsc.JSValue.createBuffer(globalObject, this.buffered_request_body_data_during_pause.slice()); this.buffered_request_body_data_during_pause = .{}; @@ -605,12 +604,10 @@ pub fn doResume(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, _: * if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended or this.flags.upgraded) { return .false; } - + this.setOnAbortedHandler(); + this.raw_response.onData(*NodeHTTPResponse, onData, this); + this.flags.is_data_buffered_during_pause = false; var result: jsc.JSValue = .true; - if (this.flags.is_data_buffered_during_pause) { - this.raw_response.clearOnData(); - this.flags.is_data_buffered_during_pause = false; - } if (this.drainBufferedRequestBodyFromPause(globalObject)) |buffered_data| { result = buffered_data; @@ -626,7 +623,7 @@ pub fn onRequestComplete(this: *NodeHTTPResponse) void { } log("onRequestComplete", .{}); this.flags.request_has_completed = true; - this.js_ref.unref(jsc.VirtualMachine.get()); + this.poll_ref.unref(jsc.VirtualMachine.get()); this.markRequestAsDoneIfNecessary(); } @@ -644,6 +641,7 @@ pub export fn Bun__NodeHTTPRequest__onResolve(globalObject: *jsc.JSGlobalObject, if (this_value != .zero) { js.onAbortedSetCached(this_value, globalObject, .zero); } + log("clearOnData", .{}); this.raw_response.clearOnData(); this.raw_response.clearOnWritable(); this.raw_response.clearTimeout(); @@ -670,6 +668,7 @@ pub export fn Bun__NodeHTTPRequest__onReject(globalObject: *jsc.JSGlobalObject, if (this_value != .zero) { js.onAbortedSetCached(this_value, globalObject, .zero); } + log("clearOnData", .{}); this.raw_response.clearOnData(); this.raw_response.clearOnWritable(); this.raw_response.clearTimeout(); @@ -695,6 +694,7 @@ pub fn abort(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame) return .js_undefined; } resumeSocket(this); + log("clearOnData", .{}); this.raw_response.clearOnData(); this.raw_response.clearOnWritable(); this.raw_response.clearTimeout(); @@ -718,7 +718,43 @@ fn onBufferRequestBodyWhilePaused(this: *NodeHTTPResponse, chunk: []const u8, la } } +fn getBytes(this: *NodeHTTPResponse, globalThis: *jsc.JSGlobalObject, chunk: []const u8) jsc.JSValue { + // TODO: we should have a error event for this but is better than ignoring it + // right now the socket instead of emitting an error event it will reportUncaughtException + // this makes the behavior aligned with current implementation, but not ideal + const bytes: jsc.JSValue = brk: { + if (chunk.len > 0 and this.buffered_request_body_data_during_pause.len > 0) { + const buffer = jsc.JSValue.createBufferFromLength(globalThis, chunk.len + this.buffered_request_body_data_during_pause.len) catch |err| { + globalThis.reportUncaughtExceptionFromError(err); + return .js_undefined; + }; + + const array_buffer = buffer.asArrayBuffer(globalThis).?; + + defer this.buffered_request_body_data_during_pause.clearAndFree(bun.default_allocator); + var input = array_buffer.slice(); + @memcpy(input[0..this.buffered_request_body_data_during_pause.len], this.buffered_request_body_data_during_pause.slice()); + @memcpy(input[this.buffered_request_body_data_during_pause.len..], chunk); + break :brk buffer; + } + + if (this.drainBufferedRequestBodyFromPause(globalThis)) |buffered_data| { + break :brk buffered_data; + } + + if (chunk.len > 0) { + break :brk jsc.ArrayBuffer.createBuffer(globalThis, chunk) catch |err| { + globalThis.reportUncaughtExceptionFromError(err); + return .js_undefined; + }; + } + break :brk .js_undefined; + }; + return bytes; +} + fn onDataOrAborted(this: *NodeHTTPResponse, chunk: []const u8, last: bool, event: AbortEvent, thisValue: jsc.JSValue) void { + log("onDataOrAborted({d}, {})", .{ chunk.len, last }); if (last) { this.ref(); this.body_read_state = .done; @@ -743,27 +779,7 @@ fn onDataOrAborted(this: *NodeHTTPResponse, chunk: []const u8, last: bool, event const globalThis = jsc.VirtualMachine.get().global; const event_loop = globalThis.bunVM().eventLoop(); - const bytes: jsc.JSValue = brk: { - if (chunk.len > 0 and this.buffered_request_body_data_during_pause.len > 0) { - const buffer = jsc.JSValue.createBufferFromLength(globalThis, chunk.len + this.buffered_request_body_data_during_pause.len) catch return; // TODO: properly propagate exception upwards - this.buffered_request_body_data_during_pause.clearAndFree(bun.default_allocator); - if (buffer.asArrayBuffer(globalThis)) |array_buffer| { - var input = array_buffer.slice(); - @memcpy(input[0..this.buffered_request_body_data_during_pause.len], this.buffered_request_body_data_during_pause.slice()); - @memcpy(input[this.buffered_request_body_data_during_pause.len..], chunk); - break :brk buffer; - } - } - - if (this.drainBufferedRequestBodyFromPause(globalThis)) |buffered_data| { - break :brk buffered_data; - } - - if (chunk.len > 0) { - break :brk jsc.ArrayBuffer.createBuffer(globalThis, chunk) catch return; // TODO: properly propagate exception upwards - } - break :brk .js_undefined; - }; + const bytes = this.getBytes(globalThis, chunk); event_loop.runCallback(callback, globalThis, .js_undefined, &.{ bytes, @@ -779,23 +795,29 @@ pub fn onData(this: *NodeHTTPResponse, chunk: []const u8, last: bool) void { onDataOrAborted(this, chunk, last, .none, this.getThisValue()); } -fn onDrain(this: *NodeHTTPResponse, offset: u64, response: uws.AnyResponse) bool { - log("onDrain({d})", .{offset}); +fn onDrainCorked(this: *NodeHTTPResponse, offset: u64) void { + log("onDrainCorked({d})", .{offset}); this.ref(); defer this.deref(); - response.clearOnWritable(); + + const thisValue = this.getThisValue(); + const on_writable = js.onWritableGetCached(thisValue) orelse return; + const globalThis = jsc.VirtualMachine.get().global; + js.onWritableSetCached(thisValue, globalThis, .js_undefined); // TODO(@heimskr): is this necessary? + const vm = globalThis.bunVM(); + + vm.eventLoop().runCallback(on_writable, globalThis, .js_undefined, &.{jsc.JSValue.jsNumberFromUint64(offset)}); +} + +fn onDrain(this: *NodeHTTPResponse, offset: u64, response: uws.AnyResponse) bool { + log("onDrain({d})", .{offset}); + if (this.flags.socket_closed or this.flags.request_has_completed or this.flags.upgraded) { // return false means we don't have anything to drain return false; } - const thisValue = this.getThisValue(); - const on_writable = js.onWritableGetCached(thisValue) orelse return false; - const globalThis = jsc.VirtualMachine.get().global; - js.onWritableSetCached(thisValue, globalThis, .js_undefined); // TODO(@heimskr): is this necessary? - const vm = globalThis.bunVM(); - - response.corked(jsc.EventLoop.runCallback, .{ vm.eventLoop(), on_writable, globalThis, .js_undefined, &.{jsc.JSValue.jsNumberFromUint64(offset)} }); + response.corked(onDrainCorked, .{ this, offset }); // return true means we may have something to drain return true; } @@ -995,12 +1017,15 @@ pub fn setHasCustomOnData(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, value } fn clearOnDataCallback(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject) void { + log("clearOnDataCallback", .{}); if (this.body_read_state != .none) { if (thisValue != .zero) { js.onDataSetCached(thisValue, globalObject, .js_undefined); } - if (!this.flags.socket_closed and !this.flags.upgraded) + if (!this.flags.socket_closed and !this.flags.upgraded) { + log("clearOnData", .{}); this.raw_response.clearOnData(); + } if (this.body_read_state != .done) { this.body_read_state = .done; } @@ -1018,6 +1043,7 @@ pub fn setOnData(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: switch (this.body_read_state) { .pending, .done => { if (!this.flags.request_has_completed and !this.flags.socket_closed and !this.flags.upgraded) { + log("clearOnData", .{}); this.raw_response.clearOnData(); } this.body_read_state = .done; @@ -1133,12 +1159,12 @@ pub fn finalize(this: *NodeHTTPResponse) void { fn deinit(this: *NodeHTTPResponse) void { bun.debugAssert(!this.body_read_ref.has); - bun.debugAssert(!this.js_ref.has); + bun.debugAssert(!this.poll_ref.has); bun.debugAssert(!this.flags.is_request_pending); bun.debugAssert(this.flags.socket_closed or this.flags.request_has_completed); this.buffered_request_body_data_during_pause.deinit(bun.default_allocator); - this.js_ref.unref(jsc.VirtualMachine.get()); + this.poll_ref.unref(jsc.VirtualMachine.get()); this.body_read_ref.unref(jsc.VirtualMachine.get()); this.promise.deinit(); diff --git a/src/bun.js/api/server/ServerWebSocket.zig b/src/bun.js/api/server/ServerWebSocket.zig index 57b50867c0..ffd5bec449 100644 --- a/src/bun.js/api/server/ServerWebSocket.zig +++ b/src/bun.js/api/server/ServerWebSocket.zig @@ -1,9 +1,9 @@ const ServerWebSocket = @This(); -handler: *WebSocketServer.Handler, -this_value: JSValue = .zero, -flags: Flags = .{}, -signal: ?*bun.webcore.AbortSignal = null, +#handler: *WebSocketServer.Handler, +#this_value: jsc.JSRef = .empty(), +#flags: Flags = .{}, +#signal: ?*bun.webcore.AbortSignal = null, // We pack the per-socket data into this struct below const Flags = packed struct(u64) { @@ -26,7 +26,7 @@ const Flags = packed struct(u64) { }; inline fn websocket(this: *const ServerWebSocket) uws.AnyWebSocket { - return this.flags.websocket(); + return this.#flags.websocket(); } pub const js = jsc.Codegen.JSServerWebSocket; @@ -34,10 +34,25 @@ pub const toJS = js.toJS; pub const fromJS = js.fromJS; pub const fromJSDirect = js.fromJSDirect; -pub const new = bun.TrivialNew(ServerWebSocket); +const new = bun.TrivialNew(ServerWebSocket); + +/// Initialize a ServerWebSocket with the given handler, data value, and signal. +/// The signal will not be ref'd inside the ServerWebSocket init function, but will unref itself when the ServerWebSocket is destroyed. +pub fn init(handler: *WebSocketServer.Handler, data_value: jsc.JSValue, signal: ?*bun.webcore.AbortSignal) *ServerWebSocket { + const globalObject = handler.globalObject; + const this = ServerWebSocket.new(.{ + .#handler = handler, + .#signal = signal, + }); + // Get a strong ref and downgrade when terminating/close and GC will be able to collect the newly created value + const this_value = this.toJS(globalObject); + this.#this_value = .initStrong(this_value, globalObject); + js.dataSetCached(this_value, globalObject, data_value); + return this; +} pub fn memoryCost(this: *const ServerWebSocket) usize { - if (this.flags.closed) { + if (this.#flags.closed) { return @sizeOf(ServerWebSocket); } return this.websocket().memoryCost() + @sizeOf(ServerWebSocket); @@ -48,15 +63,12 @@ const log = Output.scoped(.WebSocketServer, .visible); pub fn onOpen(this: *ServerWebSocket, ws: uws.AnyWebSocket) void { log("OnOpen", .{}); - this.flags.packed_websocket_ptr = @truncate(@intFromPtr(ws.raw())); - this.flags.closed = false; - this.flags.ssl = ws == .ssl; + this.#flags.packed_websocket_ptr = @truncate(@intFromPtr(ws.raw())); + this.#flags.closed = false; + this.#flags.ssl = ws == .ssl; - // the this value is initially set to whatever the user passed in - const value_to_cache = this.this_value; - - var handler = this.handler; - const vm = this.handler.vm; + var handler = this.#handler; + const vm = this.#handler.vm; handler.active_connections +|= 1; const globalObject = handler.globalObject; const onOpenHandler = handler.onOpen; @@ -66,25 +78,19 @@ pub fn onOpen(this: *ServerWebSocket, ws: uws.AnyWebSocket) void { return; } - this.this_value = .zero; - this.flags.opened = false; - if (value_to_cache != .zero) { - const current_this = this.getThisValue(); - js.dataSetCached(current_this, globalObject, value_to_cache); - } + this.#flags.opened = false; if (onOpenHandler.isEmptyOrUndefinedOrNull()) { - if (bun.take(&this.handler.onBeforeOpen)) |on_before_open| { + if (bun.take(&this.#handler.onBeforeOpen)) |on_before_open| { // Only create the "this" value if needed. - const this_value = this.getThisValue(); - on_before_open.callback(on_before_open.ctx, this_value, ws.raw()); + on_before_open.callback(on_before_open.ctx, this.#this_value.tryGet() orelse .js_undefined, ws.raw()); } return; } - const this_value = this.getThisValue(); + const this_value = this.#this_value.tryGet() orelse .js_undefined; var args = [_]JSValue{this_value}; - if (bun.take(&this.handler.onBeforeOpen)) |on_before_open| { + if (bun.take(&this.#handler.onBeforeOpen)) |on_before_open| { on_before_open.callback(on_before_open.ctx, this_value, ws.raw()); } @@ -99,12 +105,12 @@ pub fn onOpen(this: *ServerWebSocket, ws: uws.AnyWebSocket) void { }; ws.cork(&corker, Corker.run); const result = corker.result; - this.flags.opened = true; + this.#flags.opened = true; if (result.toError()) |err_value| { log("onOpen exception", .{}); - if (!this.flags.closed) { - this.flags.closed = true; + if (!this.#flags.closed) { + this.#flags.closed = true; // we un-gracefully close the connection if there was an exception // we don't want any event handlers to fire after this for anything other than error() // https://github.com/oven-sh/bun/issues/1480 @@ -117,16 +123,6 @@ pub fn onOpen(this: *ServerWebSocket, ws: uws.AnyWebSocket) void { } } -pub fn getThisValue(this: *ServerWebSocket) JSValue { - var this_value = this.this_value; - if (this_value == .zero) { - this_value = this.toJS(this.handler.globalObject); - this_value.protect(); - this.this_value = this_value; - } - return this_value; -} - pub fn onMessage( this: *ServerWebSocket, ws: uws.AnyWebSocket, @@ -137,11 +133,11 @@ pub fn onMessage( @intFromEnum(opcode), message, }); - const onMessageHandler = this.handler.onMessage; + const onMessageHandler = this.#handler.onMessage; if (onMessageHandler.isEmptyOrUndefinedOrNull()) return; - var globalObject = this.handler.globalObject; + var globalObject = this.#handler.globalObject; // This is the start of a task. - const vm = this.handler.vm; + const vm = this.#handler.vm; if (vm.isShuttingDown()) { log("onMessage called after script execution", .{}); ws.close(); @@ -153,7 +149,7 @@ pub fn onMessage( defer loop.exit(); const arguments = [_]JSValue{ - this.getThisValue(), + this.#this_value.tryGet() orelse .js_undefined, switch (opcode) { .text => bun.String.createUTF8ForJS(globalObject, message) catch .zero, // TODO: properly propagate exception upwards .binary => this.binaryToJS(globalObject, message) catch .zero, // TODO: properly propagate exception upwards @@ -173,7 +169,7 @@ pub fn onMessage( if (result.isEmptyOrUndefinedOrNull()) return; if (result.toError()) |err_value| { - this.handler.runErrorCallback(vm, globalObject, err_value); + this.#handler.runErrorCallback(vm, globalObject, err_value); return; } @@ -190,13 +186,13 @@ pub fn onMessage( } pub inline fn isClosed(this: *const ServerWebSocket) bool { - return this.flags.closed; + return this.#flags.closed; } pub fn onDrain(this: *ServerWebSocket, _: uws.AnyWebSocket) void { log("onDrain", .{}); - const handler = this.handler; + const handler = this.#handler; const vm = handler.vm; if (this.isClosed() or vm.isShuttingDown()) return; @@ -205,7 +201,7 @@ pub fn onDrain(this: *ServerWebSocket, _: uws.AnyWebSocket) void { const globalObject = handler.globalObject; var corker = Corker{ - .args = &[_]jsc.JSValue{this.getThisValue()}, + .args = &[_]jsc.JSValue{this.#this_value.tryGet() orelse .js_undefined}, .globalObject = globalObject, .callback = handler.onDrain, }; @@ -222,7 +218,7 @@ pub fn onDrain(this: *ServerWebSocket, _: uws.AnyWebSocket) void { } fn binaryToJS(this: *const ServerWebSocket, globalThis: *jsc.JSGlobalObject, data: []const u8) bun.JSError!jsc.JSValue { - return switch (this.flags.binary_type) { + return switch (this.#flags.binary_type) { .Buffer => jsc.ArrayBuffer.createBuffer( globalThis, data, @@ -243,7 +239,7 @@ fn binaryToJS(this: *const ServerWebSocket, globalThis: *jsc.JSGlobalObject, dat pub fn onPing(this: *ServerWebSocket, _: uws.AnyWebSocket, data: []const u8) void { log("onPing: {s}", .{data}); - const handler = this.handler; + const handler = this.#handler; var cb = handler.onPing; const vm = handler.vm; if (cb.isEmptyOrUndefinedOrNull() or vm.isShuttingDown()) return; @@ -257,7 +253,7 @@ pub fn onPing(this: *ServerWebSocket, _: uws.AnyWebSocket, data: []const u8) voi _ = cb.call( globalThis, .js_undefined, - &[_]jsc.JSValue{ this.getThisValue(), this.binaryToJS(globalThis, data) catch .zero }, // TODO: properly propagate exception upwards + &[_]jsc.JSValue{ this.#this_value.tryGet() orelse .js_undefined, this.binaryToJS(globalThis, data) catch .zero }, // TODO: properly propagate exception upwards ) catch |e| { const err = globalThis.takeException(e); log("onPing error", .{}); @@ -268,7 +264,7 @@ pub fn onPing(this: *ServerWebSocket, _: uws.AnyWebSocket, data: []const u8) voi pub fn onPong(this: *ServerWebSocket, _: uws.AnyWebSocket, data: []const u8) void { log("onPong: {s}", .{data}); - const handler = this.handler; + const handler = this.#handler; var cb = handler.onPong; if (cb.isEmptyOrUndefinedOrNull()) return; @@ -285,7 +281,7 @@ pub fn onPong(this: *ServerWebSocket, _: uws.AnyWebSocket, data: []const u8) voi _ = cb.call( globalThis, .js_undefined, - &[_]jsc.JSValue{ this.getThisValue(), this.binaryToJS(globalThis, data) catch .zero }, // TODO: properly propagate exception upwards + &[_]jsc.JSValue{ this.#this_value.tryGet() orelse .js_undefined, this.binaryToJS(globalThis, data) catch .zero }, // TODO: properly propagate exception upwards ) catch |e| { const err = globalThis.takeException(e); log("onPong error", .{}); @@ -295,26 +291,27 @@ pub fn onPong(this: *ServerWebSocket, _: uws.AnyWebSocket, data: []const u8) voi pub fn onClose(this: *ServerWebSocket, _: uws.AnyWebSocket, code: i32, message: []const u8) void { log("onClose", .{}); - var handler = this.handler; + // TODO: Can this called inside finalize? + var handler = this.#handler; const was_closed = this.isClosed(); - this.flags.closed = true; + this.#flags.closed = true; defer { if (!was_closed) { handler.active_connections -|= 1; } } - const signal = this.signal; - this.signal = null; - - if (js.socketGetCached(this.getThisValue())) |socket| { - Bun__callNodeHTTPServerSocketOnClose(socket); - } + const signal = this.#signal; + this.#signal = null; defer { if (signal) |sig| { sig.pendingActivityUnref(); sig.unref(); } + + if (this.#this_value.isNotEmpty()) { + this.#this_value.downgrade(); + } } const vm = handler.vm; @@ -337,14 +334,14 @@ pub fn onClose(this: *ServerWebSocket, _: uws.AnyWebSocket, code: i32, message: const message_js = bun.String.createUTF8ForJS(globalObject, message) catch |e| { const err = globalObject.takeException(e); - log("onClose error", .{}); + log("onClose error (message) {}", .{this.#this_value.isNotEmpty()}); handler.runErrorCallback(vm, globalObject, err); return; }; - _ = handler.onClose.call(globalObject, .js_undefined, &[_]jsc.JSValue{ this.getThisValue(), JSValue.jsNumber(code), message_js }) catch |e| { + _ = handler.onClose.call(globalObject, .js_undefined, &[_]jsc.JSValue{ this.#this_value.tryGet() orelse .js_undefined, JSValue.jsNumber(code), message_js }) catch |e| { const err = globalObject.takeException(e); - log("onClose error", .{}); + log("onClose error {}", .{this.#this_value.isNotEmpty()}); handler.runErrorCallback(vm, globalObject, err); return; }; @@ -358,8 +355,6 @@ pub fn onClose(this: *ServerWebSocket, _: uws.AnyWebSocket, code: i32, message: sig.signal(handler.globalObject, .ConnectionClosed); } } - - this.this_value.unprotect(); } pub fn behavior(comptime ServerType: type, comptime ssl: bool, opts: uws.WebSocketBehavior) uws.WebSocketBehavior { @@ -372,6 +367,12 @@ pub fn constructor(globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSE pub fn finalize(this: *ServerWebSocket) void { log("finalize", .{}); + this.#this_value.finalize(); + if (this.#signal) |signal| { + this.#signal = null; + signal.pendingActivityUnref(); + signal.unref(); + } bun.destroy(this); } @@ -387,11 +388,11 @@ pub fn publish( return globalThis.throw("publish requires at least 1 argument", .{}); } - const app = this.handler.app orelse { + const app = this.#handler.app orelse { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const flags = this.handler.flags; + const flags = this.#handler.flags; const ssl = flags.ssl; const publish_to_self = flags.publish_to_self; @@ -474,11 +475,11 @@ pub fn publishText( return globalThis.throw("publish requires at least 1 argument", .{}); } - const app = this.handler.app orelse { + const app = this.#handler.app orelse { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const flags = this.handler.flags; + const flags = this.#handler.flags; const ssl = flags.ssl; const publish_to_self = flags.publish_to_self; @@ -540,11 +541,11 @@ pub fn publishBinary( return globalThis.throw("publishBinary requires at least 1 argument", .{}); } - const app = this.handler.app orelse { + const app = this.#handler.app orelse { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const flags = this.handler.flags; + const flags = this.#handler.flags; const ssl = flags.ssl; const publish_to_self = flags.publish_to_self; const topic_value = args.ptr[0]; @@ -595,11 +596,11 @@ pub fn publishBinaryWithoutTypeChecks( topic_str: *jsc.JSString, array: *jsc.JSUint8Array, ) bun.JSError!jsc.JSValue { - const app = this.handler.app orelse { + const app = this.#handler.app orelse { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const flags = this.handler.flags; + const flags = this.#handler.flags; const ssl = flags.ssl; const publish_to_self = flags.publish_to_self; @@ -634,11 +635,11 @@ pub fn publishTextWithoutTypeChecks( topic_str: *jsc.JSString, str: *jsc.JSString, ) bun.JSError!jsc.JSValue { - const app = this.handler.app orelse { + const app = this.#handler.app orelse { log("publish() closed", .{}); return JSValue.jsNumber(0); }; - const flags = this.handler.flags; + const flags = this.#handler.flags; const ssl = flags.ssl; const publish_to_self = flags.publish_to_self; @@ -674,12 +675,9 @@ pub fn cork( this: *ServerWebSocket, globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame, - // Since we're passing the `this` value to the cork function, we need to - // make sure the `this` value is up to date. this_value: jsc.JSValue, ) bun.JSError!JSValue { const args = callframe.arguments_old(1); - this.this_value = this_value; if (args.len < 1) { return globalThis.throwNotEnoughArguments("cork", 1, 0); @@ -1040,10 +1038,13 @@ inline fn sendPing( } pub fn getData( - _: *ServerWebSocket, + this: *ServerWebSocket, _: *jsc.JSGlobalObject, ) JSValue { log("getData()", .{}); + if (this.#this_value.tryGet()) |this_value| { + return js.dataGetCached(this_value) orelse .js_undefined; + } return .js_undefined; } @@ -1053,7 +1054,9 @@ pub fn setData( value: jsc.JSValue, ) void { log("setData()", .{}); - js.dataSetCached(this.this_value, globalObject, value); + if (this.#this_value.tryGet()) |this_value| { + js.dataSetCached(this_value, globalObject, value); + } } pub fn getReadyState( @@ -1074,11 +1077,10 @@ pub fn close( globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame, // Since close() can lead to the close() callback being called, let's always ensure the `this` value is up to date. - this_value: jsc.JSValue, + _: jsc.JSValue, ) bun.JSError!JSValue { const args = callframe.arguments_old(2); log("close()", .{}); - this.this_value = this_value; if (this.isClosed()) { return .js_undefined; @@ -1104,31 +1106,24 @@ pub fn close( defer message_value.deinit(); - this.flags.closed = true; + this.#flags.closed = true; this.websocket().end(code, message_value.slice()); return .js_undefined; } pub fn terminate( this: *ServerWebSocket, - globalThis: *jsc.JSGlobalObject, - callframe: *jsc.CallFrame, - // Since terminate() can lead to close() being called, let's always ensure the `this` value is up to date. - this_value: jsc.JSValue, + _: *jsc.JSGlobalObject, + _: *jsc.CallFrame, + _: jsc.JSValue, ) bun.JSError!JSValue { - _ = globalThis; - const args = callframe.arguments_old(2); - _ = args; log("terminate()", .{}); - this.this_value = this_value; - if (this.isClosed()) { return .js_undefined; } - this.flags.closed = true; - this.this_value.unprotect(); + this.#flags.closed = true; this.websocket().close(); return .js_undefined; @@ -1140,7 +1135,7 @@ pub fn getBinaryType( ) JSValue { log("getBinaryType()", .{}); - return switch (this.flags.binary_type) { + return switch (this.#flags.binary_type) { .Uint8Array => bun.String.static("uint8array").toJS(globalThis), .Buffer => bun.String.static("nodebuffer").toJS(globalThis), .ArrayBuffer => bun.String.static("arraybuffer").toJS(globalThis), @@ -1156,7 +1151,7 @@ pub fn setBinaryType(this: *ServerWebSocket, globalThis: *jsc.JSGlobalObject, va // some other value which we don't support .Float64Array) { .ArrayBuffer, .Buffer, .Uint8Array => |val| { - this.flags.binary_type = val; + this.#flags.binary_type = val; return; }, else => { @@ -1295,8 +1290,6 @@ const Corker = struct { } }; -extern "c" fn Bun__callNodeHTTPServerSocketOnClose(jsc.JSValue) void; - const string = []const u8; const std = @import("std"); diff --git a/src/bun.js/bindings/NodeHTTP.cpp b/src/bun.js/bindings/NodeHTTP.cpp index 29eca691c0..70c60c2e5e 100644 --- a/src/bun.js/bindings/NodeHTTP.cpp +++ b/src/bun.js/bindings/NodeHTTP.cpp @@ -22,11 +22,32 @@ #include #include "JSSocketAddressDTO.h" +extern "C" { +struct us_socket_stream_buffer_t { + char* list_ptr = nullptr; + size_t list_cap = 0; + size_t listLen = 0; + size_t total_bytes_written = 0; + size_t cursor = 0; + + size_t bufferedSize() const + { + return listLen - cursor; + } + size_t totalBytesWritten() const + { + return total_bytes_written; + } +}; +} + extern "C" uint64_t uws_res_get_remote_address_info(void* res, const char** dest, int* port, bool* is_ipv6); extern "C" uint64_t uws_res_get_local_address_info(void* res, const char** dest, int* port, bool* is_ipv6); extern "C" void Bun__NodeHTTPResponse_setClosed(void* zigResponse); extern "C" void Bun__NodeHTTPResponse_onClose(void* zigResponse, JSC::EncodedJSValue jsValue); +extern "C" EncodedJSValue us_socket_buffered_js_write(void* socket, bool is_ssl, bool ended, us_socket_stream_buffer_t* streamBuffer, JSC::JSGlobalObject* globalObject, JSC::EncodedJSValue data, JSC::EncodedJSValue encoding); +extern "C" void us_socket_free_stream_buffer(us_socket_stream_buffer_t* streamBuffer); namespace Bun { using namespace JSC; @@ -38,9 +59,16 @@ JSC_DEFINE_CUSTOM_SETTER(noOpSetter, (JSGlobalObject * globalObject, JSC::Encode } JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterOnClose); +JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterOnDrain); JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterClosed); JSC_DECLARE_CUSTOM_SETTER(jsNodeHttpServerSocketSetterOnClose); +JSC_DECLARE_CUSTOM_SETTER(jsNodeHttpServerSocketSetterOnDrain); +JSC_DECLARE_CUSTOM_SETTER(jsNodeHttpServerSocketSetterOnData); +JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterOnData); +JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterBytesWritten); JSC_DECLARE_HOST_FUNCTION(jsFunctionNodeHTTPServerSocketClose); +JSC_DECLARE_HOST_FUNCTION(jsFunctionNodeHTTPServerSocketWrite); +JSC_DECLARE_HOST_FUNCTION(jsFunctionNodeHTTPServerSocketEnd); JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterResponse); JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterRemoteAddress); JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterLocalAddress); @@ -52,12 +80,17 @@ JSC_DECLARE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterIsSecureEstablished); // Create a static hash table of values containing an onclose DOMAttributeGetterSetter and a close function static const HashTableValue JSNodeHTTPServerSocketPrototypeTableValues[] = { { "onclose"_s, static_cast(PropertyAttribute::CustomAccessor), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterOnClose, jsNodeHttpServerSocketSetterOnClose } }, + { "ondrain"_s, static_cast(PropertyAttribute::CustomAccessor), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterOnDrain, jsNodeHttpServerSocketSetterOnDrain } }, + { "ondata"_s, static_cast(PropertyAttribute::CustomAccessor), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterOnData, jsNodeHttpServerSocketSetterOnData } }, + { "bytesWritten"_s, static_cast(PropertyAttribute::CustomAccessor), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterBytesWritten, noOpSetter } }, { "closed"_s, static_cast(PropertyAttribute::CustomAccessor | PropertyAttribute::ReadOnly), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterClosed, noOpSetter } }, { "response"_s, static_cast(PropertyAttribute::CustomAccessor | PropertyAttribute::ReadOnly), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterResponse, noOpSetter } }, { "duplex"_s, static_cast(PropertyAttribute::CustomAccessor), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterDuplex, jsNodeHttpServerSocketSetterDuplex } }, { "remoteAddress"_s, static_cast(PropertyAttribute::CustomAccessor | PropertyAttribute::ReadOnly), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterRemoteAddress, noOpSetter } }, { "localAddress"_s, static_cast(PropertyAttribute::CustomAccessor | PropertyAttribute::ReadOnly), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterLocalAddress, noOpSetter } }, { "close"_s, static_cast(PropertyAttribute::Function | PropertyAttribute::DontEnum), NoIntrinsic, { HashTableValue::NativeFunctionType, jsFunctionNodeHTTPServerSocketClose, 0 } }, + { "write"_s, static_cast(PropertyAttribute::Function | PropertyAttribute::DontEnum), NoIntrinsic, { HashTableValue::NativeFunctionType, jsFunctionNodeHTTPServerSocketWrite, 2 } }, + { "end"_s, static_cast(PropertyAttribute::Function | PropertyAttribute::DontEnum), NoIntrinsic, { HashTableValue::NativeFunctionType, jsFunctionNodeHTTPServerSocketEnd, 0 } }, { "secureEstablished"_s, static_cast(PropertyAttribute::CustomAccessor | PropertyAttribute::ReadOnly), NoIntrinsic, { HashTableValue::GetterSetterType, jsNodeHttpServerSocketGetterIsSecureEstablished, noOpSetter } }, }; @@ -102,6 +135,12 @@ private: class JSNodeHTTPServerSocket : public JSC::JSDestructibleObject { public: using Base = JSC::JSDestructibleObject; + us_socket_stream_buffer_t streamBuffer = {}; + us_socket_t* socket = nullptr; + unsigned is_ssl : 1 = 0; + unsigned ended : 1 = 0; + JSC::Strong strongThis = {}; + static JSNodeHTTPServerSocket* create(JSC::VM& vm, JSC::Structure* structure, us_socket_t* socket, bool is_ssl, WebCore::JSNodeHTTPResponse* response) { auto* object = new (JSC::allocateCell(vm)) JSNodeHTTPServerSocket(vm, structure, socket, is_ssl, response); @@ -161,6 +200,7 @@ public: clearSocketData(socket); } } + us_socket_free_stream_buffer(&streamBuffer); } JSNodeHTTPServerSocket(JSC::VM& vm, JSC::Structure* structure, us_socket_t* socket, bool is_ssl, WebCore::JSNodeHTTPResponse* response) @@ -172,15 +212,13 @@ public: } mutable WriteBarrier functionToCallOnClose; + mutable WriteBarrier functionToCallOnDrain; + mutable WriteBarrier functionToCallOnData; mutable WriteBarrier currentResponseObject; mutable WriteBarrier m_remoteAddress; mutable WriteBarrier m_localAddress; mutable WriteBarrier m_duplex; - unsigned is_ssl : 1; - us_socket_t* socket; - JSC::Strong strongThis = {}; - DECLARE_INFO; DECLARE_VISIT_CHILDREN; @@ -206,6 +244,7 @@ public: void onClose() { + this->socket = nullptr; if (auto* res = this->currentResponseObject.get(); res != nullptr && res->m_ctx != nullptr) { Bun__NodeHTTPResponse_setClosed(res->m_ctx); @@ -257,6 +296,107 @@ public: } } + void onDrain() + { + // This function can be called during GC! + Zig::GlobalObject* globalObject = static_cast(this->globalObject()); + if (!functionToCallOnDrain) { + return; + } + + auto bufferedSize = this->streamBuffer.bufferedSize(); + if (bufferedSize > 0) { + + auto* globalObject = defaultGlobalObject(this->globalObject()); + auto scope = DECLARE_CATCH_SCOPE(globalObject->vm()); + us_socket_buffered_js_write(this->socket, this->is_ssl, this->ended, &this->streamBuffer, globalObject, JSValue::encode(JSC::jsUndefined()), JSValue::encode(JSC::jsUndefined())); + if (scope.exception()) { + globalObject->reportUncaughtExceptionAtEventLoop(globalObject, scope.exception()); + return; + } + bufferedSize = this->streamBuffer.bufferedSize(); + + if (bufferedSize > 0) { + // need to drain more + return; + } + } + WebCore::ScriptExecutionContext* scriptExecutionContext = globalObject->scriptExecutionContext(); + + if (scriptExecutionContext) { + scriptExecutionContext->postTask([self = this](ScriptExecutionContext& context) { + WTF::NakedPtr exception; + auto* globalObject = defaultGlobalObject(context.globalObject()); + auto* thisObject = self; + auto* callbackObject = thisObject->functionToCallOnDrain.get(); + if (!callbackObject) { + return; + } + auto callData = JSC::getCallData(callbackObject); + MarkedArgumentBuffer args; + EnsureStillAliveScope ensureStillAlive(self); + + if (globalObject->scriptExecutionStatus(globalObject, thisObject) == ScriptExecutionStatus::Running) { + profiledCall(globalObject, JSC::ProfilingReason::API, callbackObject, callData, thisObject, args, exception); + + if (auto* ptr = exception.get()) { + exception.clear(); + globalObject->reportUncaughtExceptionAtEventLoop(globalObject, ptr); + } + } + }); + } + } + + void + onData(const char* data, int length, bool last) + { + // This function can be called during GC! + Zig::GlobalObject* globalObject = static_cast(this->globalObject()); + if (!functionToCallOnData) { + return; + } + + WebCore::ScriptExecutionContext* scriptExecutionContext = globalObject->scriptExecutionContext(); + + if (scriptExecutionContext) { + auto scope = DECLARE_CATCH_SCOPE(globalObject->vm()); + JSC::JSUint8Array* buffer = WebCore::createBuffer(globalObject, std::span(reinterpret_cast(data), length)); + auto chunk = JSC::JSValue(buffer); + if (scope.exception()) { + globalObject->reportUncaughtExceptionAtEventLoop(globalObject, scope.exception()); + return; + } + gcProtect(chunk); + scriptExecutionContext->postTask([self = this, chunk = chunk, last = last](ScriptExecutionContext& context) { + WTF::NakedPtr exception; + auto* globalObject = defaultGlobalObject(context.globalObject()); + auto* thisObject = self; + auto* callbackObject = thisObject->functionToCallOnData.get(); + EnsureStillAliveScope ensureChunkStillAlive(chunk); + gcUnprotect(chunk); + if (!callbackObject) { + return; + } + + auto callData = JSC::getCallData(callbackObject); + MarkedArgumentBuffer args; + args.append(chunk); + args.append(JSC::jsBoolean(last)); + EnsureStillAliveScope ensureStillAlive(self); + + if (globalObject->scriptExecutionStatus(globalObject, thisObject) == ScriptExecutionStatus::Running) { + profiledCall(globalObject, JSC::ProfilingReason::API, callbackObject, callData, thisObject, args, exception); + + if (auto* ptr = exception.get()) { + exception.clear(); + globalObject->reportUncaughtExceptionAtEventLoop(globalObject, ptr); + } + } + }); + } + } + static Structure* createStructure(JSC::VM& vm, JSC::JSGlobalObject* globalObject) { auto* structure = JSC::Structure::create(vm, globalObject, globalObject->objectPrototype(), JSC::TypeInfo(JSC::ObjectType, StructureFlags), JSNodeHTTPServerSocketPrototype::info()); @@ -284,6 +424,37 @@ JSC_DEFINE_HOST_FUNCTION(jsFunctionNodeHTTPServerSocketClose, (JSC::JSGlobalObje return JSValue::encode(JSC::jsUndefined()); } +JSC_DEFINE_HOST_FUNCTION(jsFunctionNodeHTTPServerSocketWrite, (JSC::JSGlobalObject * globalObject, JSC::CallFrame* callFrame)) +{ + auto* thisObject = jsDynamicCast(callFrame->thisValue()); + if (!thisObject) [[unlikely]] { + return JSValue::encode(JSC::jsNumber(0)); + } + if (thisObject->isClosed() || thisObject->ended) { + return JSValue::encode(JSC::jsNumber(0)); + } + + return us_socket_buffered_js_write(thisObject->socket, thisObject->is_ssl, thisObject->ended, &thisObject->streamBuffer, globalObject, JSValue::encode(callFrame->argument(0)), JSValue::encode(callFrame->argument(1))); +} + +JSC_DEFINE_HOST_FUNCTION(jsFunctionNodeHTTPServerSocketEnd, (JSC::JSGlobalObject * globalObject, JSC::CallFrame* callFrame)) +{ + auto* thisObject = jsDynamicCast(callFrame->thisValue()); + if (!thisObject) [[unlikely]] { + return JSValue::encode(JSC::jsUndefined()); + } + if (thisObject->isClosed()) { + return JSValue::encode(JSC::jsUndefined()); + } + + thisObject->ended = true; + auto bufferedSize = thisObject->streamBuffer.bufferedSize(); + if (bufferedSize == 0) { + return us_socket_buffered_js_write(thisObject->socket, thisObject->is_ssl, thisObject->ended, &thisObject->streamBuffer, globalObject, JSValue::encode(JSC::jsUndefined()), JSValue::encode(JSC::jsUndefined())); + } + return JSValue::encode(JSC::jsUndefined()); +} + JSC_DEFINE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterIsSecureEstablished, (JSC::JSGlobalObject * globalObject, JSC::EncodedJSValue thisValue, JSC::PropertyName)) { auto* thisObject = jsCast(JSC::JSValue::decode(thisValue)); @@ -390,6 +561,66 @@ JSC_DEFINE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterOnClose, (JSC::JSGlobalObje return JSValue::encode(JSC::jsUndefined()); } +JSC_DEFINE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterOnDrain, (JSC::JSGlobalObject * globalObject, JSC::EncodedJSValue thisValue, JSC::PropertyName)) +{ + auto* thisObject = jsCast(JSC::JSValue::decode(thisValue)); + + if (thisObject->functionToCallOnDrain) { + return JSValue::encode(thisObject->functionToCallOnDrain.get()); + } + + return JSValue::encode(JSC::jsUndefined()); +} +JSC_DEFINE_CUSTOM_SETTER(jsNodeHttpServerSocketSetterOnDrain, (JSC::JSGlobalObject * globalObject, JSC::EncodedJSValue thisValue, JSC::EncodedJSValue encodedValue, JSC::PropertyName propertyName)) +{ + auto& vm = globalObject->vm(); + auto scope = DECLARE_THROW_SCOPE(vm); + + auto* thisObject = jsCast(JSC::JSValue::decode(thisValue)); + JSValue value = JSC::JSValue::decode(encodedValue); + + if (value.isUndefined() || value.isNull()) { + thisObject->functionToCallOnDrain.clear(); + return true; + } + + if (!value.isCallable()) { + return false; + } + + thisObject->functionToCallOnDrain.set(vm, thisObject, value.getObject()); + return true; +} +JSC_DEFINE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterOnData, (JSC::JSGlobalObject * globalObject, JSC::EncodedJSValue thisValue, JSC::PropertyName)) +{ + auto* thisObject = jsCast(JSC::JSValue::decode(thisValue)); + + if (thisObject->functionToCallOnData) { + return JSValue::encode(thisObject->functionToCallOnData.get()); + } + + return JSValue::encode(JSC::jsUndefined()); +} +JSC_DEFINE_CUSTOM_SETTER(jsNodeHttpServerSocketSetterOnData, (JSC::JSGlobalObject * globalObject, JSC::EncodedJSValue thisValue, JSC::EncodedJSValue encodedValue, JSC::PropertyName propertyName)) +{ + auto& vm = globalObject->vm(); + auto scope = DECLARE_THROW_SCOPE(vm); + + auto* thisObject = jsCast(JSC::JSValue::decode(thisValue)); + JSValue value = JSC::JSValue::decode(encodedValue); + + if (value.isUndefined() || value.isNull()) { + thisObject->functionToCallOnData.clear(); + return true; + } + + if (!value.isCallable()) { + return false; + } + + thisObject->functionToCallOnData.set(vm, thisObject, value.getObject()); + return true; +} JSC_DEFINE_CUSTOM_SETTER(jsNodeHttpServerSocketSetterOnClose, (JSC::JSGlobalObject * globalObject, JSC::EncodedJSValue thisValue, JSC::EncodedJSValue encodedValue, JSC::PropertyName propertyName)) { auto& vm = globalObject->vm(); @@ -417,6 +648,12 @@ JSC_DEFINE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterClosed, (JSGlobalObject * g return JSValue::encode(JSC::jsBoolean(thisObject->isClosed())); } +JSC_DEFINE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterBytesWritten, (JSGlobalObject * globalObject, JSC::EncodedJSValue thisValue, PropertyName propertyName)) +{ + auto* thisObject = jsCast(JSC::JSValue::decode(thisValue)); + return JSValue::encode(JSC::jsNumber(thisObject->streamBuffer.totalBytesWritten())); +} + JSC_DEFINE_CUSTOM_GETTER(jsNodeHttpServerSocketGetterResponse, (JSGlobalObject * globalObject, JSC::EncodedJSValue thisValue, PropertyName propertyName)) { auto* thisObject = jsCast(JSC::JSValue::decode(thisValue)); @@ -436,6 +673,8 @@ void JSNodeHTTPServerSocket::visitChildrenImpl(JSCell* cell, Visitor& visitor) visitor.append(fn->currentResponseObject); visitor.append(fn->functionToCallOnClose); + visitor.append(fn->functionToCallOnDrain); + visitor.append(fn->functionToCallOnData); visitor.append(fn->m_remoteAddress); visitor.append(fn->m_localAddress); visitor.append(fn->m_duplex); @@ -498,31 +737,45 @@ extern "C" void Bun__setNodeHTTPServerSocketUsSocketValue(EncodedJSValue thisVal response->socket = socket; } -extern "C" void Bun__callNodeHTTPServerSocketOnClose(EncodedJSValue thisValue) -{ - auto* response = jsCast(JSValue::decode(thisValue)); - response->onClose(); -} - -extern "C" JSC::EncodedJSValue Bun__createNodeHTTPServerSocket(bool isSSL, us_socket_t* us_socket, Zig::GlobalObject* globalObject) +extern "C" JSC::EncodedJSValue Bun__createNodeHTTPServerSocketForClientError(bool isSSL, us_socket_t* us_socket, Zig::GlobalObject* globalObject) { auto& vm = globalObject->vm(); auto scope = DECLARE_THROW_SCOPE(vm); RETURN_IF_EXCEPTION(scope, {}); + if (isSSL) { + uWS::HttpResponse* response = reinterpret_cast*>(us_socket); + auto* currentSocketDataPtr = reinterpret_cast(response->getHttpResponseData()->socketData); + if (currentSocketDataPtr) { + return JSValue::encode(currentSocketDataPtr); + } + } else { + uWS::HttpResponse* response = reinterpret_cast*>(us_socket); + auto* currentSocketDataPtr = reinterpret_cast(response->getHttpResponseData()->socketData); + if (currentSocketDataPtr) { + return JSValue::encode(currentSocketDataPtr); + } + } // socket without response because is not valid http JSNodeHTTPServerSocket* socket = JSNodeHTTPServerSocket::create( vm, globalObject->m_JSNodeHTTPServerSocketStructure.getInitializedOnMainThread(globalObject), us_socket, isSSL, nullptr); - + if (isSSL) { + uWS::HttpResponse* response = reinterpret_cast*>(us_socket); + response->getHttpResponseData()->socketData = socket; + } else { + uWS::HttpResponse* response = reinterpret_cast*>(us_socket); + response->getHttpResponseData()->socketData = socket; + } RETURN_IF_EXCEPTION(scope, {}); if (socket) { socket->strongThis.set(vm, socket); return JSValue::encode(socket); } + return JSValue::encode(JSC::jsNull()); } @@ -873,21 +1126,31 @@ static EncodedJSValue assignHeadersFromUWebSockets(uWS::HttpRequest* request, JS } template -static void assignOnCloseFunction(uWS::TemplatedApp* app) +static void assignOnNodeJSCompat(uWS::TemplatedApp* app) { - app->setOnClose([](void* socketData, int is_ssl, struct us_socket_t* rawSocket) -> void { + app->setOnSocketClosed([](void* socketData, int is_ssl, struct us_socket_t* rawSocket) -> void { auto* socket = reinterpret_cast(socketData); ASSERT(rawSocket == socket->socket || socket->socket == nullptr); socket->onClose(); }); + app->setOnSocketDrain([](void* socketData, int is_ssl, struct us_socket_t* rawSocket) -> void { + auto* socket = reinterpret_cast(socketData); + ASSERT(rawSocket == socket->socket || socket->socket == nullptr); + socket->onDrain(); + }); + app->setOnSocketData([](void* socketData, int is_ssl, struct us_socket_t* rawSocket, const char* data, int length, bool last) -> void { + auto* socket = reinterpret_cast(socketData); + ASSERT(rawSocket == socket->socket || socket->socket == nullptr); + socket->onData(data, length, last); + }); } -extern "C" void NodeHTTP_assignOnCloseFunction(bool is_ssl, void* uws_app) +extern "C" void NodeHTTP_assignOnNodeJSCompat(bool is_ssl, void* uws_app) { if (is_ssl) { - assignOnCloseFunction(reinterpret_cast*>(uws_app)); + assignOnNodeJSCompat(reinterpret_cast*>(uws_app)); } else { - assignOnCloseFunction(reinterpret_cast*>(uws_app)); + assignOnNodeJSCompat(reinterpret_cast*>(uws_app)); } } @@ -1481,6 +1744,7 @@ JSValue createNodeHTTPInternalBinding(Zig::GlobalObject* globalObject) obj->putDirectNativeFunction( vm, globalObject, JSC::PropertyName(JSC::Identifier::fromString(vm, "drainMicrotasks"_s)), 0, Bun__drainMicrotasksFromJS, ImplementationVisibility::Public, Intrinsic::NoIntrinsic, 0); + return obj; } diff --git a/src/deps/libuwsockets.cpp b/src/deps/libuwsockets.cpp index 07bcff0e42..64372ffc58 100644 --- a/src/deps/libuwsockets.cpp +++ b/src/deps/libuwsockets.cpp @@ -1624,7 +1624,7 @@ size_t uws_req_get_header(uws_req_t *res, const char *lower_case_header, uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; return uwsRes->template upgrade( - data ? std::move(data) : NULL, + data ? std::move(data) : nullptr, stringViewFromC(sec_web_socket_key, sec_web_socket_key_length), stringViewFromC(sec_web_socket_protocol, sec_web_socket_protocol_length), stringViewFromC(sec_web_socket_extensions, @@ -1634,7 +1634,7 @@ size_t uws_req_get_header(uws_req_t *res, const char *lower_case_header, uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; return uwsRes->template upgrade( - data ? std::move(data) : NULL, + data ? std::move(data) : nullptr, stringViewFromC(sec_web_socket_key, sec_web_socket_key_length), stringViewFromC(sec_web_socket_protocol, sec_web_socket_protocol_length), stringViewFromC(sec_web_socket_extensions, @@ -1811,6 +1811,26 @@ __attribute__((callback (corker, ctx))) } } + void *uws_res_get_socket_data(int ssl, uws_res_r res) { + if (ssl) { + uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; + return uwsRes->getSocketData(); + } else { + uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; + return uwsRes->getSocketData(); + } + } + + bool uws_res_is_connect_request(int ssl, uws_res_r res) + { + if (ssl) { + uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; + return uwsRes->isConnectRequest(); + } else { + uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; + return uwsRes->isConnectRequest(); + } + } void *uws_res_get_native_handle(int ssl, uws_res_r res) { if (ssl) diff --git a/src/deps/uws/Response.zig b/src/deps/uws/Response.zig index 18ec88e83d..757392d78a 100644 --- a/src/deps/uws/Response.zig +++ b/src/deps/uws/Response.zig @@ -34,6 +34,14 @@ pub fn NewResponse(ssl_flag: i32) type { return c.uws_res_try_end(ssl_flag, res.downcast(), data.ptr, data.len, total, close_); } + pub fn getSocketData(res: *Response) ?*anyopaque { + return c.uws_res_get_socket_data(ssl_flag, res.downcast()); + } + + pub fn isConnectRequest(res: *Response) bool { + return c.uws_res_is_connect_request(ssl_flag, res.downcast()); + } + pub fn flushHeaders(res: *Response) void { c.uws_res_flush_headers(ssl_flag, res.downcast()); } @@ -359,6 +367,11 @@ pub const AnyResponse = union(enum) { inline else => |resp| resp.downcast(), }; } + pub fn getSocketData(this: AnyResponse) ?*anyopaque { + return switch (this) { + inline else => |resp| resp.getSocketData(), + }; + } pub fn getRemoteSocketInfo(this: AnyResponse) ?SocketAddress { return switch (this) { inline else => |resp| resp.getRemoteSocketInfo(), @@ -554,6 +567,12 @@ pub const AnyResponse = union(enum) { } } + pub fn isConnectRequest(this: AnyResponse) bool { + return switch (this) { + inline else => |resp| resp.isConnectRequest(), + }; + } + pub fn endStream(this: AnyResponse, close_connection: bool) void { switch (this) { inline else => |resp| resp.endStream(close_connection), @@ -635,10 +654,12 @@ const c = struct { pub extern fn uws_res_write_mark(ssl: i32, res: *c.uws_res) void; pub extern fn us_socket_mark_needs_more_not_ssl(socket: ?*c.uws_res) void; pub extern fn uws_res_state(ssl: c_int, res: *const c.uws_res) State; + pub extern fn uws_res_is_connect_request(ssl: i32, res: *c.uws_res) bool; pub extern fn uws_res_get_remote_address_info(res: *c.uws_res, dest: *[*]const u8, port: *i32, is_ipv6: *bool) usize; pub extern fn uws_res_uncork(ssl: i32, res: *c.uws_res) void; pub extern fn uws_res_end(ssl: i32, res: *c.uws_res, data: [*c]const u8, length: usize, close_connection: bool) void; pub extern fn uws_res_flush_headers(ssl: i32, res: *c.uws_res) void; + pub extern fn uws_res_get_socket_data(ssl: i32, res: *c.uws_res) ?*uws.SocketData; pub extern fn uws_res_pause(ssl: i32, res: *c.uws_res) void; pub extern fn uws_res_resume(ssl: i32, res: *c.uws_res) void; pub extern fn uws_res_write_continue(ssl: i32, res: *c.uws_res) void; diff --git a/src/deps/uws/us_socket_t.zig b/src/deps/uws/us_socket_t.zig index bd84853b52..0af82ca70e 100644 --- a/src/deps/uws/us_socket_t.zig +++ b/src/deps/uws/us_socket_t.zig @@ -226,6 +226,119 @@ pub const c = struct { ) ?*us_socket_t; pub extern fn us_socket_get_error(ssl: i32, s: *uws.us_socket_t) c_int; pub extern fn us_socket_is_established(ssl: i32, s: *uws.us_socket_t) i32; + + const us_socket_stream_buffer_t = extern struct { + list_ptr: ?[*]u8 = null, + list_cap: usize = 0, + list_len: usize = 0, + total_bytes_written: usize = 0, + cursor: usize = 0, + + pub fn update(this: *us_socket_stream_buffer_t, stream_buffer: bun.io.StreamBuffer) void { + if (stream_buffer.list.capacity > 0) { + this.list_ptr = stream_buffer.list.items.ptr; + } else { + this.list_ptr = null; + } + this.list_len = stream_buffer.list.items.len; + this.list_cap = stream_buffer.list.capacity; + this.cursor = stream_buffer.cursor; + } + pub fn wrote(this: *us_socket_stream_buffer_t, written: usize) void { + this.total_bytes_written +|= written; + } + + pub fn toStreamBuffer(this: *us_socket_stream_buffer_t) bun.io.StreamBuffer { + return .{ + .list = if (this.list_ptr) |buffer_ptr| .{ + .allocator = bun.default_allocator, + .items = buffer_ptr[0..this.list_len], + .capacity = this.list_cap, + } else .{ + .allocator = bun.default_allocator, + .items = &.{}, + .capacity = 0, + }, + .cursor = this.cursor, + }; + } + + pub fn deinit(this: *us_socket_stream_buffer_t) void { + if (this.list_ptr) |buffer| { + bun.default_allocator.free(buffer[0..this.list_cap]); + } + } + }; + + export fn us_socket_free_stream_buffer(buffer: *us_socket_stream_buffer_t) void { + buffer.deinit(); + } + export fn us_socket_buffered_js_write( + socket: *uws.us_socket_t, + is_ssl: bool, + ended: bool, + buffer: *us_socket_stream_buffer_t, + globalObject: *jsc.JSGlobalObject, + data: jsc.JSValue, + encoding: jsc.JSValue, + ) jsc.JSValue { + // convever it back to StreamBuffer + var stream_buffer = buffer.toStreamBuffer(); + var total_written: usize = 0; + // update the buffer pointer to the new buffer + defer { + buffer.update(stream_buffer); + buffer.wrote(total_written); + } + + var stack_fallback = std.heap.stackFallback(16 * 1024, bun.default_allocator); + const node_buffer: jsc.Node.BlobOrStringOrBuffer = if (data.isUndefined()) + jsc.Node.BlobOrStringOrBuffer{ .string_or_buffer = jsc.Node.StringOrBuffer.empty } + else + jsc.Node.BlobOrStringOrBuffer.fromJSWithEncodingValueMaybeAsyncAllowRequestResponse(globalObject, stack_fallback.get(), data, encoding, false, true) catch { + return .zero; + } orelse { + if (!globalObject.hasException()) { + return globalObject.throwInvalidArgumentTypeValue("data", "string, buffer, or blob", data) catch .zero; + } + return .zero; + }; + + defer node_buffer.deinit(); + if (node_buffer == .blob and node_buffer.blob.needsToReadFile()) { + return globalObject.throw("File blob not supported yet in this function.", .{}) catch .zero; + } + + const data_slice = node_buffer.slice(); + if (stream_buffer.isNotEmpty()) { + // need to flush + const to_flush = stream_buffer.slice(); + const written: u32 = @max(0, socket.write(is_ssl, to_flush)); + stream_buffer.wrote(written); + total_written +|= written; + if (written < to_flush.len) { + if (data_slice.len > 0) { + bun.handleOom(stream_buffer.write(data_slice)); + } + return JSValue.jsBoolean(false); + } + // stream buffer is empty now + } + + if (data_slice.len > 0) { + const written: u32 = @max(0, socket.write(is_ssl, data_slice)); + total_written +|= written; + if (written < data_slice.len) { + bun.handleOom(stream_buffer.write(data_slice[written..])); + return JSValue.jsBoolean(false); + } + } + if (ended) { + // last part so we shutdown the writable side of the socket aka send FIN + socket.shutdown(is_ssl); + } + return JSValue.jsBoolean(true); + } }; const bun = @import("bun"); @@ -233,3 +346,6 @@ const std = @import("std"); const uws = @import("../uws.zig"); const SocketContext = uws.SocketContext; + +const jsc = bun.jsc; +const JSValue = jsc.JSValue; diff --git a/src/js/node/_http_server.ts b/src/js/node/_http_server.ts index aef1343eea..9d9314074a 100644 --- a/src/js/node/_http_server.ts +++ b/src/js/node/_http_server.ts @@ -352,8 +352,9 @@ Server.prototype[EventEmitter.captureRejectionSymbol] = function (err, event, .. Server.prototype[Symbol.asyncDispose] = function () { const { resolve, reject, promise } = Promise.withResolvers(); this.close(function (err, ...args) { - if (err) reject(err); - else resolve(...args); + if (err) { + reject(err); + } else resolve(...args); }); return promise; }; @@ -474,7 +475,6 @@ Server.prototype[kRealListen] = function (tls, port, host, socketPath, reusePort if (tls) { this.serverName = tls.serverName || host || "localhost"; } - this[serverSymbol] = Bun.serve({ idleTimeout: 0, // nodejs dont have a idleTimeout by default tls, @@ -528,10 +528,30 @@ Server.prototype[kRealListen] = function (tls, port, host, socketPath, reusePort if (isAncientHTTP) { http_req.httpVersion = "1.0"; } + if (method === "CONNECT") { + // Handle CONNECT method for HTTP tunneling/proxy + if (server.listenerCount("connect") > 0) { + // For CONNECT, emit the event and let the handler respond + // Don't assign the socket to a response for CONNECT + // The handler should write the raw response + socket[kEnableStreaming](true); + const { promise, resolve } = $newPromiseCapability(Promise); + socket.once("close", resolve); + server.emit("connect", http_req, socket, kEmptyBuffer); + return promise; + } else { + // Node.js will close the socket and will NOT respond with 400 Bad Request + socketHandle.close(); + } + return; + } + socket[kEnableStreaming](false); + const http_res = new ResponseClass(http_req, { [kHandle]: handle, [kRejectNonStandardBodyWrites]: server.rejectNonStandardBodyWrites, }); + setIsNextIncomingMessageHTTPS(prevIsNextIncomingMessageHTTPS); handle.onabort = onServerRequestEvent.bind(socket); // start buffering data if any, the user will need to resume() or .on("data") to read it @@ -677,6 +697,7 @@ Server.prototype[kRealListen] = function (tls, port, host, socketPath, reusePort // return promise; // }, }); + getBunServerAllClosedPromise(this[serverSymbol]).$then(emitCloseNTServer.bind(this)); isHTTPS = this[serverSymbol].protocol === "https"; // always set strict method validation to true for node.js compatibility @@ -784,14 +805,18 @@ function onServerClientError(ssl: boolean, socket: unknown, errorCode: number, r } } +const kBytesWritten = Symbol("kBytesWritten"); +const kEnableStreaming = Symbol("kEnableStreaming"); const NodeHTTPServerSocket = class Socket extends Duplex { bytesRead = 0; connecting = false; timeout = 0; + [kBytesWritten] = 0; [kHandle]; server: Server; _httpMessage; _secureEstablished = false; + #pendingCallback = null; constructor(server: Server, handle, encrypted) { super(); this.server = server; @@ -799,15 +824,56 @@ const NodeHTTPServerSocket = class Socket extends Duplex { this._secureEstablished = !!handle?.secureEstablished; handle.onclose = this.#onClose.bind(this); handle.duplex = this; + this.encrypted = encrypted; this.on("timeout", onNodeHTTPServerSocketTimeout); } get bytesWritten() { - return this[kHandle]?.response?.getBytesWritten?.() ?? 0; + const handle = this[kHandle]; + return handle + ? (handle.response?.getBytesWritten?.() ?? handle.bytesWritten ?? this[kBytesWritten] ?? 0) + : (this[kBytesWritten] ?? 0); + } + set bytesWritten(value) { + this[kBytesWritten] = value; } - set bytesWritten(value) {} + [kEnableStreaming](enable: boolean) { + const handle = this[kHandle]; + if (handle) { + if (enable) { + handle.ondata = this.#onData.bind(this); + handle.ondrain = this.#onDrain.bind(this); + } else { + handle.ondata = undefined; + handle.ondrain = undefined; + } + } + } + #onDrain() { + const handle = this[kHandle]; + this[kBytesWritten] = handle ? (handle.response?.getBytesWritten?.() ?? handle.bytesWritten ?? 0) : 0; + const callback = this.#pendingCallback; + if (callback) { + this.#pendingCallback = null; + (callback as Function)(); + } + this.emit("drain"); + } + #onData(chunk, last) { + if (chunk) { + this.push(chunk); + } + if (last) { + const handle = this[kHandle]; + if (handle) { + handle.ondata = undefined; + } + + this.push(null); + } + } #closeHandle(handle, callback) { this[kHandle] = undefined; handle.onclose = this.#onCloseForDestroy.bind(this, callback); @@ -822,8 +888,10 @@ const NodeHTTPServerSocket = class Socket extends Duplex { } #onClose() { this[kHandle] = null; + const message = this._httpMessage; const req = message?.req; + if (req && !req.complete && !req[kHandle]?.upgraded) { // At this point the socket is already destroyed; let's avoid UAF req[kHandle] = undefined; @@ -833,6 +901,7 @@ const NodeHTTPServerSocket = class Socket extends Duplex { req.destroy(); } } + this.emit("close"); } #onCloseForDestroy(closeCallback) { this.#onClose(); @@ -871,9 +940,10 @@ const NodeHTTPServerSocket = class Socket extends Duplex { $isCallable(callback) && callback(err); return; } + handle.ondata = undefined; if (handle.closed) { const onclose = handle.onclose; - handle.onclose = null; + handle.onclose = undefined; if ($isCallable(onclose)) { onclose.$call(handle); } @@ -890,7 +960,8 @@ const NodeHTTPServerSocket = class Socket extends Duplex { callback(); return; } - this.#closeHandle(handle, callback); + handle.end(); + callback(); } get localAddress() { @@ -998,7 +1069,20 @@ const NodeHTTPServerSocket = class Socket extends Duplex { return this; } - _write(_chunk, _encoding, _callback) {} + _write(_chunk, _encoding, _callback) { + const handle = this[kHandle]; + // only enable writting if we can drain + let err; + try { + if (handle && handle.ondrain && !handle.write(_chunk, _encoding)) { + this.#pendingCallback = _callback; + return false; + } + } catch (e) { + err = e; + } + err ? _callback(err) : _callback(); + } pause() { const handle = this[kHandle]; @@ -1006,6 +1090,7 @@ const NodeHTTPServerSocket = class Socket extends Duplex { if (response) { response.pause(); } + return super.pause(); } @@ -1138,8 +1223,12 @@ function ServerResponse(req, options): void { if (handle) { this[kHandle] = handle; + } else { + this[kHandle] = req[kHandle]; } this[kRejectNonStandardBodyWrites] = options[kRejectNonStandardBodyWrites] ?? false; + } else { + this[kHandle] = req[kHandle]; } this.statusCode = 200; @@ -1622,6 +1711,7 @@ function ServerResponse_finalDeprecated(chunk, encoding, callback) { chunk = Buffer.from(chunk, encoding); } const req = this.req; + const shouldEmitClose = req && req.emit && !this.finished; if (!this.headersSent) { let data = this[firstWriteSymbol]; diff --git a/test/bun.lock b/test/bun.lock index d5dc61d49f..ef010a555b 100644 --- a/test/bun.lock +++ b/test/bun.lock @@ -71,6 +71,7 @@ "postgres": "3.3.5", "prisma": "5.1.1", "prompts": "2.4.2", + "proxy": "2.2.0", "react": "file:../node_modules/react", "react-dom": "18.3.1", "reflect-metadata": "0.2.2", @@ -892,6 +893,8 @@ "argparse": ["argparse@2.0.1", "", {}, "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q=="], + "args": ["args@5.0.3", "", { "dependencies": { "camelcase": "5.0.0", "chalk": "2.4.2", "leven": "2.1.0", "mri": "1.1.4" } }, "sha512-h6k/zfFgusnv3i5TU08KQkVKuCPBtL/PWQbWkHUxvJrZ2nAyeaUupneemcrgn1xmqxPQsPIzwkUhOpoqPDRZuA=="], + "aria-query": ["aria-query@5.3.2", "", {}, "sha512-COROpnaoap1E2F000S62r6A60uHZnmlvomhfyT2DlTcrY1OrBKn2UhH7qn5wTC9zMvD0AY7csdPSNwKP+7WiQw=="], "array-flatten": ["array-flatten@1.1.1", "", {}, "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg=="], @@ -962,6 +965,8 @@ "basic-auth": ["basic-auth@2.0.1", "", { "dependencies": { "safe-buffer": "5.1.2" } }, "sha512-NF+epuEdnUYVlGuhaxbbq+dvJttwLnGY+YixlXlME5KpQ5W3CnXA5cVTneY3SPbPDRkcjMbifrwmFYcClgOZeg=="], + "basic-auth-parser": ["basic-auth-parser@0.0.2-1", "", {}, "sha512-GFj8iVxo9onSU6BnnQvVwqvxh60UcSHJEDnIk3z4B6iOjsKSmqe+ibW0Rsz7YO7IE1HG3D3tqCNIidP46SZVdQ=="], + "basic-ftp": ["basic-ftp@5.0.5", "", {}, "sha512-4Bcg1P8xhUuqcii/S0Z9wiHIrQVPMermM1any+MX5GeGD7faD3/msQUDGLol9wOcz4/jbg/WJnGqoJF6LiBdtg=="], "bcrypt-pbkdf": ["bcrypt-pbkdf@1.0.2", "", { "dependencies": { "tweetnacl": "^0.14.3" } }, "sha512-qeFIXtP4MSoi6NLqO12WfqARWWuCKi2Rn/9hJLEmtB5yTNr9DqFWkJRCf2qShWzPeAMRnOgCrq0sg/KLv5ES9w=="], @@ -1264,7 +1269,7 @@ "escape-html": ["escape-html@1.0.3", "", {}, "sha512-NiSupZ4OeuGwr68lGIeym/ksIZMJodUGOSCZ/FSnTxcrekbvqrgdUxlJOMpijaKZVjAJrWrGs/6Jy8OMuyj9ow=="], - "escape-string-regexp": ["escape-string-regexp@5.0.0", "", {}, "sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw=="], + "escape-string-regexp": ["escape-string-regexp@1.0.5", "", {}, "sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg=="], "escodegen": ["escodegen@2.1.0", "", { "dependencies": { "esprima": "^4.0.1", "estraverse": "^5.2.0", "esutils": "^2.0.2" }, "optionalDependencies": { "source-map": "~0.6.1" }, "bin": { "esgenerate": "bin/esgenerate.js", "escodegen": "bin/escodegen.js" } }, "sha512-2NlIDTwUWJN0mRPQOdtQBzbUHvdGY2P1VXSyU83Q3xKxM7WHX2Ql8dKq782Q9TgQUNOLEzEYu9bzLNj1q88I5w=="], @@ -1670,6 +1675,8 @@ "kleur": ["kleur@4.1.5", "", {}, "sha512-o+NO+8WrRiQEE4/7nwRJhN1HWpVmJm511pBHUxPLtp0BUISzlBplORYSmTclCnJvQq2tKu/sgl3xVpkc7ZWuQQ=="], + "leven": ["leven@2.1.0", "", {}, "sha512-nvVPLpIHUxCUoRLrFqTgSxXJ614d8AgQoWl7zPe/2VadE8+1dpU3LBhowRuBAcuwruWtOdD8oYC9jDNJjXDPyA=="], + "light-my-request": ["light-my-request@6.6.0", "", { "dependencies": { "cookie": "^1.0.1", "process-warning": "^4.0.0", "set-cookie-parser": "^2.6.0" } }, "sha512-CHYbu8RtboSIoVsHZ6Ye4cj4Aw/yg2oAFimlF7mNvfDV192LR7nDiKtSIfCuLT7KokPSTn/9kfVLm5OGN0A28A=="], "lines-and-columns": ["lines-and-columns@1.2.4", "", {}, "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg=="], @@ -1872,6 +1879,8 @@ "morgan": ["morgan@1.10.0", "", { "dependencies": { "basic-auth": "~2.0.1", "debug": "2.6.9", "depd": "~2.0.0", "on-finished": "~2.3.0", "on-headers": "~1.0.2" } }, "sha512-AbegBVI4sh6El+1gNwvD5YIck7nSA36weD7xvIxG4in80j/UoK8AEGaWnnz8v1GxonMCltmlNs5ZKbGvl9b1XQ=="], + "mri": ["mri@1.1.4", "", {}, "sha512-6y7IjGPm8AzlvoUrwAaw1tLnUBudaS3752vcd8JtrpGGQn+rXIe63LFVHm/YMwtqAuh+LJPCFdlLYPWM1nYn6w=="], + "mrmime": ["mrmime@2.0.1", "", {}, "sha512-Y3wQdFg2Va6etvQ5I82yUhGdsKrcYox6p7FfL1LbK2J4V01F9TGlepTIhnK24t7koZibmg82KGglhA1XK5IsLQ=="], "ms": ["ms@2.1.3", "", {}, "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA=="], @@ -2110,6 +2119,8 @@ "protobufjs": ["protobufjs@7.3.2", "", { "dependencies": { "@protobufjs/aspromise": "^1.1.2", "@protobufjs/base64": "^1.1.2", "@protobufjs/codegen": "^2.0.4", "@protobufjs/eventemitter": "^1.1.0", "@protobufjs/fetch": "^1.1.0", "@protobufjs/float": "^1.0.2", "@protobufjs/inquire": "^1.1.0", "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", "@types/node": ">=13.7.0", "long": "^5.0.0" } }, "sha512-RXyHaACeqXeqAKGLDl68rQKbmObRsTIn4TYVUUug1KfS47YWCo5MacGITEryugIgZqORCvJWEk4l449POg5Txg=="], + "proxy": ["proxy@2.2.0", "", { "dependencies": { "args": "^5.0.3", "basic-auth-parser": "0.0.2-1", "debug": "^4.3.4" }, "bin": { "proxy": "dist/bin/proxy.js" } }, "sha512-nYclNIWj9UpXbVJ3W5EXIYiGR88AKZoGt90kyh3zoOBY5QW+7bbtPvMFgKGD4VJmpS3UXQXtlGXSg3lRNLOFLg=="], + "proxy-addr": ["proxy-addr@2.0.7", "", { "dependencies": { "forwarded": "0.2.0", "ipaddr.js": "1.9.1" } }, "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg=="], "proxy-agent": ["proxy-agent@6.4.0", "", { "dependencies": { "agent-base": "^7.0.2", "debug": "^4.3.4", "http-proxy-agent": "^7.0.1", "https-proxy-agent": "^7.0.3", "lru-cache": "^7.14.1", "pac-proxy-agent": "^7.0.1", "proxy-from-env": "^1.1.0", "socks-proxy-agent": "^8.0.2" } }, "sha512-u0piLU+nCOHMgGjRbimiXmA9kM/L9EHh3zL81xCdp7m+Y2pHIsnmbdDoEDoAz5geaonNR6q6+yOPQs6n4T6sBQ=="], @@ -2924,6 +2935,10 @@ "are-we-there-yet/readable-stream": ["readable-stream@3.6.2", "", { "dependencies": { "inherits": "^2.0.3", "string_decoder": "^1.1.1", "util-deprecate": "^1.0.1" } }, "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA=="], + "args/camelcase": ["camelcase@5.0.0", "", {}, "sha512-faqwZqnWxbxn+F1d399ygeamQNy3lPp/H9H6rNrqYh4FSVCtcY+3cub1MxA8o9mDd55mM8Aghuu/kuyYA6VTsA=="], + + "args/chalk": ["chalk@2.4.2", "", { "dependencies": { "ansi-styles": "^3.2.1", "escape-string-regexp": "^1.0.5", "supports-color": "^5.3.0" } }, "sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ=="], + "astro/acorn": ["acorn@8.14.1", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-OvQ/2pUDKmgfCg++xsTX1wGxfTaszcHVcTctW4UJB4hibJx2HXxxO5UmVgyjMa+ZDsiaf5wWLXYpRWMmBI0QHg=="], "astro/debug": ["debug@4.4.0", "", { "dependencies": { "ms": "^2.1.3" } }, "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA=="], @@ -3088,6 +3103,8 @@ "make-fetch-happen/proc-log": ["proc-log@4.2.0", "", {}, "sha512-g8+OnU/L2v+wyiVK+D5fA34J7EH8jZ8DDlvwhRCMxmMj7UCBvxiO1mGeN+36JXIKF4zevU4kRBd8lVgG9vLelA=="], + "mdast-util-find-and-replace/escape-string-regexp": ["escape-string-regexp@5.0.0", "", {}, "sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw=="], + "micromark/debug": ["debug@4.4.0", "", { "dependencies": { "ms": "^2.1.3" } }, "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA=="], "micromatch/picomatch": ["picomatch@2.3.1", "", {}, "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA=="], @@ -3158,6 +3175,8 @@ "prompts/kleur": ["kleur@3.0.3", "", {}, "sha512-eTIzlVOSUR+JxdDFepEYcBMtZ9Qqdef+rnzWdRZuMbOywu5tO2w2N7rqjoANZ5k9vywhL6Br1VRjUIgTQx4E8w=="], + "proxy/debug": ["debug@4.4.0", "", { "dependencies": { "ms": "^2.1.3" } }, "sha512-6WTZ/IxCY/T6BALoZHaE4ctp9xm+Z5kY/pzYaCHRFeyVhojxlrm+46y68HA6hr0TcwEssoxNiDEUJQjfPZ/RYA=="], + "proxy-agent/debug": ["debug@4.3.7", "", { "dependencies": { "ms": "^2.1.3" } }, "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ=="], "proxy-agent/lru-cache": ["lru-cache@7.18.3", "", {}, "sha512-jumlc0BIUrS3qJGgIkWZsyfAM7NCWiBcCDhnd+3NNM5KbBmLTgHVfWBcg6W+rLUsIpzpERPsvwUP7CckAQSOoA=="], @@ -3532,6 +3551,8 @@ "ansi-align/string-width/strip-ansi": ["strip-ansi@6.0.1", "", { "dependencies": { "ansi-regex": "^5.0.1" } }, "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A=="], + "args/chalk/ansi-styles": ["ansi-styles@3.2.1", "", { "dependencies": { "color-convert": "^1.9.0" } }, "sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA=="], + "astro/esbuild/@esbuild/android-arm": ["@esbuild/android-arm@0.25.1", "", { "os": "android", "cpu": "arm" }, "sha512-dp+MshLYux6j/JjdqVLnMglQlFu+MuVeNrmT5nk6q07wNhCdSnB7QZj+7G8VMUGh1q+vj2Bq8kRsuyA00I/k+Q=="], "astro/esbuild/@esbuild/android-arm64": ["@esbuild/android-arm64@0.25.1", "", { "os": "android", "cpu": "arm64" }, "sha512-50tM0zCJW5kGqgG7fQ7IHvQOcAn9TKiVRuQ/lN0xR+T2lzEFvAi1ZcS8DiksFcEpf1t/GYOeOfCAgDHFpkiSmA=="], @@ -3974,6 +3995,8 @@ "ansi-align/string-width/strip-ansi/ansi-regex": ["ansi-regex@5.0.1", "", {}, "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ=="], + "args/chalk/ansi-styles/color-convert": ["color-convert@1.9.3", "", { "dependencies": { "color-name": "1.1.3" } }, "sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg=="], + "astro/sharp/@img/sharp-wasm32/@emnapi/runtime": ["@emnapi/runtime@1.4.0", "", { "dependencies": { "tslib": "^2.4.0" } }, "sha512-64WYIf4UYcdLnbKn/umDlNjQDSS8AgZrI/R9+x5ilkUVFxXcA1Ebl+gQLc/6mERA4407Xof0R7wEyEuj091CVw=="], "cli-highlight/chalk/supports-color/has-flag": ["has-flag@4.0.0", "", {}, "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ=="], @@ -4152,6 +4175,8 @@ "yargs/string-width/strip-ansi/ansi-regex": ["ansi-regex@5.0.1", "", {}, "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ=="], + "args/chalk/ansi-styles/color-convert/color-name": ["color-name@1.1.3", "", {}, "sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw=="], + "astro/sharp/@img/sharp-wasm32/@emnapi/runtime/tslib": ["tslib@2.8.1", "", {}, "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w=="], "cli-highlight/yargs/cliui/strip-ansi/ansi-regex": ["ansi-regex@5.0.1", "", {}, "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ=="], diff --git a/test/internal/ban-limits.json b/test/internal/ban-limits.json index 78029a9f56..6980821328 100644 --- a/test/internal/ban-limits.json +++ b/test/internal/ban-limits.json @@ -4,7 +4,7 @@ " catch bun.outOfMemory()": 0, "!= alloc.ptr": 0, "!= allocator.ptr": 0, - ".arguments_old(": 266, + ".arguments_old(": 265, ".jsBoolean(false)": 0, ".jsBoolean(true)": 0, ".stdDir()": 41, diff --git a/test/js/node/http/node-http-connect.node.mts b/test/js/node/http/node-http-connect.node.mts new file mode 100644 index 0000000000..2bbc7303a8 --- /dev/null +++ b/test/js/node/http/node-http-connect.node.mts @@ -0,0 +1,413 @@ +/** + * All new tests in this file should also run in Node.js. + * + * Do not add any tests that only run in Bun. + * + * A handful of older tests do not run in Node in this file. These tests should be updated to run in Node, or deleted. + */ + +import { describe, test } from "node:test"; +import assert from "node:assert"; + +function expect(value: any) { + return { + toBe: (expected: any) => { + assert.strictEqual(value, expected); + }, + toContain: (expected: any) => { + assert.ok(value.includes(expected)); + }, + toBeInstanceOf: (expected: any) => { + assert.ok(value instanceof expected); + }, + toBeGreaterThan: (expected: any) => { + assert.ok(value > expected); + }, + toBeLessThan: (expected: any) => { + assert.ok(value < expected); + }, + toEqual: (expected: any) => { + assert.deepStrictEqual(value, expected); + }, + not: { + toBe: (expected: any) => { + assert.notStrictEqual(value, expected); + }, + toContain: (expected: any) => { + assert.ok(!value.includes(expected)); + }, + toBeInstanceOf: (expected: any) => { + assert.ok(!(value instanceof expected)); + }, + toBeGreaterThan: (expected: any) => { + assert.ok(!(value > expected)); + }, + toBeLessThan: (expected: any) => { + assert.ok(!(value < expected)); + }, + toEqual: (expected: any) => { + assert.notDeepStrictEqual(value, expected); + }, + }, + }; +} +import http from "http"; +import { createProxy } from "proxy"; + +import { once } from "node:events"; +import type { AddressInfo } from "node:net"; +import net from "node:net"; + +function connectClient(proxyAddress: AddressInfo, targetAddress: AddressInfo, add_http_prefix: boolean) { + const client = net.connect({ port: proxyAddress.port, host: proxyAddress.address }, () => { + client.write( + `CONNECT ${add_http_prefix ? "http://" : ""}${targetAddress.address}:${targetAddress.port} HTTP/1.1\r\nHost: ${targetAddress.address}:${targetAddress.port}\r\nProxy-Authorization: Basic dXNlcjpwYXNzd29yZA==\r\n\r\n`, + ); + }); + + const received: string[] = []; + const { promise, resolve, reject } = Promise.withResolvers(); + + client.on("data", data => { + if (data.toString().includes("200 Connection established")) { + client.write("GET / HTTP/1.1\r\nHost: www.example.com:80\r\nConnection: close\r\n\r\n"); + } + received.push(data.toString()); + }); + client.on("error", reject); + + client.on("end", () => { + resolve(received.join("")); + }); + return promise; +} + +const BIG_DATA = Buffer.alloc(1024 * 64, "bun").toString(); + +describe("HTTP server CONNECT", () => { + test("should work with proxy package", async () => { + await using targetServer = http.createServer((req, res) => { + res.end("Hello World from target server"); + }); + await using proxyServer = createProxy(http.createServer()); + let proxyHeaders = {}; + proxyServer.authenticate = req => { + proxyHeaders = req.headers; + return true; + }; + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + await once(targetServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + const targetAddress = targetServer.address() as AddressInfo; + + { + // server should support http prefix but the proxy package it self does not + // this behavior is consistent with node.js + const response = await connectClient(proxyAddress, targetAddress, true); + expect(proxyHeaders["proxy-authorization"]).toBe("Basic dXNlcjpwYXNzd29yZA=="); + expect(response).toContain("HTTP/1.1 404 Not Found"); + } + + { + proxyHeaders = {}; + const response = await connectClient(proxyAddress, targetAddress, false); + expect(proxyHeaders["proxy-authorization"]).toBe("Basic dXNlcjpwYXNzd29yZA=="); + expect(response).toContain("HTTP/1.1 200 OK"); + expect(response).toContain("Hello World from target server"); + } + }); + + test("should work with raw sockets", async () => { + await using proxyServer = http.createServer((req, res) => { + res.end("Hello World from proxy server"); + }); + await using targetServer = http.createServer((req, res) => { + res.end("Hello World from target server"); + }); + let proxyHeaders = {}; + proxyServer.on("connect", (req, socket, head) => { + proxyHeaders = req.headers; + const [host, port] = req.url?.split(":") ?? []; + + const serverSocket = net.connect(parseInt(port), host, () => { + socket.write(`HTTP/1.1 200 Connection established\r\nConnection: close\r\n\r\n`); + serverSocket.pipe(socket); + socket.pipe(serverSocket); + }); + serverSocket.on("error", err => { + socket.end("HTTP/1.1 502 Bad Gateway\r\n\r\n"); + }); + socket.on("error", err => { + serverSocket.destroy(); + }); + + socket.on("end", () => serverSocket.end()); + serverSocket.on("end", () => socket.end()); + }); + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + await once(targetServer.listen(0, "127.0.0.1"), "listening"); + const targetAddress = targetServer.address() as AddressInfo; + + { + const response = await connectClient(proxyAddress, targetAddress, false); + expect(proxyHeaders["proxy-authorization"]).toBe("Basic dXNlcjpwYXNzd29yZA=="); + expect(response).toContain("HTTP/1.1 200 OK"); + expect(response).toContain("Hello World from target server"); + } + }); + + test("should handle multiple concurrent CONNECT requests", async () => { + await using proxyServer = http.createServer((req, res) => { + res.end("Hello World from proxy server"); + }); + + await using targetServer = http.createServer((req, res) => { + res.end(`Response for ${req.url}`); + }); + + let connectionCount = 0; + proxyServer.on("connect", (req, socket, head) => { + connectionCount++; + const [host, port] = req.url?.split(":") ?? []; + + const serverSocket = net.connect(parseInt(port), host, () => { + socket.write(`HTTP/1.1 200 Connection established\r\n\r\n`); + serverSocket.pipe(socket); + socket.pipe(serverSocket); + }); + + serverSocket.on("error", () => socket.end("HTTP/1.1 502 Bad Gateway\r\n\r\n")); + socket.on("error", () => serverSocket.destroy()); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + await once(targetServer.listen(0, "127.0.0.1"), "listening"); + + const proxyAddress = proxyServer.address() as AddressInfo; + const targetAddress = targetServer.address() as AddressInfo; + + // Create 5 concurrent connections + const promises = Array.from({ length: 5 }, (_, i) => connectClient(proxyAddress, targetAddress, false)); + + const results = await Promise.all(promises); + expect(connectionCount).toBe(5); + results.forEach(result => { + expect(result).toContain("HTTP/1.1 200 OK"); + }); + }); + + test("should handle CONNECT with invalid target", async () => { + await using proxyServer = http.createServer((req, res) => { + res.end("Hello World from proxy server"); + }); + + proxyServer.on("connect", (req, socket, head) => { + const [host, port] = req.url?.split(":") ?? []; + + const serverSocket = net.connect(parseInt(port) || 80, host, () => { + socket.write(`HTTP/1.1 200 Connection established\r\n\r\n`); + serverSocket.pipe(socket); + socket.pipe(serverSocket); + }); + + serverSocket.on("error", err => { + socket.write("HTTP/1.1 502 Bad Gateway\r\n\r\n"); + socket.end(); + }); + + socket.on("error", () => serverSocket.destroy()); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + + const client = net.connect(proxyAddress.port, proxyAddress.address, () => { + client.write("CONNECT invalid.host.that.does.not.exist:9999 HTTP/1.1\r\nHost: invalid.host:9999\r\n\r\n"); + }); + + const { promise, resolve } = Promise.withResolvers(); + const received: string[] = []; + + client.on("data", data => { + received.push(data.toString()); + }); + + client.on("end", () => { + resolve(received.join("")); + }); + + const response = await promise; + expect(response).toContain("502 Bad Gateway"); + }); + + test("should handle CONNECT with authentication failure", async () => { + await using proxyServer = http.createServer((req, res) => { + res.end("Hello World from proxy server"); + }); + + proxyServer.on("connect", (req, socket, head) => { + const auth = req.headers["proxy-authorization"]; + if (!auth || auth !== "Basic dXNlcjpwYXNzd29yZA==") { + socket.write("HTTP/1.1 407 Proxy Authentication Required\r\n"); + socket.write('Proxy-Authenticate: Basic realm="Proxy"\r\n\r\n'); + socket.end(); + return; + } + + socket.write("HTTP/1.1 200 Connection established\r\n\r\n"); + socket.end(); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + + // Test without authentication + const client1 = net.connect(proxyAddress.port, proxyAddress.address, () => { + client1.write("CONNECT example.com:80 HTTP/1.1\r\nHost: example.com:80\r\n\r\n"); + }); + + const { promise: promise1, resolve: resolve1 } = Promise.withResolvers(); + const received1: string[] = []; + + client1.on("data", data => { + received1.push(data.toString()); + }); + + client1.on("end", () => { + resolve1(received1.join("")); + }); + + const response1 = await promise1; + expect(response1).toContain("407 Proxy Authentication Required"); + + // Test with correct authentication + const client2 = net.connect(proxyAddress.port, proxyAddress.address, () => { + client2.write( + "CONNECT example.com:80 HTTP/1.1\r\nHost: example.com:80\r\nProxy-Authorization: Basic dXNlcjpwYXNzd29yZA==\r\n\r\n", + ); + }); + + const { promise: promise2, resolve: resolve2 } = Promise.withResolvers(); + const received2: string[] = []; + + client2.on("data", data => { + received2.push(data.toString()); + }); + + client2.on("end", () => { + resolve2(received2.join("")); + }); + + const response2 = await promise2; + expect(response2).toContain("200 Connection established"); + }); + + test("should handle partial writes and buffering", async () => { + await using proxyServer = http.createServer(); + let bufferReceived = ""; + + proxyServer.on("connect", (req, socket, head) => { + socket.on("data", chunk => { + bufferReceived += chunk.toString(); + }); + + // Send response in small chunks + socket.write("HTTP/1.1 "); + setTimeout(() => socket.write("200 "), 10); + setTimeout(() => socket.write("Connection "), 20); + setTimeout(() => socket.write("established\r\n\r\n"), 30); + setTimeout(() => { + socket.write("Test data"); + socket.end(); + }, 40); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + + const client = net.connect(proxyAddress.port, proxyAddress.address, () => { + // Send request in chunks + client.write("CONNECT example.com:80 "); + setTimeout(() => client.write("HTTP/1.1\r\n"), 5); + setTimeout(() => client.write("Host: example.com\r\n\r\n"), 10); + setTimeout(() => client.write("Client data"), 35); + }); + + const { promise, resolve } = Promise.withResolvers(); + const received: string[] = []; + + client.on("data", data => { + received.push(data.toString()); + }); + + client.on("end", () => { + resolve(received.join("")); + }); + + const response = await promise; + expect(response).toContain("200 Connection established"); + expect(response).toContain("Test data"); + expect(bufferReceived).toContain("Client data"); + }); + + test("should handle keep-alive connections", async () => { + await using proxyServer = http.createServer(); + await using targetServer = http.createServer((req, res) => { + res.writeHead(200, { "Content-Length": "5" }); + res.end("Hello"); + }); + + proxyServer.on("connect", (req, socket, head) => { + const [host, port] = req.url?.split(":") ?? []; + + const serverSocket = net.connect(parseInt(port), host, () => { + socket.write("HTTP/1.1 200 Connection established\r\n\r\n"); + serverSocket.pipe(socket); + socket.pipe(serverSocket); + }); + + serverSocket.on("error", () => socket.end()); + socket.on("error", () => serverSocket.destroy()); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + await once(targetServer.listen(0, "127.0.0.1"), "listening"); + + const proxyAddress = proxyServer.address() as AddressInfo; + const targetAddress = targetServer.address() as AddressInfo; + + const client = net.connect(proxyAddress.port, proxyAddress.address, () => { + client.write( + `CONNECT ${targetAddress.address}:${targetAddress.port} HTTP/1.1\r\nHost: ${targetAddress.address}:${targetAddress.port}\r\n\r\n`, + ); + }); + + const { promise, resolve } = Promise.withResolvers(); + const responses: string[] = []; + let requestCount = 0; + + client.on("data", data => { + const str = data.toString(); + responses.push(str); + + if (str.includes("200 Connection established") && requestCount === 0) { + // Send first request + client.write("GET /first HTTP/1.1\r\nHost: example.com\r\nConnection: keep-alive\r\n\r\n"); + requestCount++; + } else if (str.includes("Hello") && requestCount === 1) { + // Send second request on same connection + client.write("GET /second HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n"); + requestCount++; + } else if (str.includes("Hello") && requestCount === 2) { + client.end(); + resolve(responses); + } + }); + + const allResponses = await promise; + const combined = allResponses.join(""); + expect(combined).toContain("200 Connection established"); + expect(combined.match(/Hello/g)?.length).toBe(2); + }); +}); diff --git a/test/js/node/http/node-http-connect.test.ts b/test/js/node/http/node-http-connect.test.ts new file mode 100644 index 0000000000..5e486e4888 --- /dev/null +++ b/test/js/node/http/node-http-connect.test.ts @@ -0,0 +1,464 @@ +import { describe, expect, test } from "bun:test"; +import { bunEnv, bunExe, nodeExe } from "harness"; +import http from "http"; + +import { once } from "node:events"; +import type { AddressInfo } from "node:net"; +import net from "node:net"; +import { join } from "node:path"; +function connectClient(proxyAddress: AddressInfo, targetAddress: AddressInfo, add_http_prefix: boolean) { + const client = net.connect({ port: proxyAddress.port, host: proxyAddress.address }, () => { + client.write( + `CONNECT ${add_http_prefix ? "http://" : ""}${targetAddress.address}:${targetAddress.port} HTTP/1.1\r\nHost: ${targetAddress.address}:${targetAddress.port}\r\nProxy-Authorization: Basic dXNlcjpwYXNzd29yZA==\r\n\r\n`, + ); + }); + + const received: string[] = []; + const { promise, resolve, reject } = Promise.withResolvers(); + + client.on("data", data => { + if (data.toString().includes("200 Connection established")) { + client.write("GET / HTTP/1.1\r\nHost: www.example.com:80\r\nConnection: close\r\n\r\n"); + } + received.push(data.toString()); + }); + client.on("error", reject); + + client.on("end", () => { + resolve(received.join("")); + }); + return promise; +} + +const BIG_DATA = Buffer.alloc(1024 * 1024 * 64, "bun").toString(); +describe("HTTP server CONNECT", () => { + test("should handle backpressure", async () => { + const responseHeader = "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n"; + await using proxyServer = http.createServer((req, res) => { + res.end("Hello World from proxy server"); + }); + await using targetServer = net.createServer(socket => { + socket.write(responseHeader, () => { + socket.write(BIG_DATA, () => { + //TODO: is this a net bug? on windows the connection is closed before everything is sended + Bun.sleep(100).then(() => { + socket.end(); + }); + }); + }); + }); + let proxyHeaders = {}; + proxyServer.on("connect", (req, socket, head) => { + proxyHeaders = req.headers; + const [host, port] = req.url?.split(":") ?? []; + + const serverSocket = net.connect(parseInt(port), host, async () => { + socket.write(`HTTP/1.1 200 Connection established\r\nConnection: close\r\n\r\n`); + serverSocket.pipe(socket); + socket.pipe(serverSocket); + }); + serverSocket.on("error", err => { + socket.end("HTTP/1.1 502 Bad Gateway\r\n\r\n"); + }); + socket.on("error", err => { + serverSocket.destroy(); + }); + + socket.on("end", () => serverSocket.end()); + serverSocket.on("end", () => socket.end()); + }); + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + + await once(targetServer.listen(0, "127.0.0.1"), "listening"); + const targetAddress = targetServer.address() as AddressInfo; + + { + const response = await connectClient(proxyAddress, targetAddress, false); + expect(proxyHeaders["proxy-authorization"]).toBe("Basic dXNlcjpwYXNzd29yZA=="); + expect(response).toContain("HTTP/1.1 200 OK"); + expect(response.length).toBeGreaterThan(responseHeader.length + BIG_DATA.length); + expect(response).toContain(BIG_DATA); + } + }); + + test("should handle data, drain, end and close events", async () => { + await using proxyServer = http.createServer((req, res) => { + res.end("Hello World from proxy server"); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + let data_received: string[] = []; + let client_data_received: string[] = []; + let proxy_drain_received = false; + let proxy_end_received = false; + + const { promise, resolve, reject } = Promise.withResolvers(); + + const { promise: clientPromise, resolve: clientResolve, reject: clientReject } = Promise.withResolvers(); + const clientSocket = net.connect(proxyAddress.port, proxyAddress.address, () => { + clientSocket.on("error", clientReject); + clientSocket.on("data", chunk => { + client_data_received.push(chunk?.toString()); + }); + clientSocket.on("end", () => { + clientSocket.end(); + clientResolve(client_data_received.join("")); + }); + + clientSocket.write("CONNECT localhost:80 HTTP/1.1\r\nHost: localhost:80\r\nConnection: close\r\n\r\n"); + }); + + proxyServer.on("connect", (req, socket, head) => { + expect(head).toBeInstanceOf(Buffer); + socket.on("data", chunk => { + data_received.push(chunk?.toString()); + }); + socket.on("end", () => { + proxy_end_received = true; + }); + socket.on("close", () => { + resolve(data_received.join("")); + }); + socket.on("drain", () => { + proxy_drain_received = true; + socket.end(); + }); + socket.on("error", reject); + proxy_drain_received = false; + // write until backpressure + while (socket.write(BIG_DATA)) {} + clientSocket.write("Hello World"); + }); + + expect(await promise).toContain("Hello World"); + expect(await clientPromise).toContain(BIG_DATA); + expect(proxy_drain_received).toBe(true); + expect(proxy_end_received).toBe(true); + }); + + test("should handle CONNECT with invalid target", async () => { + await using proxyServer = http.createServer((req, res) => { + res.end("Hello World from proxy server"); + }); + + proxyServer.on("connect", (req, socket, head) => { + const [host, port] = req.url?.split(":") ?? []; + + const serverSocket = net.connect(parseInt(port) || 80, host, () => { + socket.write(`HTTP/1.1 200 Connection established\r\n\r\n`); + serverSocket.pipe(socket); + socket.pipe(serverSocket); + }); + + serverSocket.on("error", err => { + socket.write("HTTP/1.1 502 Bad Gateway\r\n\r\n"); + socket.end(); + }); + + socket.on("error", () => serverSocket.destroy()); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + + const client = net.connect(proxyAddress.port, proxyAddress.address, () => { + client.write("CONNECT invalid.host.that.does.not.exist:9999 HTTP/1.1\r\nHost: invalid.host:9999\r\n\r\n"); + }); + + const { promise, resolve } = Promise.withResolvers(); + const received: string[] = []; + + client.on("data", data => { + received.push(data.toString()); + }); + + client.on("end", () => { + resolve(received.join("")); + }); + + const response = await promise; + expect(response).toContain("502 Bad Gateway"); + }); + + // TODO: timeout is not supported in bun socket yet + test.todo("should handle socket timeout", async () => { + await using proxyServer = http.createServer(); + let timeoutFired = false; + + proxyServer.on("connect", (req, socket, head) => { + socket.setTimeout(100); + socket.on("timeout", () => { + timeoutFired = true; + socket.write("HTTP/1.1 408 Request Timeout\r\n\r\n"); + socket.end(); + }); + + // Don't send any response immediately + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + + const client = net.connect(proxyAddress.port, proxyAddress.address, () => { + client.write("CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n\r\n"); + }); + + const { promise, resolve } = Promise.withResolvers(); + const received: string[] = []; + + client.on("data", data => { + received.push(data.toString()); + }); + + client.on("end", () => { + resolve(received.join("")); + }); + + const response = await promise; + expect(timeoutFired).toBe(true); + expect(response).toContain("408 Request Timeout"); + }); + + //TODO pause and resume only not supported in bun socket yet + test.todo("should handle socket pause and resume", async () => { + await using proxyServer = http.createServer(); + let pauseCount = 0; + let resumeCount = 0; + + proxyServer.on("connect", (req, socket, head) => { + socket.write("HTTP/1.1 200 Connection established\r\n\r\n"); + + // Simulate backpressure scenario + const interval = setInterval(() => { + const canWrite = socket.write("X".repeat(1024)); + if (!canWrite) { + pauseCount++; + socket.pause(); + setTimeout(() => { + resumeCount++; + socket.resume(); + }, 50); + } + }, 10); + + socket.on("end", () => { + clearInterval(interval); + socket.end(); + }); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + + const client = net.connect(proxyAddress.port, proxyAddress.address, () => { + client.write("CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n\r\n"); + + setTimeout(() => client.end(), 200); + }); + + const { promise, resolve } = Promise.withResolvers(); + let bytesReceived = 0; + + client.on("data", data => { + bytesReceived += data.length; + }); + + client.on("end", () => { + resolve(bytesReceived); + }); + + const totalBytes = await promise; + expect(totalBytes).toBeGreaterThan(0); + expect(pauseCount).toBeGreaterThan(0); + expect(resumeCount).toBeGreaterThan(0); + }); + + test("should handle malformed CONNECT requests", async () => { + await using proxyServer = http.createServer(); + + proxyServer.on("connect", (req, socket, head) => { + // This shouldn't be reached for malformed requests + socket.write("HTTP/1.1 200 Connection established\r\n\r\n"); + socket.end(); + }); + + await once(proxyServer.listen(0, "127.0.0.1"), "listening"); + const proxyAddress = proxyServer.address() as AddressInfo; + + // Test various malformed requests + const malformedRequests = [ + "CONNECT\r\n\r\n", // Missing target + "CONNECT example.com HTTP/1.1\r\n\r\n", // Missing port + "CONNECT :80 HTTP/1.1\r\n\r\n", // Missing host + "CONNEC example.com:80 HTTP/1.1\r\n\r\n", // Typo in method + "CONNECT example.com:80\r\n\r\n", // Missing HTTP version + ]; + + for (const request of malformedRequests) { + const client = net.connect(proxyAddress.port, proxyAddress.address, () => { + client.write(request); + }); + + const { promise, resolve } = Promise.withResolvers(); + const received: string[] = []; + + client.on("data", data => { + received.push(data.toString()); + }); + + client.on("end", () => { + resolve(received.join("")); + }); + + client.on("error", () => { + resolve("CONNECTION_ERROR"); + }); + + setTimeout(() => { + client.end(); + resolve(received.join("") || "TIMEOUT"); + }, 100); + + const response = await promise; + // Should either get an error response or timeout/connection error + expect(response).not.toContain("200 Connection established"); + } + }); +}); + +/** + * Test variations using normal HTTP requests and res.socket + * These tests should run in both Node.js and Bun + */ + +describe("HTTP server socket access via normal requests", () => { + //TODO: right now http server socket dont emit error event + test.todo("should handle socket errors during normal requests", async () => { + let errorHandled = false; + + await using server = http.createServer((req, res) => { + const socket = res.socket!; + + socket.on("error", err => { + errorHandled = true; + }); + + // Simulate an error condition + setTimeout(() => { + socket.destroy(new Error("Simulated error")); + }, 50); + }); + + await once(server.listen(0, "127.0.0.1"), "listening"); + const serverAddress = server.address() as AddressInfo; + + const client = net.connect(serverAddress.port, serverAddress.address, () => { + client.write("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"); + }); + + const { promise, resolve } = Promise.withResolvers(); + + client.on("error", () => { + resolve(true); + }); + + client.on("close", () => { + resolve(false); + }); + + await promise; + expect(errorHandled).toBe(true); + }); + + test.todo("should handle socket pause/resume during request", async () => { + const largeData = Buffer.alloc(1024 * 1024, "x").toString(); + let pauseCount = 0; + let resumeCount = 0; + + await using server = http.createServer((req, res) => { + const socket = res.socket!; + + // Monitor socket state + const originalPause = socket.pause.bind(socket); + const originalResume = socket.resume.bind(socket); + + socket.pause = function () { + pauseCount++; + return originalPause(); + }; + + socket.resume = function () { + resumeCount++; + return originalResume(); + }; + + // Send large response to trigger backpressure + res.writeHead(200, { "Content-Type": "text/plain" }); + + const sendData = () => { + let ok = true; + while (ok) { + ok = res.write(largeData); + if (!ok) { + // Wait for drain event + res.once("drain", sendData); + break; + } + } + }; + + sendData(); + + setTimeout(() => res.end(), 100); + }); + + await once(server.listen(0, "127.0.0.1"), "listening"); + const serverAddress = server.address() as AddressInfo; + + const client = net.connect(serverAddress.port, serverAddress.address, () => { + client.write("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"); + }); + + const { promise, resolve } = Promise.withResolvers(); + let bytesReceived = 0; + + // Slow reader to trigger backpressure + client.on("data", chunk => { + bytesReceived += chunk.length; + client.pause(); + setTimeout(() => client.resume(), 10); + }); + + client.on("end", () => { + resolve(bytesReceived); + }); + + const total = await promise; + expect(total).toBeGreaterThan(0); + }); +}); + +describe("Should be compatible with node.js", () => { + test("tests should run on node.js", async () => { + const process = Bun.spawn({ + cmd: [nodeExe(), "--test", join(import.meta.dir, "node-http-connect.node.mts")], + stdout: "inherit", + stderr: "inherit", + stdin: "ignore", + env: bunEnv, + }); + expect(await process.exited).toBe(0); + }); + test("tests should run on bun", async () => { + const process = Bun.spawn({ + cmd: [bunExe(), "test", join(import.meta.dir, "node-http-connect.node.mts")], + stdout: "inherit", + stderr: "inherit", + stdin: "ignore", + env: bunEnv, + }); + expect(await process.exited).toBe(0); + }); +}); diff --git a/test/js/node/http/node-http.test.ts b/test/js/node/http/node-http.test.ts index 50f8146f0f..d882b44d30 100644 --- a/test/js/node/http/node-http.test.ts +++ b/test/js/node/http/node-http.test.ts @@ -2757,11 +2757,11 @@ test("chunked encoding must be valid after flushHeaders", async () => { res.end(); }); - server.listen(3000); + server.listen(0); await once(server, "listening"); - const socket = connect(3000, () => { - socket.write("GET / HTTP/1.1\r\nHost: localhost:3000\r\nConnection: close\r\n\r\n"); + const socket = connect(server.address().port, () => { + socket.write(`GET / HTTP/1.1\r\nHost: localhost:${server.address().port}\r\nConnection: close\r\n\r\n`); }); const chunks = []; @@ -2840,11 +2840,11 @@ test("chunked encoding must be valid using minimal code", async () => { res.end("chunk 2"); }); - server.listen(3000); + server.listen(0); await once(server, "listening"); - const socket = connect(3000, () => { - socket.write("GET / HTTP/1.1\r\nHost: localhost:3000\r\nConnection: close\r\n\r\n"); + const socket = connect(server.address().port, () => { + socket.write(`GET / HTTP/1.1\r\nHost: localhost:${server.address().port}\r\nConnection: close\r\n\r\n`); }); const chunks = []; @@ -2929,11 +2929,11 @@ test("chunked encoding must be valid after without flushHeaders", async () => { res.end(); }); - server.listen(3000); + server.listen(0); await once(server, "listening"); - const socket = connect(3000, () => { - socket.write("GET / HTTP/1.1\r\nHost: localhost:3000\r\nConnection: close\r\n\r\n"); + const socket = connect(server.address().port, () => { + socket.write(`GET / HTTP/1.1\r\nHost: localhost:${server.address().port}\r\nConnection: close\r\n\r\n`); }); const chunks = []; @@ -3011,7 +3011,9 @@ test("should accept received and send blank headers", async () => { await once(server, "listening"); const socket = createConnection((server.address() as AddressInfo).port, "localhost", () => { - socket.write("GET / HTTP/1.1\r\nHost: localhost:3000\r\nConnection: close\r\nEmpty-Header:\r\n\r\n"); + socket.write( + `GET / HTTP/1.1\r\nHost: localhost:${server.address().port}\r\nConnection: close\r\nEmpty-Header:\r\n\r\n`, + ); }); socket.on("data", data => { @@ -3044,7 +3046,7 @@ test("should handle header overflow", async () => { const socket = createConnection((server.address() as AddressInfo).port, "localhost", () => { socket.write( - "GET / HTTP/1.1\r\nHost: localhost:3000\r\nConnection: close\r\nBig-Header: " + + `GET / HTTP/1.1\r\nHost: localhost:${server.address().port}\r\nConnection: close\r\nBig-Header: ` + "a".repeat(http.maxHeaderSize) + // will overflow because of host and connection headers "\r\n\r\n", ); @@ -3069,7 +3071,7 @@ test("should handle invalid method", async () => { const socket = createConnection((server.address() as AddressInfo).port, "localhost", () => { socket.write( - "BUN / HTTP/1.1\r\nHost: localhost:3000\r\nConnection: close\r\nBig-Header: " + + `BUN / HTTP/1.1\r\nHost: localhost:${server.address().port}\r\nConnection: close\r\nBig-Header: ` + "a".repeat(http.maxHeaderSize) + // will overflow because of host and connection headers "\r\n\r\n", ); @@ -3349,4 +3351,57 @@ describe("HTTP Server Security Tests - Advanced", () => { expect(mockHandler).not.toHaveBeenCalled(); }); }); + + test("Server should not crash in clientError is emitted when calling destroy", async () => { + await using server = http.createServer(async (req, res) => { + res.end("Hello World"); + }); + + const clientErrors: Promise[] = []; + server.on("clientError", (err, socket) => { + clientErrors.push( + Bun.sleep(10).then(() => { + socket.destroy(); + }), + ); + }); + await once(server.listen(), "listening"); + const address = server.address() as AddressInfo; + + async function doRequests(address: AddressInfo) { + const client = connect(address.port, address.address, () => { + client.write("GET / HTTP/1.1\r\nHost: localhost:3000\r\nContent-Length: 0\r\n\r\n"); + }); + { + const { promise, resolve, reject } = Promise.withResolvers(); + client.on("data", resolve); + client.on("error", reject); + client.on("end", resolve); + await promise; + } + { + const { promise, resolve, reject } = Promise.withResolvers(); + client.write("GET / HTTP/1.1\r\nContent-Length: 0\r\n\r\n"); + client.on("error", reject); + client.on("end", resolve); + await promise; + } + } + + async function doInvalidRequests(address: AddressInfo) { + const client = connect(address.port, address.address, () => { + client.write("GET / HTTP/1.1\r\nContent-Length: 0\r\n\r\n"); + }); + const { promise, resolve, reject } = Promise.withResolvers(); + client.on("error", reject); + client.on("close", resolve); + await promise; + } + + await doRequests(address); + await Promise.all(clientErrors); + clientErrors.length = 0; + await doInvalidRequests(address); + await Promise.all(clientErrors); + }); }); diff --git a/test/package.json b/test/package.json index 8b302e35dc..6effaadca1 100644 --- a/test/package.json +++ b/test/package.json @@ -76,6 +76,7 @@ "postgres": "3.3.5", "prisma": "5.1.1", "prompts": "2.4.2", + "proxy": "2.2.0", "react": "file:../node_modules/react", "react-dom": "18.3.1", "reflect-metadata": "0.2.2",