diff --git a/packages/bun-uws/src/App.h b/packages/bun-uws/src/App.h index 6840c23fed..9c73e64dba 100644 --- a/packages/bun-uws/src/App.h +++ b/packages/bun-uws/src/App.h @@ -641,6 +641,10 @@ public: httpContext->getSocketContextData()->onClientError = std::move(onClientError); } + void setOnSocketUpgraded(HttpContextData::OnSocketUpgradedCallback onUpgraded) { + httpContext->getSocketContextData()->onSocketUpgraded = onUpgraded; + } + TemplatedApp &&run() { uWS::run(); return std::move(*this); diff --git a/packages/bun-uws/src/HttpContextData.h b/packages/bun-uws/src/HttpContextData.h index a595927d56..538537c92c 100644 --- a/packages/bun-uws/src/HttpContextData.h +++ b/packages/bun-uws/src/HttpContextData.h @@ -43,11 +43,11 @@ struct alignas(16) HttpContextData { template friend struct TemplatedApp; private: std::vector *, int)>> filterHandlers; - using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket); using OnSocketDataCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket, const char *data, int length, bool last); using OnSocketDrainCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket); + using OnSocketUpgradedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket); using OnClientErrorCallback = MoveOnlyFunction; - + using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket); MoveOnlyFunction missingServerNameHandler; @@ -66,6 +66,7 @@ private: OnSocketClosedCallback onSocketClosed = nullptr; OnSocketDrainCallback onSocketDrain = nullptr; OnSocketDataCallback onSocketData = nullptr; + OnSocketUpgradedCallback onSocketUpgraded = nullptr; OnClientErrorCallback onClientError = nullptr; uint64_t maxHeaderSize = 0; // 0 means no limit @@ -78,6 +79,7 @@ private: } public: + HttpFlags flags; }; diff --git a/packages/bun-uws/src/HttpResponse.h b/packages/bun-uws/src/HttpResponse.h index 209e0e79df..974a4a95f6 100644 --- a/packages/bun-uws/src/HttpResponse.h +++ b/packages/bun-uws/src/HttpResponse.h @@ -316,14 +316,20 @@ public: HttpContext *httpContext = (HttpContext *) us_socket_context(SSL, (struct us_socket_t *) this); /* Move any backpressure out of HttpResponse */ - BackPressure backpressure(std::move(((AsyncSocketData *) getHttpResponseData())->buffer)); - + auto* responseData = getHttpResponseData(); + BackPressure backpressure(std::move(((AsyncSocketData *) responseData)->buffer)); + + auto* socketData = responseData->socketData; + HttpContextData *httpContextData = httpContext->getSocketContextData(); + /* Destroy HttpResponseData */ - getHttpResponseData()->~HttpResponseData(); + responseData->~HttpResponseData(); /* Before we adopt and potentially change socket, check if we are corked */ bool wasCorked = Super::isCorked(); + + /* Adopting a socket invalidates it, do not rely on it directly to carry any data */ us_socket_t *usSocket = us_socket_context_adopt_socket(SSL, (us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(WebSocketData) + sizeof(UserData)); WebSocket *webSocket = (WebSocket *) usSocket; @@ -334,10 +340,12 @@ public: } /* Initialize websocket with any moved backpressure intact */ - webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure)); + webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure), socketData, httpContextData->onSocketClosed); + if (httpContextData->onSocketUpgraded) { + httpContextData->onSocketUpgraded(socketData, SSL, usSocket); + } /* We should only mark this if inside the parser; if upgrading "async" we cannot set this */ - HttpContextData *httpContextData = httpContext->getSocketContextData(); if (httpContextData->flags.isParsingHttp) { /* We need to tell the Http parser that we changed socket */ httpContextData->upgradedWebSocket = webSocket; @@ -351,7 +359,6 @@ public: /* Move construct the UserData right before calling open handler */ new (webSocket->getUserData()) UserData(std::forward(userData)); - /* Emit open event and start the timeout */ if (webSocketContextData->openHandler) { diff --git a/packages/bun-uws/src/WebSocket.h b/packages/bun-uws/src/WebSocket.h index 6b5efc81f7..5871cacb61 100644 --- a/packages/bun-uws/src/WebSocket.h +++ b/packages/bun-uws/src/WebSocket.h @@ -34,8 +34,8 @@ struct WebSocket : AsyncSocket { private: typedef AsyncSocket Super; - void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) { - new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure)); + void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure, void *socketData, WebSocketData::OnSocketClosedCallback onSocketClosed) { + new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure), socketData, onSocketClosed); return this; } public: diff --git a/packages/bun-uws/src/WebSocketContext.h b/packages/bun-uws/src/WebSocketContext.h index 16d8092fb0..1c31050010 100644 --- a/packages/bun-uws/src/WebSocketContext.h +++ b/packages/bun-uws/src/WebSocketContext.h @@ -256,6 +256,9 @@ private: /* For whatever reason, if we already have emitted close event, do not emit it again */ WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s)); + if (webSocketData->socketData && webSocketData->onSocketClosed) { + webSocketData->onSocketClosed(webSocketData->socketData, SSL, (us_socket_t *) s); + } if (!webSocketData->isShuttingDown) { /* Emit close event */ auto *webSocketContextData = (WebSocketContextData *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s)); diff --git a/packages/bun-uws/src/WebSocketContextData.h b/packages/bun-uws/src/WebSocketContextData.h index c016be49c4..b675f65dc9 100644 --- a/packages/bun-uws/src/WebSocketContextData.h +++ b/packages/bun-uws/src/WebSocketContextData.h @@ -52,7 +52,6 @@ struct WebSocketContextData { private: public: - /* This one points to the App's shared topicTree */ TopicTree *topicTree; diff --git a/packages/bun-uws/src/WebSocketData.h b/packages/bun-uws/src/WebSocketData.h index 21e96a72d9..f9139341d1 100644 --- a/packages/bun-uws/src/WebSocketData.h +++ b/packages/bun-uws/src/WebSocketData.h @@ -38,6 +38,7 @@ private: unsigned int controlTipLength = 0; bool isShuttingDown = 0; bool hasTimedOut = false; + enum CompressionStatus : char { DISABLED, ENABLED, @@ -52,7 +53,12 @@ private: /* We could be a subscriber */ Subscriber *subscriber = nullptr; public: - WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) : AsyncSocketData(std::move(backpressure)), WebSocketState() { + using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket); + void *socketData = nullptr; + /* node http compatibility callbacks */ + OnSocketClosedCallback onSocketClosed = nullptr; + + WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure, void *socketData, OnSocketClosedCallback onSocketClosed) : AsyncSocketData(std::move(backpressure)), WebSocketState() { compressionStatus = perMessageDeflate ? ENABLED : DISABLED; /* Initialize the dedicated sliding window(s) */ @@ -64,6 +70,8 @@ public: inflationStream = new InflationStream(compressOptions); } } + this->socketData = socketData; + this->onSocketClosed = onSocketClosed; } ~WebSocketData() { diff --git a/src/bun.js/bindings/NodeHTTP.cpp b/src/bun.js/bindings/NodeHTTP.cpp index daf35b9078..3b939e5929 100644 --- a/src/bun.js/bindings/NodeHTTP.cpp +++ b/src/bun.js/bindings/NodeHTTP.cpp @@ -139,6 +139,7 @@ public: us_socket_t* socket = nullptr; unsigned is_ssl : 1 = 0; unsigned ended : 1 = 0; + unsigned upgraded : 1 = 0; JSC::Strong strongThis = {}; static JSNodeHTTPServerSocket* create(JSC::VM& vm, JSC::Structure* structure, us_socket_t* socket, bool is_ssl, WebCore::JSNodeHTTPResponse* response) @@ -160,10 +161,15 @@ public: } template - static void clearSocketData(us_socket_t* socket) + static void clearSocketData(bool upgraded, us_socket_t* socket) { - auto* httpResponseData = (uWS::HttpResponseData*)us_socket_ext(SSL, socket); - httpResponseData->socketData = nullptr; + if (upgraded) { + auto* webSocket = (uWS::WebSocketData*)us_socket_ext(SSL, socket); + webSocket->socketData = nullptr; + } else { + auto* httpResponseData = (uWS::HttpResponseData*)us_socket_ext(SSL, socket); + httpResponseData->socketData = nullptr; + } } void close() @@ -195,9 +201,9 @@ public: { if (socket) { if (is_ssl) { - clearSocketData(socket); + clearSocketData(this->upgraded, socket); } else { - clearSocketData(socket); + clearSocketData(this->upgraded, socket); } } us_socket_free_stream_buffer(&streamBuffer); @@ -1143,6 +1149,12 @@ static void assignOnNodeJSCompat(uWS::TemplatedApp* app) ASSERT(rawSocket == socket->socket || socket->socket == nullptr); socket->onData(data, length, last); }); + app->setOnSocketUpgraded([](void* socketData, int is_ssl, struct us_socket_t* rawSocket) -> void { + auto* socket = reinterpret_cast(socketData); + // the socket is adopted and might not be the same as the rawSocket + socket->socket = rawSocket; + socket->upgraded = true; + }); } extern "C" void NodeHTTP_assignOnNodeJSCompat(bool is_ssl, void* uws_app) diff --git a/src/js/node/_http_server.ts b/src/js/node/_http_server.ts index 2981ae75ca..61cc987bc5 100644 --- a/src/js/node/_http_server.ts +++ b/src/js/node/_http_server.ts @@ -901,7 +901,6 @@ const NodeHTTPServerSocket = class Socket extends Duplex { req.destroy(); } } - this.emit("close"); } #onCloseForDestroy(closeCallback) { this.#onClose(); diff --git a/test/js/node/http/node-http-with-ws.test.ts b/test/js/node/http/node-http-with-ws.test.ts new file mode 100644 index 0000000000..f4cc63d242 --- /dev/null +++ b/test/js/node/http/node-http-with-ws.test.ts @@ -0,0 +1,50 @@ +import { expect, test } from "bun:test"; +import { tls as options } from "harness"; +import https from "https"; +import type { AddressInfo } from "node:net"; +import tls from "tls"; +import { WebSocketServer } from "ws"; +test("should not crash when closing sockets after upgrade", async () => { + const { promise, resolve } = Promise.withResolvers(); + let http_sockets: tls.TLSSocket[] = []; + + const server = https.createServer(options, (req, res) => { + http_sockets.push(res.socket as tls.TLSSocket); + res.writeHead(200, { "Content-Type": "text/plain", "Connection": "Keep-Alive" }); + res.end("okay"); + res.detachSocket(res.socket!); + }); + + server.listen(0, "127.0.0.1", () => { + const wsServer = new WebSocketServer({ server }); + wsServer.on("connection", socket => {}); + + const port = (server.address() as AddressInfo).port; + const socket = tls.connect({ port, ca: options.cert }, () => { + // normal request keep the socket alive + socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`); + socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`); + socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`); + // upgrade to websocket + socket.write( + `GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\r\n`, + ); + }); + socket.on("data", data => { + const isWebSocket = data?.toString().includes("Upgrade: websocket"); + if (isWebSocket) { + socket.destroy(); + setTimeout(() => { + http_sockets.forEach(http_socket => { + http_socket?.destroy(); + }); + server.closeAllConnections(); + resolve(); + }, 10); + } + }); + }); + + await promise; + expect().pass(); +});