mirror of
https://github.com/oven-sh/bun
synced 2026-02-02 15:08:46 +00:00
fix(node:http) fix closing socket after upgraded to websocket (#23150)
### What does this PR do? handle socket upgrade in NodeHTTP.cpp ### How did you verify your code works? Run the test added with asan it should catch the bug --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -641,6 +641,10 @@ public:
|
||||
httpContext->getSocketContextData()->onClientError = std::move(onClientError);
|
||||
}
|
||||
|
||||
void setOnSocketUpgraded(HttpContextData<SSL>::OnSocketUpgradedCallback onUpgraded) {
|
||||
httpContext->getSocketContextData()->onSocketUpgraded = onUpgraded;
|
||||
}
|
||||
|
||||
TemplatedApp &&run() {
|
||||
uWS::run();
|
||||
return std::move(*this);
|
||||
|
||||
@@ -43,11 +43,11 @@ struct alignas(16) HttpContextData {
|
||||
template <bool> friend struct TemplatedApp;
|
||||
private:
|
||||
std::vector<MoveOnlyFunction<void(HttpResponse<SSL> *, int)>> filterHandlers;
|
||||
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
using OnSocketDataCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket, const char *data, int length, bool last);
|
||||
using OnSocketDrainCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
using OnSocketUpgradedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
using OnClientErrorCallback = MoveOnlyFunction<void(int is_ssl, struct us_socket_t *rawSocket, uWS::HttpParserError errorCode, char *rawPacket, int rawPacketLength)>;
|
||||
|
||||
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
|
||||
MoveOnlyFunction<void(const char *hostname)> missingServerNameHandler;
|
||||
|
||||
@@ -66,6 +66,7 @@ private:
|
||||
OnSocketClosedCallback onSocketClosed = nullptr;
|
||||
OnSocketDrainCallback onSocketDrain = nullptr;
|
||||
OnSocketDataCallback onSocketData = nullptr;
|
||||
OnSocketUpgradedCallback onSocketUpgraded = nullptr;
|
||||
OnClientErrorCallback onClientError = nullptr;
|
||||
|
||||
uint64_t maxHeaderSize = 0; // 0 means no limit
|
||||
@@ -78,6 +79,7 @@ private:
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
HttpFlags flags;
|
||||
};
|
||||
|
||||
|
||||
@@ -316,14 +316,20 @@ public:
|
||||
HttpContext<SSL> *httpContext = (HttpContext<SSL> *) us_socket_context(SSL, (struct us_socket_t *) this);
|
||||
|
||||
/* Move any backpressure out of HttpResponse */
|
||||
BackPressure backpressure(std::move(((AsyncSocketData<SSL> *) getHttpResponseData())->buffer));
|
||||
|
||||
auto* responseData = getHttpResponseData();
|
||||
BackPressure backpressure(std::move(((AsyncSocketData<SSL> *) responseData)->buffer));
|
||||
|
||||
auto* socketData = responseData->socketData;
|
||||
HttpContextData<SSL> *httpContextData = httpContext->getSocketContextData();
|
||||
|
||||
/* Destroy HttpResponseData */
|
||||
getHttpResponseData()->~HttpResponseData();
|
||||
responseData->~HttpResponseData();
|
||||
|
||||
/* Before we adopt and potentially change socket, check if we are corked */
|
||||
bool wasCorked = Super::isCorked();
|
||||
|
||||
|
||||
|
||||
/* Adopting a socket invalidates it, do not rely on it directly to carry any data */
|
||||
us_socket_t *usSocket = us_socket_context_adopt_socket(SSL, (us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(WebSocketData) + sizeof(UserData));
|
||||
WebSocket<SSL, true, UserData> *webSocket = (WebSocket<SSL, true, UserData> *) usSocket;
|
||||
@@ -334,10 +340,12 @@ public:
|
||||
}
|
||||
|
||||
/* Initialize websocket with any moved backpressure intact */
|
||||
webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure));
|
||||
webSocket->init(perMessageDeflate, compressOptions, std::move(backpressure), socketData, httpContextData->onSocketClosed);
|
||||
if (httpContextData->onSocketUpgraded) {
|
||||
httpContextData->onSocketUpgraded(socketData, SSL, usSocket);
|
||||
}
|
||||
|
||||
/* We should only mark this if inside the parser; if upgrading "async" we cannot set this */
|
||||
HttpContextData<SSL> *httpContextData = httpContext->getSocketContextData();
|
||||
if (httpContextData->flags.isParsingHttp) {
|
||||
/* We need to tell the Http parser that we changed socket */
|
||||
httpContextData->upgradedWebSocket = webSocket;
|
||||
@@ -351,7 +359,6 @@ public:
|
||||
|
||||
/* Move construct the UserData right before calling open handler */
|
||||
new (webSocket->getUserData()) UserData(std::forward<UserData>(userData));
|
||||
|
||||
|
||||
/* Emit open event and start the timeout */
|
||||
if (webSocketContextData->openHandler) {
|
||||
|
||||
@@ -34,8 +34,8 @@ struct WebSocket : AsyncSocket<SSL> {
|
||||
private:
|
||||
typedef AsyncSocket<SSL> Super;
|
||||
|
||||
void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) {
|
||||
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure));
|
||||
void *init(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure, void *socketData, WebSocketData::OnSocketClosedCallback onSocketClosed) {
|
||||
new (us_socket_ext(SSL, (us_socket_t *) this)) WebSocketData(perMessageDeflate, compressOptions, std::move(backpressure), socketData, onSocketClosed);
|
||||
return this;
|
||||
}
|
||||
public:
|
||||
|
||||
@@ -256,6 +256,9 @@ private:
|
||||
|
||||
/* For whatever reason, if we already have emitted close event, do not emit it again */
|
||||
WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
|
||||
if (webSocketData->socketData && webSocketData->onSocketClosed) {
|
||||
webSocketData->onSocketClosed(webSocketData->socketData, SSL, (us_socket_t *) s);
|
||||
}
|
||||
if (!webSocketData->isShuttingDown) {
|
||||
/* Emit close event */
|
||||
auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
|
||||
|
||||
@@ -52,7 +52,6 @@ struct WebSocketContextData {
|
||||
private:
|
||||
|
||||
public:
|
||||
|
||||
/* This one points to the App's shared topicTree */
|
||||
TopicTree<TopicTreeMessage, TopicTreeBigMessage> *topicTree;
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ private:
|
||||
unsigned int controlTipLength = 0;
|
||||
bool isShuttingDown = 0;
|
||||
bool hasTimedOut = false;
|
||||
|
||||
enum CompressionStatus : char {
|
||||
DISABLED,
|
||||
ENABLED,
|
||||
@@ -52,7 +53,12 @@ private:
|
||||
/* We could be a subscriber */
|
||||
Subscriber *subscriber = nullptr;
|
||||
public:
|
||||
WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
|
||||
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
void *socketData = nullptr;
|
||||
/* node http compatibility callbacks */
|
||||
OnSocketClosedCallback onSocketClosed = nullptr;
|
||||
|
||||
WebSocketData(bool perMessageDeflate, CompressOptions compressOptions, BackPressure &&backpressure, void *socketData, OnSocketClosedCallback onSocketClosed) : AsyncSocketData<false>(std::move(backpressure)), WebSocketState<true>() {
|
||||
compressionStatus = perMessageDeflate ? ENABLED : DISABLED;
|
||||
|
||||
/* Initialize the dedicated sliding window(s) */
|
||||
@@ -64,6 +70,8 @@ public:
|
||||
inflationStream = new InflationStream(compressOptions);
|
||||
}
|
||||
}
|
||||
this->socketData = socketData;
|
||||
this->onSocketClosed = onSocketClosed;
|
||||
}
|
||||
|
||||
~WebSocketData() {
|
||||
|
||||
Reference in New Issue
Block a user