mirror of
https://github.com/oven-sh/bun
synced 2026-02-02 15:08:46 +00:00
fix(node:http) allow CONNECT in node http/https servers (#22756)
### What does this PR do? Fixes https://github.com/oven-sh/bun/issues/22755 Fixes https://github.com/oven-sh/bun/issues/19790 Fixes https://github.com/oven-sh/bun/issues/16372 ### How did you verify your code works? --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -627,9 +627,15 @@ public:
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
void setOnClose(HttpContextData<SSL>::OnSocketClosedCallback onClose) {
|
||||
void setOnSocketClosed(HttpContextData<SSL>::OnSocketClosedCallback onClose) {
|
||||
httpContext->getSocketContextData()->onSocketClosed = onClose;
|
||||
}
|
||||
void setOnSocketDrain(HttpContextData<SSL>::OnSocketDrainCallback onDrain) {
|
||||
httpContext->getSocketContextData()->onSocketDrain = onDrain;
|
||||
}
|
||||
void setOnSocketData(HttpContextData<SSL>::OnSocketDataCallback onData) {
|
||||
httpContext->getSocketContextData()->onSocketData = onData;
|
||||
}
|
||||
|
||||
void setOnClientError(HttpContextData<SSL>::OnClientErrorCallback onClientError) {
|
||||
httpContext->getSocketContextData()->onClientError = std::move(onClientError);
|
||||
|
||||
@@ -193,23 +193,32 @@ private:
|
||||
auto *httpResponseData = reinterpret_cast<HttpResponseData<SSL> *>(us_socket_ext(SSL, s));
|
||||
|
||||
|
||||
|
||||
/* Call filter */
|
||||
HttpContextData<SSL> *httpContextData = getSocketContextDataS(s);
|
||||
|
||||
if(httpResponseData && httpResponseData->isConnectRequest) {
|
||||
if (httpResponseData->socketData && httpContextData->onSocketData) {
|
||||
httpContextData->onSocketData(httpResponseData->socketData, SSL, s, "", 0, true);
|
||||
}
|
||||
if(httpResponseData->inStream) {
|
||||
httpResponseData->inStream(reinterpret_cast<HttpResponse<SSL> *>(s), "", 0, true, httpResponseData->userData);
|
||||
httpResponseData->inStream = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for (auto &f : httpContextData->filterHandlers) {
|
||||
f((HttpResponse<SSL> *) s, -1);
|
||||
}
|
||||
|
||||
if (httpResponseData->socketData && httpContextData->onSocketClosed) {
|
||||
httpContextData->onSocketClosed(httpResponseData->socketData, SSL, s);
|
||||
}
|
||||
/* Signal broken HTTP request only if we have a pending request */
|
||||
if (httpResponseData->onAborted != nullptr && httpResponseData->userData != nullptr) {
|
||||
httpResponseData->onAborted((HttpResponse<SSL> *)s, httpResponseData->userData);
|
||||
}
|
||||
|
||||
if (httpResponseData->socketData && httpContextData->onSocketClosed) {
|
||||
httpContextData->onSocketClosed(httpResponseData->socketData, SSL, s);
|
||||
}
|
||||
|
||||
/* Destruct socket ext */
|
||||
httpResponseData->~HttpResponseData<SSL>();
|
||||
@@ -254,7 +263,9 @@ private:
|
||||
|
||||
/* The return value is entirely up to us to interpret. The HttpParser cares only for whether the returned value is DIFFERENT from passed user */
|
||||
|
||||
auto result = httpResponseData->consumePostPadded(httpContextData->maxHeaderSize, httpContextData->flags.requireHostHeader,httpContextData->flags.useStrictMethodValidation, data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * {
|
||||
auto result = httpResponseData->consumePostPadded(httpContextData->maxHeaderSize, httpResponseData->isConnectRequest, httpContextData->flags.requireHostHeader,httpContextData->flags.useStrictMethodValidation, data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * {
|
||||
|
||||
|
||||
/* For every request we reset the timeout and hang until user makes action */
|
||||
/* Warning: if we are in shutdown state, resetting the timer is a security issue! */
|
||||
us_socket_timeout(SSL, (us_socket_t *) s, 0);
|
||||
@@ -330,7 +341,12 @@ private:
|
||||
/* Continue parsing */
|
||||
return s;
|
||||
|
||||
}, [httpResponseData](void *user, std::string_view data, bool fin) -> void * {
|
||||
}, [httpResponseData, httpContextData](void *user, std::string_view data, bool fin) -> void * {
|
||||
|
||||
|
||||
if (httpResponseData->isConnectRequest && httpResponseData->socketData && httpContextData->onSocketData) {
|
||||
httpContextData->onSocketData(httpResponseData->socketData, SSL, (struct us_socket_t *) user, data.data(), data.length(), fin);
|
||||
}
|
||||
/* We always get an empty chunk even if there is no data */
|
||||
if (httpResponseData->inStream) {
|
||||
|
||||
@@ -449,7 +465,7 @@ private:
|
||||
us_socket_context_on_writable(SSL, getSocketContext(), [](us_socket_t *s) {
|
||||
auto *asyncSocket = reinterpret_cast<AsyncSocket<SSL> *>(s);
|
||||
auto *httpResponseData = reinterpret_cast<HttpResponseData<SSL> *>(asyncSocket->getAsyncSocketData());
|
||||
|
||||
|
||||
/* Attempt to drain the socket buffer before triggering onWritable callback */
|
||||
size_t bufferedAmount = asyncSocket->getBufferedAmount();
|
||||
if (bufferedAmount > 0) {
|
||||
@@ -470,6 +486,12 @@ private:
|
||||
*/
|
||||
}
|
||||
|
||||
auto *httpContextData = getSocketContextDataS(s);
|
||||
|
||||
|
||||
if (httpResponseData->isConnectRequest && httpResponseData->socketData && httpContextData->onSocketDrain) {
|
||||
httpContextData->onSocketDrain(httpResponseData->socketData, SSL, (struct us_socket_t *) s);
|
||||
}
|
||||
/* Ask the developer to write data and return success (true) or failure (false), OR skip sending anything and return success (true). */
|
||||
if (httpResponseData->onWritable) {
|
||||
/* We are now writable, so hang timeout again, the user does not have to do anything so we should hang until end or tryEnd rearms timeout */
|
||||
@@ -514,6 +536,7 @@ private:
|
||||
us_socket_context_on_end(SSL, getSocketContext(), [](us_socket_t *s) {
|
||||
auto *asyncSocket = reinterpret_cast<AsyncSocket<SSL> *>(s);
|
||||
asyncSocket->uncorkWithoutSending();
|
||||
|
||||
/* We do not care for half closed sockets */
|
||||
return asyncSocket->close();
|
||||
});
|
||||
|
||||
@@ -44,7 +44,10 @@ struct alignas(16) HttpContextData {
|
||||
private:
|
||||
std::vector<MoveOnlyFunction<void(HttpResponse<SSL> *, int)>> filterHandlers;
|
||||
using OnSocketClosedCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
using OnSocketDataCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket, const char *data, int length, bool last);
|
||||
using OnSocketDrainCallback = void (*)(void* userData, int is_ssl, struct us_socket_t *rawSocket);
|
||||
using OnClientErrorCallback = MoveOnlyFunction<void(int is_ssl, struct us_socket_t *rawSocket, uWS::HttpParserError errorCode, char *rawPacket, int rawPacketLength)>;
|
||||
|
||||
|
||||
MoveOnlyFunction<void(const char *hostname)> missingServerNameHandler;
|
||||
|
||||
@@ -61,6 +64,8 @@ private:
|
||||
void *upgradedWebSocket = nullptr;
|
||||
/* Used to simulate Node.js socket events. */
|
||||
OnSocketClosedCallback onSocketClosed = nullptr;
|
||||
OnSocketDrainCallback onSocketDrain = nullptr;
|
||||
OnSocketDataCallback onSocketData = nullptr;
|
||||
OnClientErrorCallback onClientError = nullptr;
|
||||
|
||||
uint64_t maxHeaderSize = 0; // 0 means no limit
|
||||
|
||||
@@ -117,18 +117,19 @@ namespace uWS
|
||||
struct ConsumeRequestLineResult {
|
||||
char *position;
|
||||
bool isAncientHTTP;
|
||||
bool isConnect;
|
||||
HTTPHeaderParserError headerParserError;
|
||||
public:
|
||||
static ConsumeRequestLineResult error(HTTPHeaderParserError error) {
|
||||
return ConsumeRequestLineResult{nullptr, false, error};
|
||||
return ConsumeRequestLineResult{nullptr, false, false, error};
|
||||
}
|
||||
|
||||
static ConsumeRequestLineResult success(char *position, bool isAncientHTTP = false) {
|
||||
return ConsumeRequestLineResult{position, isAncientHTTP, HTTP_HEADER_PARSER_ERROR_NONE};
|
||||
static ConsumeRequestLineResult success(char *position, bool isAncientHTTP = false, bool isConnect = false) {
|
||||
return ConsumeRequestLineResult{position, isAncientHTTP, isConnect, HTTP_HEADER_PARSER_ERROR_NONE};
|
||||
}
|
||||
|
||||
static ConsumeRequestLineResult shortRead(bool isAncientHTTP = false) {
|
||||
return ConsumeRequestLineResult{nullptr, isAncientHTTP, HTTP_HEADER_PARSER_ERROR_NONE};
|
||||
static ConsumeRequestLineResult shortRead(bool isAncientHTTP = false, bool isConnect = false) {
|
||||
return ConsumeRequestLineResult{nullptr, isAncientHTTP, isConnect, HTTP_HEADER_PARSER_ERROR_NONE};
|
||||
}
|
||||
|
||||
bool isErrorOrShortRead() {
|
||||
@@ -551,7 +552,10 @@ namespace uWS
|
||||
return ConsumeRequestLineResult::shortRead();
|
||||
}
|
||||
|
||||
if (data[0] == 32 && (__builtin_expect(data[1] == '/', 1) || isHTTPorHTTPSPrefixForProxies(data + 1, end) == 1)) [[likely]] {
|
||||
|
||||
bool isHTTPMethod = (__builtin_expect(data[1] == '/', 1));
|
||||
bool isConnect = !isHTTPMethod && (isHTTPorHTTPSPrefixForProxies(data + 1, end) == 1 || ((data - start) == 7 && memcmp(start, "CONNECT", 7) == 0));
|
||||
if (isHTTPMethod || isConnect) [[likely]] {
|
||||
header.key = {start, (size_t) (data - start)};
|
||||
data++;
|
||||
if(!isValidMethod(header.key, useStrictMethodValidation)) {
|
||||
@@ -577,22 +581,22 @@ namespace uWS
|
||||
if (nextPosition >= end) {
|
||||
/* Whatever we have must be part of the version string */
|
||||
if (memcmp(" HTTP/1.1\r\n", data, std::min<unsigned int>(11, (unsigned int) (end - data))) == 0) {
|
||||
return ConsumeRequestLineResult::shortRead();
|
||||
return ConsumeRequestLineResult::shortRead(false, isConnect);
|
||||
} else if (memcmp(" HTTP/1.0\r\n", data, std::min<unsigned int>(11, (unsigned int) (end - data))) == 0) {
|
||||
/*Indicates that the request line is ancient HTTP*/
|
||||
return ConsumeRequestLineResult::shortRead(true);
|
||||
return ConsumeRequestLineResult::shortRead(true, isConnect);
|
||||
}
|
||||
return ConsumeRequestLineResult::error(HTTP_HEADER_PARSER_ERROR_INVALID_HTTP_VERSION);
|
||||
}
|
||||
if (memcmp(" HTTP/1.1\r\n", data, 11) == 0) {
|
||||
return ConsumeRequestLineResult::success(nextPosition);
|
||||
return ConsumeRequestLineResult::success(nextPosition, false, isConnect);
|
||||
} else if (memcmp(" HTTP/1.0\r\n", data, 11) == 0) {
|
||||
/*Indicates that the request line is ancient HTTP*/
|
||||
return ConsumeRequestLineResult::success(nextPosition, true);
|
||||
return ConsumeRequestLineResult::success(nextPosition, true, isConnect);
|
||||
}
|
||||
/* If we stand at the post padded CR, we have fragmented input so try again later */
|
||||
if (data[0] == '\r') {
|
||||
return ConsumeRequestLineResult::shortRead();
|
||||
return ConsumeRequestLineResult::shortRead(false, isConnect);
|
||||
}
|
||||
/* This is an error */
|
||||
return ConsumeRequestLineResult::error(HTTP_HEADER_PARSER_ERROR_INVALID_HTTP_VERSION);
|
||||
@@ -602,14 +606,14 @@ namespace uWS
|
||||
|
||||
/* If we stand at the post padded CR, we have fragmented input so try again later */
|
||||
if (data[0] == '\r') {
|
||||
return ConsumeRequestLineResult::shortRead();
|
||||
return ConsumeRequestLineResult::shortRead(false, isConnect);
|
||||
}
|
||||
|
||||
if (data[0] == 32) {
|
||||
switch (isHTTPorHTTPSPrefixForProxies(data + 1, end)) {
|
||||
// If we haven't received enough data to check if it's http:// or https://, let's try again later
|
||||
case -1:
|
||||
return ConsumeRequestLineResult::shortRead();
|
||||
return ConsumeRequestLineResult::shortRead(false, isConnect);
|
||||
// Otherwise, if it's not http:// or https://, return 400
|
||||
default:
|
||||
return ConsumeRequestLineResult::error(HTTP_HEADER_PARSER_ERROR_INVALID_REQUEST);
|
||||
@@ -635,7 +639,7 @@ namespace uWS
|
||||
}
|
||||
|
||||
/* End is only used for the proxy parser. The HTTP parser recognizes "\ra" as invalid "\r\n" scan and breaks. */
|
||||
static HttpParserResult getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved, bool &isAncientHTTP, bool useStrictMethodValidation, uint64_t maxHeaderSize) {
|
||||
static HttpParserResult getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved, bool &isAncientHTTP, bool &isConnectRequest, bool useStrictMethodValidation, uint64_t maxHeaderSize) {
|
||||
char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer;
|
||||
#ifdef UWS_WITH_PROXY
|
||||
/* ProxyParser is passed as reserved parameter */
|
||||
@@ -689,6 +693,9 @@ namespace uWS
|
||||
if(requestLineResult.isAncientHTTP) {
|
||||
isAncientHTTP = true;
|
||||
}
|
||||
if(requestLineResult.isConnect) {
|
||||
isConnectRequest = true;
|
||||
}
|
||||
/* No request headers found */
|
||||
const char * headerStart = (headers[0].key.length() > 0) ? headers[0].key.data() : end;
|
||||
|
||||
@@ -798,7 +805,7 @@ namespace uWS
|
||||
|
||||
/* This is the only caller of getHeaders and is thus the deepest part of the parser. */
|
||||
template <bool ConsumeMinimally>
|
||||
HttpParserResult fenceAndConsumePostPadded(uint64_t maxHeaderSize, bool requireHostHeader, bool useStrictMethodValidation, char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction<void *(void *, HttpRequest *)> &requestHandler, MoveOnlyFunction<void *(void *, std::string_view, bool)> &dataHandler) {
|
||||
HttpParserResult fenceAndConsumePostPadded(uint64_t maxHeaderSize, bool& isConnectRequest, bool requireHostHeader, bool useStrictMethodValidation, char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction<void *(void *, HttpRequest *)> &requestHandler, MoveOnlyFunction<void *(void *, std::string_view, bool)> &dataHandler) {
|
||||
|
||||
/* How much data we CONSUMED (to throw away) */
|
||||
unsigned int consumedTotal = 0;
|
||||
@@ -809,7 +816,7 @@ namespace uWS
|
||||
data[length + 1] = 'a'; /* Anything that is not \n, to trigger "invalid request" */
|
||||
req->ancientHttp = false;
|
||||
for (;length;) {
|
||||
auto result = getHeaders(data, data + length, req->headers, reserved, req->ancientHttp, useStrictMethodValidation, maxHeaderSize);
|
||||
auto result = getHeaders(data, data + length, req->headers, reserved, req->ancientHttp, isConnectRequest, useStrictMethodValidation, maxHeaderSize);
|
||||
if(result.isError()) {
|
||||
return result;
|
||||
}
|
||||
@@ -916,6 +923,10 @@ namespace uWS
|
||||
length -= emittable;
|
||||
consumedTotal += emittable;
|
||||
}
|
||||
} else if(isConnectRequest) {
|
||||
// This only server to mark that the connect request read all headers
|
||||
// and can starting emitting data
|
||||
remainingStreamingBytes = STATE_IS_CHUNKED;
|
||||
} else {
|
||||
/* If we came here without a body; emit an empty data chunk to signal no data */
|
||||
dataHandler(user, {}, true);
|
||||
@@ -931,15 +942,16 @@ namespace uWS
|
||||
}
|
||||
|
||||
public:
|
||||
HttpParserResult consumePostPadded(uint64_t maxHeaderSize, bool requireHostHeader, bool useStrictMethodValidation, char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction<void *(void *, HttpRequest *)> &&requestHandler, MoveOnlyFunction<void *(void *, std::string_view, bool)> &&dataHandler) {
|
||||
|
||||
HttpParserResult consumePostPadded(uint64_t maxHeaderSize, bool& isConnectRequest, bool requireHostHeader, bool useStrictMethodValidation, char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction<void *(void *, HttpRequest *)> &&requestHandler, MoveOnlyFunction<void *(void *, std::string_view, bool)> &&dataHandler) {
|
||||
/* This resets BloomFilter by construction, but later we also reset it again.
|
||||
* Optimize this to skip resetting twice (req could be made global) */
|
||||
HttpRequest req;
|
||||
if (remainingStreamingBytes) {
|
||||
|
||||
/* It's either chunked or with a content-length */
|
||||
if (isParsingChunkedEncoding(remainingStreamingBytes)) {
|
||||
if (isConnectRequest) {
|
||||
dataHandler(user, std::string_view(data, length), false);
|
||||
return HttpParserResult::success(0, user);
|
||||
} else if (isParsingChunkedEncoding(remainingStreamingBytes)) {
|
||||
/* It's either chunked or with a content-length */
|
||||
std::string_view dataToConsume(data, length);
|
||||
for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) {
|
||||
dataHandler(user, chunk, chunk.length() == 0);
|
||||
@@ -950,6 +962,7 @@ public:
|
||||
data = (char *) dataToConsume.data();
|
||||
length = (unsigned int) dataToConsume.length();
|
||||
} else {
|
||||
|
||||
// this is exactly the same as below!
|
||||
// todo: refactor this
|
||||
if (remainingStreamingBytes >= length) {
|
||||
@@ -980,7 +993,7 @@ public:
|
||||
fallback.append(data, maxCopyDistance);
|
||||
|
||||
// break here on break
|
||||
HttpParserResult consumed = fenceAndConsumePostPadded<true>(maxHeaderSize, requireHostHeader, useStrictMethodValidation, fallback.data(), (unsigned int) fallback.length(), user, reserved, &req, requestHandler, dataHandler);
|
||||
HttpParserResult consumed = fenceAndConsumePostPadded<true>(maxHeaderSize, isConnectRequest, requireHostHeader, useStrictMethodValidation, fallback.data(), (unsigned int) fallback.length(), user, reserved, &req, requestHandler, dataHandler);
|
||||
/* Return data will be different than user if we are upgraded to WebSocket or have an error */
|
||||
if (consumed.returnedData != user) {
|
||||
return consumed;
|
||||
@@ -997,8 +1010,11 @@ public:
|
||||
length -= consumedBytes - had;
|
||||
|
||||
if (remainingStreamingBytes) {
|
||||
/* It's either chunked or with a content-length */
|
||||
if (isParsingChunkedEncoding(remainingStreamingBytes)) {
|
||||
if(isConnectRequest) {
|
||||
dataHandler(user, std::string_view(data, length), false);
|
||||
return HttpParserResult::success(0, user);
|
||||
} else if (isParsingChunkedEncoding(remainingStreamingBytes)) {
|
||||
/* It's either chunked or with a content-length */
|
||||
std::string_view dataToConsume(data, length);
|
||||
for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) {
|
||||
dataHandler(user, chunk, chunk.length() == 0);
|
||||
@@ -1037,7 +1053,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
HttpParserResult consumed = fenceAndConsumePostPadded<false>(maxHeaderSize, requireHostHeader, useStrictMethodValidation, data, length, user, reserved, &req, requestHandler, dataHandler);
|
||||
HttpParserResult consumed = fenceAndConsumePostPadded<false>(maxHeaderSize, isConnectRequest, requireHostHeader, useStrictMethodValidation, data, length, user, reserved, &req, requestHandler, dataHandler);
|
||||
/* Return data will be different than user if we are upgraded to WebSocket or have an error */
|
||||
if (consumed.returnedData != user) {
|
||||
return consumed;
|
||||
|
||||
@@ -243,7 +243,7 @@ public:
|
||||
/* Manually upgrade to WebSocket. Typically called in upgrade handler. Immediately calls open handler.
|
||||
* NOTE: Will invalidate 'this' as socket might change location in memory. Throw away after use. */
|
||||
template <typename UserData>
|
||||
us_socket_t *upgrade(UserData &&userData, std::string_view secWebSocketKey, std::string_view secWebSocketProtocol,
|
||||
us_socket_t *upgrade(UserData&& userData, std::string_view secWebSocketKey, std::string_view secWebSocketProtocol,
|
||||
std::string_view secWebSocketExtensions,
|
||||
struct us_socket_context_t *webSocketContext) {
|
||||
|
||||
@@ -350,7 +350,8 @@ public:
|
||||
us_socket_timeout(SSL, (us_socket_t *) webSocket, webSocketContextData->idleTimeoutComponents.first);
|
||||
|
||||
/* Move construct the UserData right before calling open handler */
|
||||
new (webSocket->getUserData()) UserData(std::move(userData));
|
||||
new (webSocket->getUserData()) UserData(std::forward<UserData>(userData));
|
||||
|
||||
|
||||
/* Emit open event and start the timeout */
|
||||
if (webSocketContextData->openHandler) {
|
||||
@@ -741,6 +742,10 @@ public:
|
||||
|
||||
return httpResponseData->socketData;
|
||||
}
|
||||
bool isConnectRequest() {
|
||||
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
|
||||
return httpResponseData->isConnectRequest;
|
||||
}
|
||||
|
||||
void setWriteOffset(uint64_t offset) {
|
||||
HttpResponseData<SSL> *httpResponseData = getHttpResponseData();
|
||||
|
||||
@@ -108,6 +108,7 @@ struct HttpResponseData : AsyncSocketData<SSL>, HttpParser {
|
||||
uint8_t state = 0;
|
||||
uint8_t idleTimeout = 10; // default HTTP_TIMEOUT 10 seconds
|
||||
bool fromAncientRequest = false;
|
||||
bool isConnectRequest = false;
|
||||
bool isIdle = true;
|
||||
bool shouldCloseOnceIdle = false;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user