fix(node:http) fix closing socket after upgraded to websocket (#23150)

### What does this PR do?
handle socket upgrade in NodeHTTP.cpp
### How did you verify your code works?
Run the test added with asan it should catch the bug

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ciro Spaciari
2025-10-02 14:55:28 -07:00
committed by GitHub
parent 2caa5dc8f2
commit 76545140af
10 changed files with 102 additions and 18 deletions

View File

@@ -641,6 +641,10 @@ public:
httpContext->getSocketContextData()->onClientError = std::move(onClientError);
}
void setOnSocketUpgraded(HttpContextData<SSL>::OnSocketUpgradedCallback onUpgraded) {
httpContext->getSocketContextData()->onSocketUpgraded = onUpgraded;
}
TemplatedApp &&run() {
uWS::run();
return std::move(*this);

View File

@@ -43,11 +43,11 @@ struct alignas(16) HttpContextData {
template <bool> friend struct TemplatedApp;
private:
std::vector<MoveOnlyFunction<void(HttpResponse<SSL> *, int)>> filterHandlers;
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
using OnSocketDataCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket, const char *data, int length, bool last);
using OnSocketDrainCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
using OnSocketUpgradedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
using OnClientErrorCallback = MoveOnlyFunction<void(int is_ssl, struct us_socket_t *rawSocket, uWS::HttpParserError errorCode, char *rawPacket, int rawPacketLength)>;
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
MoveOnlyFunction<void(const char *hostname)> missingServerNameHandler;
@@ -66,6 +66,7 @@ private:
OnSocketClosedCallback onSocketClosed = nullptr;
OnSocketDrainCallback onSocketDrain = nullptr;
OnSocketDataCallback onSocketData = nullptr;
OnSocketUpgradedCallback onSocketUpgraded = nullptr;
OnClientErrorCallback onClientError = nullptr;
uint64_t maxHeaderSize = 0; // 0 means no limit
@@ -78,6 +79,7 @@ private:
}
public:
HttpFlags flags;
};

View File

@@ -316,14 +316,20 @@ public:
HttpContext<SSL> *httpContext = (HttpContext<SSL> *) us_socket_context(SSL, (struct us_socket_t *) this);
/* Move any backpressure out of HttpResponse */
BackPressure backpressure(std::move(((AsyncSocketData<SSL> *) getHttpResponseData())->buffer));
auto* responseData = getHttpResponseData();
BackPressure backpressure(std::move(((AsyncSocketData<SSL> *) responseData)->buffer));
auto* socketData = responseData->socketData;
HttpContextData<SSL> *httpContextData = httpContext->getSocketContextData();
/* Destroy HttpResponseData */
getHttpResponseData()->~HttpResponseData();
responseData->~HttpResponseData();
/* Before we adopt and potentially change socket, check if we are corked */
bool wasCorked = Super::isCorked();
/* Adopting a socket invalidates it, do not rely on it directly to carry any data */
us_socket_t *usSocket = us_socket_context_adopt_socket(SSL, (us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(WebSocketData) + sizeof(UserData));
WebSocket<SSL, true, UserData> *webSocket = (WebSocket<SSL, true, UserData> *) usSocket;
@@ -334,10 +340,12 @@ public:
}
/* Initialize websocket with any moved backpressure intact */
webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure));
webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure), socketData, httpContextData->onSocketClosed);
if (httpContextData->onSocketUpgraded) {
httpContextData->onSocketUpgraded(socketData, SSL, usSocket);
}
/* We should only mark this if inside the parser; if upgrading "async" we cannot set this */
HttpContextData<SSL> *httpContextData = httpContext->getSocketContextData();
if (httpContextData->flags.isParsingHttp) {
/* We need to tell the Http parser that we changed socket */
httpContextData->upgradedWebSocket = webSocket;
@@ -351,7 +359,6 @@ public:
/* Move construct the UserData right before calling open handler */
new (webSocket->getUserData()) UserData(std::forward<UserData>(userData));
/* Emit open event and start the timeout */
if (webSocketContextData->openHandler) {

View File

@@ -34,8 +34,8 @@ struct WebSocket : AsyncSocket<SSL> {
private:
typedef AsyncSocket<SSL> Super;
void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) {
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure));
void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure, void *socketData, WebSocketData::OnSocketClosedCallback onSocketClosed) {
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure), socketData, onSocketClosed);
return this;
}
public:

View File

@@ -256,6 +256,9 @@ private:
/* For whatever reason, if we already have emitted close event, do not emit it again */
WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
if (webSocketData->socketData && webSocketData->onSocketClosed) {
webSocketData->onSocketClosed(webSocketData->socketData, SSL, (us_socket_t *) s);
}
if (!webSocketData->isShuttingDown) {
/* Emit close event */
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));

View File

@@ -52,7 +52,6 @@ struct WebSocketContextData {
private:
public:
/* This one points to the App's shared topicTree */
TopicTree<TopicTreeMessage, TopicTreeBigMessage> *topicTree;

View File

@@ -38,6 +38,7 @@ private:
unsigned int controlTipLength = 0;
bool isShuttingDown = 0;
bool hasTimedOut = false;
enum CompressionStatus : char {
DISABLED,
ENABLED,
@@ -52,7 +53,12 @@ private:
/* We could be a subscriber */
Subscriber *subscriber = nullptr;
public:
WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
void *socketData = nullptr;
/* node http compatibility callbacks */
OnSocketClosedCallback onSocketClosed = nullptr;
WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure, void *socketData, OnSocketClosedCallback onSocketClosed) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
compressionStatus = perMessageDeflate ? ENABLED : DISABLED;
/* Initialize the dedicated sliding window(s) */
@@ -64,6 +70,8 @@ public:
inflationStream = new InflationStream(compressOptions);
}
}
this->socketData = socketData;
this->onSocketClosed = onSocketClosed;
}
~WebSocketData() {

View File

@@ -139,6 +139,7 @@ public:
us_socket_t* socket = nullptr;
unsigned is_ssl : 1 = 0;
unsigned ended : 1 = 0;
unsigned upgraded : 1 = 0;
JSC::Strong<JSNodeHTTPServerSocket> strongThis = {};
static JSNodeHTTPServerSocket* create(JSC::VM& vm, JSC::Structure* structure, us_socket_t* socket, bool is_ssl, WebCore::JSNodeHTTPResponse* response)
@@ -160,10 +161,15 @@ public:
}
template<bool SSL>
static void clearSocketData(us_socket_t* socket)
static void clearSocketData(bool upgraded, us_socket_t* socket)
{
auto* httpResponseData = (uWS::HttpResponseData<SSL>*)us_socket_ext(SSL, socket);
httpResponseData->socketData = nullptr;
if (upgraded) {
auto* webSocket = (uWS::WebSocketData*)us_socket_ext(SSL, socket);
webSocket->socketData = nullptr;
} else {
auto* httpResponseData = (uWS::HttpResponseData<SSL>*)us_socket_ext(SSL, socket);
httpResponseData->socketData = nullptr;
}
}
void close()
@@ -195,9 +201,9 @@ public:
{
if (socket) {
if (is_ssl) {
clearSocketData<true>(socket);
clearSocketData<true>(this->upgraded, socket);
} else {
clearSocketData<false>(socket);
clearSocketData<false>(this->upgraded, socket);
}
}
us_socket_free_stream_buffer(&streamBuffer);
@@ -1143,6 +1149,12 @@ static void assignOnNodeJSCompat(uWS::TemplatedApp<isSSL>* app)
ASSERT(rawSocket == socket->socket || socket->socket == nullptr);
socket->onData(data, length, last);
});
app->setOnSocketUpgraded([](void* socketData, int is_ssl, struct us_socket_t* rawSocket) -> void {
auto* socket = reinterpret_cast<JSNodeHTTPServerSocket*>(socketData);
// the socket is adopted and might not be the same as the rawSocket
socket->socket = rawSocket;
socket->upgraded = true;
});
}
extern "C" void NodeHTTP_assignOnNodeJSCompat(bool is_ssl, void* uws_app)

View File

@@ -901,7 +901,6 @@ const NodeHTTPServerSocket = class Socket extends Duplex {
req.destroy();
}
}
this.emit("close");
}
#onCloseForDestroy(closeCallback) {
this.#onClose();

View File

@@ -0,0 +1,50 @@
import { expect, test } from "bun:test";
import { tls as options } from "harness";
import https from "https";
import type { AddressInfo } from "node:net";
import tls from "tls";
import { WebSocketServer } from "ws";
test("should not crash when closing sockets after upgrade", async () => {
const { promise, resolve } = Promise.withResolvers();
let http_sockets: tls.TLSSocket[] = [];
const server = https.createServer(options, (req, res) => {
http_sockets.push(res.socket as tls.TLSSocket);
res.writeHead(200, { "Content-Type": "text/plain", "Connection": "Keep-Alive" });
res.end("okay");
res.detachSocket(res.socket!);
});
server.listen(0, "127.0.0.1", () => {
const wsServer = new WebSocketServer({ server });
wsServer.on("connection", socket => {});
const port = (server.address() as AddressInfo).port;
const socket = tls.connect({ port, ca: options.cert }, () => {
// normal request keep the socket alive
socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`);
socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`);
socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`);
// upgrade to websocket
socket.write(
`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\r\n`,
);
});
socket.on("data", data => {
const isWebSocket = data?.toString().includes("Upgrade: websocket");
if (isWebSocket) {
socket.destroy();
setTimeout(() => {
http_sockets.forEach(http_socket => {
http_socket?.destroy();
});
server.closeAllConnections();
resolve();
}, 10);
}
});
});
await promise;
expect().pass();
});