mirror of
https://github.com/oven-sh/bun
synced 2026-02-02 15:08:46 +00:00
fix(node:http) fix closing socket after upgraded to websocket (#23150)
### What does this PR do? handle socket upgrade in NodeHTTP.cpp ### How did you verify your code works? Run the test added with asan it should catch the bug --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -641,6 +641,10 @@ public:
|
||||
httpContext->getSocketContextData()->onClientError = std::move(onClientError);
|
||||
}
|
||||
|
||||
void setOnSocketUpgraded(HttpContextData<SSL>::OnSocketUpgradedCallback onUpgraded) {
|
||||
httpContext->getSocketContextData()->onSocketUpgraded = onUpgraded;
|
||||
}
|
||||
|
||||
TemplatedApp &&run() {
|
||||
uWS::run();
|
||||
return std::move(*this);
|
||||
|
||||
@@ -43,11 +43,11 @@ struct alignas(16) HttpContextData {
|
||||
template <bool> friend struct TemplatedApp;
|
||||
private:
|
||||
std::vector<MoveOnlyFunction<void(HttpResponse<SSL> *, int)>> filterHandlers;
|
||||
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
using OnSocketDataCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket, const char *data, int length, bool last);
|
||||
using OnSocketDrainCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
using OnSocketUpgradedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
using OnClientErrorCallback = MoveOnlyFunction<void(int is_ssl, struct us_socket_t *rawSocket, uWS::HttpParserError errorCode, char *rawPacket, int rawPacketLength)>;
|
||||
|
||||
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
|
||||
MoveOnlyFunction<void(const char *hostname)> missingServerNameHandler;
|
||||
|
||||
@@ -66,6 +66,7 @@ private:
|
||||
OnSocketClosedCallback onSocketClosed = nullptr;
|
||||
OnSocketDrainCallback onSocketDrain = nullptr;
|
||||
OnSocketDataCallback onSocketData = nullptr;
|
||||
OnSocketUpgradedCallback onSocketUpgraded = nullptr;
|
||||
OnClientErrorCallback onClientError = nullptr;
|
||||
|
||||
uint64_t maxHeaderSize = 0; // 0 means no limit
|
||||
@@ -78,6 +79,7 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
HttpFlags flags;
|
||||
};
|
||||
|
||||
|
||||
@@ -316,14 +316,20 @@ public:
|
||||
HttpContext<SSL> *httpContext = (HttpContext<SSL> *) us_socket_context(SSL, (struct us_socket_t *) this);
|
||||
|
||||
/* Move any backpressure out of HttpResponse */
|
||||
BackPressure backpressure(std::move(((AsyncSocketData<SSL> *) getHttpResponseData())->buffer));
|
||||
|
||||
auto* responseData = getHttpResponseData();
|
||||
BackPressure backpressure(std::move(((AsyncSocketData<SSL> *) responseData)->buffer));
|
||||
|
||||
auto* socketData = responseData->socketData;
|
||||
HttpContextData<SSL> *httpContextData = httpContext->getSocketContextData();
|
||||
|
||||
/* Destroy HttpResponseData */
|
||||
getHttpResponseData()->~HttpResponseData();
|
||||
responseData->~HttpResponseData();
|
||||
|
||||
/* Before we adopt and potentially change socket, check if we are corked */
|
||||
bool wasCorked = Super::isCorked();
|
||||
|
||||
|
||||
|
||||
/* Adopting a socket invalidates it, do not rely on it directly to carry any data */
|
||||
us_socket_t *usSocket = us_socket_context_adopt_socket(SSL, (us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(WebSocketData) + sizeof(UserData));
|
||||
WebSocket<SSL, true, UserData> *webSocket = (WebSocket<SSL, true, UserData> *) usSocket;
|
||||
@@ -334,10 +340,12 @@ public:
|
||||
}
|
||||
|
||||
/* Initialize websocket with any moved backpressure intact */
|
||||
webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure));
|
||||
webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure), socketData, httpContextData->onSocketClosed);
|
||||
if (httpContextData->onSocketUpgraded) {
|
||||
httpContextData->onSocketUpgraded(socketData, SSL, usSocket);
|
||||
}
|
||||
|
||||
/* We should only mark this if inside the parser; if upgrading "async" we cannot set this */
|
||||
HttpContextData<SSL> *httpContextData = httpContext->getSocketContextData();
|
||||
if (httpContextData->flags.isParsingHttp) {
|
||||
/* We need to tell the Http parser that we changed socket */
|
||||
httpContextData->upgradedWebSocket = webSocket;
|
||||
@@ -351,7 +359,6 @@ public:
|
||||
|
||||
/* Move construct the UserData right before calling open handler */
|
||||
new (webSocket->getUserData()) UserData(std::forward<UserData>(userData));
|
||||
|
||||
|
||||
/* Emit open event and start the timeout */
|
||||
if (webSocketContextData->openHandler) {
|
||||
|
||||
@@ -34,8 +34,8 @@ struct WebSocket : AsyncSocket<SSL> {
|
||||
private:
|
||||
typedef AsyncSocket<SSL> Super;
|
||||
|
||||
void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) {
|
||||
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure));
|
||||
void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure, void *socketData, WebSocketData::OnSocketClosedCallback onSocketClosed) {
|
||||
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure), socketData, onSocketClosed);
|
||||
return this;
|
||||
}
|
||||
public:
|
||||
|
||||
@@ -256,6 +256,9 @@ private:
|
||||
|
||||
/* For whatever reason, if we already have emitted close event, do not emit it again */
|
||||
WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
|
||||
if (webSocketData->socketData && webSocketData->onSocketClosed) {
|
||||
webSocketData->onSocketClosed(webSocketData->socketData, SSL, (us_socket_t *) s);
|
||||
}
|
||||
if (!webSocketData->isShuttingDown) {
|
||||
/* Emit close event */
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
|
||||
@@ -52,7 +52,6 @@ struct WebSocketContextData {
|
||||
private:
|
||||
|
||||
public:
|
||||
|
||||
/* This one points to the App's shared topicTree */
|
||||
TopicTree<TopicTreeMessage, TopicTreeBigMessage> *topicTree;
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ private:
|
||||
unsigned int controlTipLength = 0;
|
||||
bool isShuttingDown = 0;
|
||||
bool hasTimedOut = false;
|
||||
|
||||
enum CompressionStatus : char {
|
||||
DISABLED,
|
||||
ENABLED,
|
||||
@@ -52,7 +53,12 @@ private:
|
||||
/* We could be a subscriber */
|
||||
Subscriber *subscriber = nullptr;
|
||||
public:
|
||||
WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
|
||||
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
void *socketData = nullptr;
|
||||
/* node http compatibility callbacks */
|
||||
OnSocketClosedCallback onSocketClosed = nullptr;
|
||||
|
||||
WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure, void *socketData, OnSocketClosedCallback onSocketClosed) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
|
||||
compressionStatus = perMessageDeflate ? ENABLED : DISABLED;
|
||||
|
||||
/* Initialize the dedicated sliding window(s) */
|
||||
@@ -64,6 +70,8 @@ public:
|
||||
inflationStream = new InflationStream(compressOptions);
|
||||
}
|
||||
}
|
||||
this->socketData = socketData;
|
||||
this->onSocketClosed = onSocketClosed;
|
||||
}
|
||||
|
||||
~WebSocketData() {
|
||||
|
||||
@@ -139,6 +139,7 @@ public:
|
||||
us_socket_t* socket = nullptr;
|
||||
unsigned is_ssl : 1 = 0;
|
||||
unsigned ended : 1 = 0;
|
||||
unsigned upgraded : 1 = 0;
|
||||
JSC::Strong<JSNodeHTTPServerSocket> strongThis = {};
|
||||
|
||||
static JSNodeHTTPServerSocket* create(JSC::VM& vm, JSC::Structure* structure, us_socket_t* socket, bool is_ssl, WebCore::JSNodeHTTPResponse* response)
|
||||
@@ -160,10 +161,15 @@ public:
|
||||
}
|
||||
|
||||
template<bool SSL>
|
||||
static void clearSocketData(us_socket_t* socket)
|
||||
static void clearSocketData(bool upgraded, us_socket_t* socket)
|
||||
{
|
||||
auto* httpResponseData = (uWS::HttpResponseData<SSL>*)us_socket_ext(SSL, socket);
|
||||
httpResponseData->socketData = nullptr;
|
||||
if (upgraded) {
|
||||
auto* webSocket = (uWS::WebSocketData*)us_socket_ext(SSL, socket);
|
||||
webSocket->socketData = nullptr;
|
||||
} else {
|
||||
auto* httpResponseData = (uWS::HttpResponseData<SSL>*)us_socket_ext(SSL, socket);
|
||||
httpResponseData->socketData = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void close()
|
||||
@@ -195,9 +201,9 @@ public:
|
||||
{
|
||||
if (socket) {
|
||||
if (is_ssl) {
|
||||
clearSocketData<true>(socket);
|
||||
clearSocketData<true>(this->upgraded, socket);
|
||||
} else {
|
||||
clearSocketData<false>(socket);
|
||||
clearSocketData<false>(this->upgraded, socket);
|
||||
}
|
||||
}
|
||||
us_socket_free_stream_buffer(&streamBuffer);
|
||||
@@ -1143,6 +1149,12 @@ static void assignOnNodeJSCompat(uWS::TemplatedApp<isSSL>* app)
|
||||
ASSERT(rawSocket == socket->socket || socket->socket == nullptr);
|
||||
socket->onData(data, length, last);
|
||||
});
|
||||
app->setOnSocketUpgraded([](void* socketData, int is_ssl, struct us_socket_t* rawSocket) -> void {
|
||||
auto* socket = reinterpret_cast<JSNodeHTTPServerSocket*>(socketData);
|
||||
// the socket is adopted and might not be the same as the rawSocket
|
||||
socket->socket = rawSocket;
|
||||
socket->upgraded = true;
|
||||
});
|
||||
}
|
||||
|
||||
extern "C" void NodeHTTP_assignOnNodeJSCompat(bool is_ssl, void* uws_app)
|
||||
|
||||
@@ -901,7 +901,6 @@ const NodeHTTPServerSocket = class Socket extends Duplex {
|
||||
req.destroy();
|
||||
}
|
||||
}
|
||||
this.emit("close");
|
||||
}
|
||||
#onCloseForDestroy(closeCallback) {
|
||||
this.#onClose();
|
||||
|
||||
50
test/js/node/http/node-http-with-ws.test.ts
Normal file
50
test/js/node/http/node-http-with-ws.test.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import { expect, test } from "bun:test";
|
||||
import { tls as options } from "harness";
|
||||
import https from "https";
|
||||
import type { AddressInfo } from "node:net";
|
||||
import tls from "tls";
|
||||
import { WebSocketServer } from "ws";
|
||||
test("should not crash when closing sockets after upgrade", async () => {
|
||||
const { promise, resolve } = Promise.withResolvers();
|
||||
let http_sockets: tls.TLSSocket[] = [];
|
||||
|
||||
const server = https.createServer(options, (req, res) => {
|
||||
http_sockets.push(res.socket as tls.TLSSocket);
|
||||
res.writeHead(200, { "Content-Type": "text/plain", "Connection": "Keep-Alive" });
|
||||
res.end("okay");
|
||||
res.detachSocket(res.socket!);
|
||||
});
|
||||
|
||||
server.listen(0, "127.0.0.1", () => {
|
||||
const wsServer = new WebSocketServer({ server });
|
||||
wsServer.on("connection", socket => {});
|
||||
|
||||
const port = (server.address() as AddressInfo).port;
|
||||
const socket = tls.connect({ port, ca: options.cert }, () => {
|
||||
// normal request keep the socket alive
|
||||
socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`);
|
||||
socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`);
|
||||
socket.write(`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Keep-Alive\r\nContent-Length: 0\r\n\r\n`);
|
||||
// upgrade to websocket
|
||||
socket.write(
|
||||
`GET / HTTP/1.1\r\nHost: localhost:${port}\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\r\n`,
|
||||
);
|
||||
});
|
||||
socket.on("data", data => {
|
||||
const isWebSocket = data?.toString().includes("Upgrade: websocket");
|
||||
if (isWebSocket) {
|
||||
socket.destroy();
|
||||
setTimeout(() => {
|
||||
http_sockets.forEach(http_socket => {
|
||||
http_socket?.destroy();
|
||||
});
|
||||
server.closeAllConnections();
|
||||
resolve();
|
||||
}, 10);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
await promise;
|
||||
expect().pass();
|
||||
});
|
||||
Reference in New Issue
Block a user