diff --git a/packages/bun-uws/src/App.h b/packages/bun-uws/src/App.h index 0a44054bf9..a13c888c1c 100644 --- a/packages/bun-uws/src/App.h +++ b/packages/bun-uws/src/App.h @@ -16,8 +16,7 @@ * limitations under the License. */ // clang-format off -#ifndef UWS_APP_H -#define UWS_APP_H + #include #include @@ -619,4 +618,3 @@ typedef TemplatedApp SSLApp; } -#endif // UWS_APP_H \ No newline at end of file diff --git a/packages/bun-uws/src/HttpContext.h b/packages/bun-uws/src/HttpContext.h index 0081779bda..daa40cb442 100644 --- a/packages/bun-uws/src/HttpContext.h +++ b/packages/bun-uws/src/HttpContext.h @@ -16,8 +16,7 @@ * limitations under the License. */ -#ifndef UWS_HTTPCONTEXT_H -#define UWS_HTTPCONTEXT_H +#pragma once /* This class defines the main behavior of HTTP and emits various events */ @@ -27,6 +26,8 @@ #include "AsyncSocket.h" #include "WebSocketData.h" +#include +#include #include #include #include "MoveOnlyFunction.h" @@ -171,7 +172,7 @@ private: #endif /* The return value is entirely up to us to interpret. The HttpParser only care for whether the returned value is DIFFERENT or not from passed user */ - void *returnedSocket = httpResponseData->consumePostPadded(data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * { + auto [err, returnedSocket] = httpResponseData->consumePostPadded(data, (unsigned int) length, s, proxyParser, [httpContextData](void *s, HttpRequest *httpRequest) -> void * { /* For every request we reset the timeout and hang until user makes action */ /* Warning: if we are in shutdown state, resetting the timer is a security issue! */ us_socket_timeout(SSL, (us_socket_t *) s, 0); @@ -180,7 +181,9 @@ private: HttpResponseData *httpResponseData = (HttpResponseData *) us_socket_ext(SSL, (us_socket_t *) s); httpResponseData->offset = 0; - /* Are we not ready for another request yet? Terminate the connection. */ + /* Are we not ready for another request yet? Terminate the connection. + * Important for denying async pipelining until, if ever, we want to suppot it. + * Otherwise requests can get mixed up on the same connection. We still support sync pipelining. */ if (httpResponseData->state & HttpResponseData::HTTP_RESPONSE_PENDING) { us_socket_close(SSL, (us_socket_t *) s, 0, nullptr); return nullptr; @@ -280,10 +283,6 @@ private: } } return user; - }, [](void *user) { - /* Close any socket on HTTP errors */ - us_socket_close(SSL, (us_socket_t *) user, 0, nullptr); - return nullptr; }); /* Mark that we are no longer parsing Http */ @@ -291,6 +290,9 @@ private: /* If we got fullptr that means the parser wants us to close the socket from error (same as calling the errorHandler) */ if (returnedSocket == FULLPTR) { + /* For errors, we only deliver them "at most once". We don't care if they get halfways delivered or not. */ + us_socket_write(SSL, s, httpErrorResponses[err].data(), (int) httpErrorResponses[err].length(), false); + us_socket_shutdown(SSL, s); /* Close any socket on HTTP errors */ us_socket_close(SSL, s, 0, nullptr); /* This just makes the following code act as if the socket was closed from error inside the parser. */ @@ -299,9 +301,8 @@ private: /* We need to uncork in all cases, except for nullptr (closed socket, or upgraded socket) */ if (returnedSocket != nullptr) { - us_socket_t* returnedSocketPtr = (us_socket_t*) returnedSocket; /* We don't want open sockets to keep the event loop alive between HTTP requests */ - us_socket_unref(returnedSocketPtr); + us_socket_unref((us_socket_t *) returnedSocket); /* Timeout on uncork failure */ auto [written, failed] = ((AsyncSocket *) returnedSocket)->uncork(); @@ -321,7 +322,7 @@ private: } } } - return returnedSocketPtr; + return (us_socket_t *) returnedSocket; } /* If we upgraded, check here (differ between nullptr close and nullptr upgrade) */ @@ -483,10 +484,27 @@ public: return; } - httpContextData->currentRouter->add(methods, pattern, [handler = std::move(handler)](auto *r) mutable { + /* Record this route's parameter offsets */ + std::map> parameterOffsets; + unsigned short offset = 0; + for (unsigned int i = 0; i < pattern.length(); i++) { + if (pattern[i] == ':') { + i++; + unsigned int start = i; + while (i < pattern.length() && pattern[i] != '/') { + i++; + } + parameterOffsets[std::string(pattern.data() + start, i - start)] = offset; + //std::cout << "<" << std::string(pattern.data() + start, i - start) << "> is offset " << offset; + offset++; + } + } + + httpContextData->currentRouter->add(methods, pattern, [handler = std::move(handler), parameterOffsets = std::move(parameterOffsets)](auto *r) mutable { auto user = r->getUserData(); user.httpRequest->setYield(false); user.httpRequest->setParameters(r->getParameters()); + user.httpRequest->setParameterOffsets(¶meterOffsets); /* Middleware? Automatically respond to expectations */ std::string_view expect = user.httpRequest->getHeader("expect"); @@ -528,4 +546,4 @@ public: } -#endif // UWS_HTTPCONTEXT_H + diff --git a/packages/bun-uws/src/HttpError.h b/packages/bun-uws/src/HttpError.h new file mode 100644 index 0000000000..a17a1c7377 --- /dev/null +++ b/packages/bun-uws/src/HttpError.h @@ -0,0 +1,53 @@ +/* + * Authored by Alex Hultman, 2018-2023. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef UWS_HTTP_ERRORS +#define UWS_HTTP_ERRORS + +#include + +namespace uWS { +/* Possible errors from http parsing */ +enum HttpError { + HTTP_ERROR_505_HTTP_VERSION_NOT_SUPPORTED = 1, + HTTP_ERROR_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 2, + HTTP_ERROR_400_BAD_REQUEST = 3 +}; + +#ifndef UWS_HTTPRESPONSE_NO_WRITEMARK + +/* Returned parser errors match this LUT. */ +static const std::string_view httpErrorResponses[] = { + "", /* Zeroth place is no error so don't use it */ + "HTTP/1.1 505 HTTP Version Not Supported\r\nConnection: close\r\n\r\n

HTTP Version Not Supported

This server does not support HTTP/1.0.


uWebSockets/20 Server", + "HTTP/1.1 431 Request Header Fields Too Large\r\nConnection: close\r\n\r\n

Request Header Fields Too Large


uWebSockets/20 Server", + "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n

Bad Request


uWebSockets/20 Server", +}; + +#else +/* Anonymized pages */ +static const std::string_view httpErrorResponses[] = { + "", /* Zeroth place is no error so don't use it */ + "HTTP/1.1 505 HTTP Version Not Supported\r\nConnection: close\r\n\r\n", + "HTTP/1.1 431 Request Header Fields Too Large\r\nConnection: close\r\n\r\n", + "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n" +}; +#endif + +} + +#endif \ No newline at end of file diff --git a/packages/bun-uws/src/HttpErrors.h b/packages/bun-uws/src/HttpErrors.h new file mode 100644 index 0000000000..704ba0b43a --- /dev/null +++ b/packages/bun-uws/src/HttpErrors.h @@ -0,0 +1,42 @@ +/* + * Authored by Alex Hultman, 2018-2023. + * Intellectual property of third-party. + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace uWS { +/* Possible errors from http parsing */ +enum HttpError { + HTTP_ERROR_505_HTTP_VERSION_NOT_SUPPORTED = 1, + HTTP_ERROR_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 2, + HTTP_ERROR_400_BAD_REQUEST = 3 +}; + + +/* Anonymized pages */ +static const std::string_view httpErrorResponses[] = { + "", /* Zeroth place is no error so don't use it */ + "HTTP/1.1 505 HTTP Version Not Supported\r\nConnection: close\r\n\r\n", + "HTTP/1.1 431 Request Header Fields Too Large\r\nConnection: close\r\n\r\n", + "HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n" +}; + + + +} + diff --git a/packages/bun-uws/src/HttpParser.h b/packages/bun-uws/src/HttpParser.h index 59bfd0f769..457b665572 100644 --- a/packages/bun-uws/src/HttpParser.h +++ b/packages/bun-uws/src/HttpParser.h @@ -15,8 +15,7 @@ * limitations under the License. */ -#ifndef UWS_HTTPPARSER_H -#define UWS_HTTPPARSER_H +#pragma once #ifndef UWS_HTTP_MAX_HEADERS_COUNT #define UWS_HTTP_MAX_HEADERS_COUNT 100 @@ -38,6 +37,7 @@ #include "BloomFilter.h" #include "ProxyParser.h" #include "QueryParser.h" +#include "HttpErrors.h" extern "C" size_t BUN_DEFAULT_MAX_HTTP_HEADER_SIZE; @@ -63,6 +63,7 @@ namespace uWS bool didYield; BloomFilter bf; std::pair currentParameters; + std::map> *currentParameterOffsets = nullptr; public: bool isAncient() @@ -188,14 +189,26 @@ namespace uWS currentParameters = parameters; } - std::string_view getParameter(unsigned short index) + void setParameterOffsets(std::map> *offsets) { - if (currentParameters.first < (int)index) - { - return {}; + currentParameterOffsets = offsets; + } + + std::string_view getParameter(std::string_view name) { + if (!currentParameterOffsets) { + return {nullptr, 0}; } - else - { + auto it = currentParameterOffsets->find(name); + if (it == currentParameterOffsets->end()) { + return {nullptr, 0}; + } + return getParameter(it->second); + } + + std::string_view getParameter(unsigned short index) { + if (currentParameters.first < (int)index) { + return {}; + } else { return currentParameters.second[index]; } } @@ -211,46 +224,42 @@ namespace uWS const size_t MAX_FALLBACK_SIZE = BUN_DEFAULT_MAX_HTTP_HEADER_SIZE; - /* Returns UINT_MAX on error. Maximum 999999999 is allowed. */ + /* Returns UINT64_MAX on error. Maximum 999999999 is allowed. */ static uint64_t toUnsignedInteger(std::string_view str) { /* We assume at least 64-bit integer giving us safely 999999999999999999 (18 number of 9s) */ if (str.length() > 18) { - return UINT_MAX; + return UINT64_MAX; } uint64_t unsignedIntegerValue = 0; for (char c : str) { /* As long as the letter is 0-9 we cannot overflow. */ if (c < '0' || c > '9') { - return UINT_MAX; + return UINT64_MAX; } unsignedIntegerValue = unsignedIntegerValue * 10ull + ((unsigned int) c - (unsigned int) '0'); } return unsignedIntegerValue; } - - static inline uint64_t hasLess(uint64_t x, uint64_t n) - { - return (((x) - ~0ULL / 255 * (n)) & ~(x) & ~0ULL / 255 * 128); + + static inline uint64_t hasLess(uint64_t x, uint64_t n) { + return (((x)-~0ULL/255*(n))&~(x)&~0ULL/255*128); } - static inline uint64_t hasMore(uint64_t x, uint64_t n) - { - return ((((x) + ~0ULL / 255 * (127 - (n))) | (x)) & ~0ULL / 255 * 128); + static inline uint64_t hasMore(uint64_t x, uint64_t n) { + return (( ((x)+~0ULL/255*(127-(n))) |(x))&~0ULL/255*128); } - static inline uint64_t hasBetween(uint64_t x, uint64_t m, uint64_t n) - { - return (((~0ULL / 255 * (127 + (n)) - ((x) & ~0ULL / 255 * 127)) & ~(x) & (((x) & ~0ULL / 255 * 127) + ~0ULL / 255 * (127 - (m)))) & ~0ULL / 255 * 128); + static inline uint64_t hasBetween(uint64_t x, uint64_t m, uint64_t n) { + return (( (~0ULL/255*(127+(n))-((x)&~0ULL/255*127)) &~(x)& (((x)&~0ULL/255*127)+~0ULL/255*(127-(m))) )&~0ULL/255*128); } - static inline bool notFieldNameWord(uint64_t x) - { + static inline bool notFieldNameWord(uint64_t x) { return hasLess(x, '-') | - hasBetween(x, '-', '0') | - hasBetween(x, '9', 'A') | - hasBetween(x, 'Z', 'a') | - hasMore(x, 'z'); + hasBetween(x, '-', '0') | + hasBetween(x, '9', 'A') | + hasBetween(x, 'Z', 'a') | + hasMore(x, 'z'); } /* RFC 9110 5.6.2. Tokens */ @@ -305,392 +314,479 @@ namespace uWS return (void *)p; } + static inline int isHTTPorHTTPSPrefixForProxies(char *data, char *end) { + // We can check 8 because: + // 1. If it's "http://" that's 7 bytes, and it's supposed to at least have a trailing slash. + // 2. If it's "https://" that's 8 bytes exactly. + if (data + 8 >= end) [[unlikely]] { + // if it's not at least 8 bytes, let's try again later + return -1; + } + + uint64_t http; + __builtin_memcpy(&http, data, sizeof(uint64_t)); + + uint32_t first_four_bytes = http & static_cast(0xFFFFFFFF); + // check if any of the first four bytes are > non-ascii + if ((first_four_bytes & 0x80808080) != 0) [[unlikely]] { + return 0; + } + first_four_bytes |= 0x20202020; // Lowercase the first four bytes + + static constexpr char http_lowercase_bytes[4] = {'h', 't', 't', 'p'}; + static constexpr uint32_t http_lowercase_bytes_int = __builtin_bit_cast(uint32_t, http_lowercase_bytes); + if (first_four_bytes == http_lowercase_bytes_int) [[likely]] { + if (__builtin_memcmp(reinterpret_cast(&http) + 4, "://", 3) == 0) [[likely]] { + return 1; + } + + static constexpr char s_colon_slash_slash[4] = {'s', ':', '/', '/'}; + static constexpr uint32_t s_colon_slash_slash_int = __builtin_bit_cast(uint32_t, s_colon_slash_slash); + + static constexpr char S_colon_slash_slash[4] = {'S', ':', '/', '/'}; + static constexpr uint32_t S_colon_slash_slash_int = __builtin_bit_cast(uint32_t, S_colon_slash_slash); + + // Extract the last four bytes from the uint64_t + const uint32_t last_four_bytes = (http >> 32) & static_cast(0xFFFFFFFF); + return (last_four_bytes == s_colon_slash_slash_int) || (last_four_bytes == S_colon_slash_slash_int); + } + + return 0; + } + /* Puts method as key, target as value and returns non-null (or nullptr on error). */ - static inline char *consumeRequestLine(char *data, HttpRequest::Header &header, bool *isAncientHttp) - { - /* Scan until single SP, assume next is not SP (origin request) */ + static inline char *consumeRequestLine(char *data, char *end, HttpRequest::Header &header, bool &isAncientHTTP) { + /* Scan until single SP, assume next is / (origin request) */ char *start = data; /* This catches the post padded CR and fails */ - while (data[0] > 32) - data++; - if (data[0] == 32 && data[1] != 32) - { - header.key = {start, (size_t)(data - start)}; + while (data[0] > 32) data++; + if (&data[1] == end) [[unlikely]] { + return nullptr; + } + + if (data[0] == 32 && (__builtin_expect(data[1] == '/', 1) || isHTTPorHTTPSPrefixForProxies(data + 1, end) == 1)) [[likely]] { + header.key = {start, (size_t) (data - start)}; data++; /* Scan for less than 33 (catches post padded CR and fails) */ start = data; - for (; true; data += 8) - { + for (; true; data += 8) { uint64_t word; memcpy(&word, data, sizeof(uint64_t)); - if (hasLess(word, 33)) - { - while (*(unsigned char *)data > 32) - data++; + if (hasLess(word, 33)) { + while (*(unsigned char *)data > 32) data++; /* Now we stand on space */ - header.value = {start, (size_t)(data - start)}; + header.value = {start, (size_t) (data - start)}; /* Check that the following is http 1.1 */ - if (memcmp(" HTTP/1.1\r\n", data, 11) == 0) - { - *isAncientHttp = false; + if (data + 11 >= end) { + /* Whatever we have must be part of the version string */ + if (memcmp(" HTTP/1.1\r\n", data, std::min(11, (unsigned int) (end - data))) == 0) { + return nullptr; + } else if (memcmp(" HTTP/1.0\r\n", data, std::min(11, (unsigned int) (end - data))) == 0) { + isAncientHTTP = true; + return data + 11; + } + return (char *) 0x1; + } + if (memcmp(" HTTP/1.1\r\n", data, 11) == 0) { + return data + 11; + } else if (memcmp(" HTTP/1.0\r\n", data, 11) == 0) { + isAncientHTTP = true; return data + 11; } - /* Check that the following is ancient http 1.0 */ - if (memcmp(" HTTP/1.0\r\n", data, 11) == 0) - { - *isAncientHttp = true; - return data + 11; + /* If we stand at the post padded CR, we have fragmented input so try again later */ + if (data[0] == '\r') { + return nullptr; } - return nullptr; + /* This is an error */ + return (char *) 0x1; } } } - return nullptr; + + /* If we stand at the post padded CR, we have fragmented input so try again later */ + if (data[0] == '\r') { + return nullptr; + } + + if (data[0] == 32) { + switch (isHTTPorHTTPSPrefixForProxies(data + 1, end)) { + // If we haven't received enough data to check if it's http:// or https://, let's try again later + case -1: + return nullptr; + // Otherwise, if it's not http:// or https://, return 400 + default: + return (char *) 0x2; + } + } + + return (char *) 0x1; } /* RFC 9110: 5.5 Field Values (TLDR; anything above 31 is allowed; htab (9) is also allowed) - * Field values are usually constrained to the range of US-ASCII characters [...] - * Field values containing CR, LF, or NUL characters are invalid and dangerous [...] - * Field values containing other CTL characters are also invalid. */ - static inline void *tryConsumeFieldValue(char *p) - { - for (; true; p += 8) - { + * Field values are usually constrained to the range of US-ASCII characters [...] + * Field values containing CR, LF, or NUL characters are invalid and dangerous [...] + * Field values containing other CTL characters are also invalid. */ + static inline void *tryConsumeFieldValue(char *p) { + for (; true; p += 8) { uint64_t word; memcpy(&word, p, sizeof(uint64_t)); - if (hasLess(word, 32)) - { - while (*(unsigned char *)p > 31) - p++; + if (hasLess(word, 32)) { + while (*(unsigned char *)p > 31) p++; return (void *)p; } } } + /* End is only used for the proxy parser. The HTTP parser recognizes "\ra" as invalid "\r\n" scan and breaks. */ + static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved, unsigned int &err, bool &isAncientHTTP) { + char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer; - /* End is only used for the proxy parser. The HTTP parser recognizes "\ra" as invalid "\r\n" scan and breaks. */ - static unsigned int getHeaders(char *postPaddedBuffer, char *end, struct HttpRequest::Header *headers, void *reserved, bool*isAncientHttp) { - char *preliminaryKey, *preliminaryValue, *start = postPaddedBuffer; + #ifdef UWS_WITH_PROXY + /* ProxyParser is passed as reserved parameter */ + ProxyParser *pp = (ProxyParser *) reserved; -#ifdef UWS_WITH_PROXY - /* ProxyParser is passed as reserved parameter */ - ProxyParser *pp = (ProxyParser *)reserved; - - /* Parse PROXY protocol */ - auto [done, offset] = pp->parse({start, (size_t)(end - postPaddedBuffer)}); - if (!done) - { - /* We do not reset the ProxyParser (on filure) since it is tied to this - * connection, which is really only supposed to ever get one PROXY frame - * anyways. We do however allow multiple PROXY frames to be sent (overwrites former). */ - return 0; - } - else - { - /* We have consumed this data so skip it */ - start += offset; - } -#else - /* This one is unused */ - (void)reserved; - (void)end; -#endif + /* Parse PROXY protocol */ + auto [done, offset] = pp->parse({postPaddedBuffer, (size_t) (end - postPaddedBuffer)}); + if (!done) { + /* We do not reset the ProxyParser (on filure) since it is tied to this + * connection, which is really only supposed to ever get one PROXY frame + * anyways. We do however allow multiple PROXY frames to be sent (overwrites former). */ + return 0; + } else { + /* We have consumed this data so skip it */ + postPaddedBuffer += offset; + } + #else + /* This one is unused */ + (void) reserved; + (void) end; + #endif /* It is critical for fallback buffering logic that we only return with success - * if we managed to parse a complete HTTP request (minus data). Returning success - * for PROXY means we can end up succeeding, yet leaving bytes in the fallback buffer - * which is then removed, and our counters to flip due to overflow and we end up with a crash */ + * if we managed to parse a complete HTTP request (minus data). Returning success + * for PROXY means we can end up succeeding, yet leaving bytes in the fallback buffer + * which is then removed, and our counters to flip due to overflow and we end up with a crash */ /* The request line is different from the field names / field values */ - if (!(postPaddedBuffer = consumeRequestLine(postPaddedBuffer, headers[0], isAncientHttp))) - { + if ((char *) 3 > (postPaddedBuffer = consumeRequestLine(postPaddedBuffer, end, headers[0], isAncientHTTP))) { /* Error - invalid request line */ - + /* Assuming it is 505 HTTP Version Not Supported */ + switch (reinterpret_cast(postPaddedBuffer)) { + case 0x1: + err = HTTP_ERROR_505_HTTP_VERSION_NOT_SUPPORTED;; + break; + case 0x2: + err = HTTP_ERROR_400_BAD_REQUEST; + break; + default: { + err = 0; + break; + } + } return 0; } headers++; - for (unsigned int i = 1; i < UWS_HTTP_MAX_HEADERS_COUNT - 1; i++) - { + for (unsigned int i = 1; i < UWS_HTTP_MAX_HEADERS_COUNT - 1; i++) { /* Lower case and consume the field name */ preliminaryKey = postPaddedBuffer; - postPaddedBuffer = (char *)consumeFieldName(postPaddedBuffer); - headers->key = std::string_view(preliminaryKey, (size_t)(postPaddedBuffer - preliminaryKey)); + postPaddedBuffer = (char *) consumeFieldName(postPaddedBuffer); + headers->key = std::string_view(preliminaryKey, (size_t) (postPaddedBuffer - preliminaryKey)); /* We should not accept whitespace between key and colon, so colon must foloow immediately */ - if (postPaddedBuffer[0] != ':') - { + if (postPaddedBuffer[0] != ':') { + /* If we stand at the end, we are fragmented */ + if (postPaddedBuffer == end) { + return 0; + } /* Error: invalid chars in field name */ + err = HTTP_ERROR_400_BAD_REQUEST; return 0; } postPaddedBuffer++; preliminaryValue = postPaddedBuffer; /* The goal of this call is to find next "\r\n", or any invalid field value chars, fast */ - while (true) - { - postPaddedBuffer = (char *)tryConsumeFieldValue(postPaddedBuffer); + while (true) { + postPaddedBuffer = (char *) tryConsumeFieldValue(postPaddedBuffer); /* If this is not CR then we caught some stinky invalid char on the way */ - if (postPaddedBuffer[0] != '\r') - { + if (postPaddedBuffer[0] != '\r') { /* If TAB then keep searching */ - if (postPaddedBuffer[0] == '\t') - { + if (postPaddedBuffer[0] == '\t') { postPaddedBuffer++; continue; } /* Error - invalid chars in field value */ + err = HTTP_ERROR_400_BAD_REQUEST; return 0; } break; } /* We fence end[0] with \r, followed by end[1] being something that is "not \n", to signify "not found". - * This way we can have this one single check to see if we found \r\n WITHIN our allowed search space. */ - if (postPaddedBuffer[1] == '\n') - { + * This way we can have this one single check to see if we found \r\n WITHIN our allowed search space. */ + if (postPaddedBuffer[1] == '\n') { /* Store this header, it is valid */ - headers->value = std::string_view(preliminaryValue, (size_t)(postPaddedBuffer - preliminaryValue)); + headers->value = std::string_view(preliminaryValue, (size_t) (postPaddedBuffer - preliminaryValue)); postPaddedBuffer += 2; /* Trim trailing whitespace (SP, HTAB) */ - while (headers->value.length() && headers->value.back() < 33) - { + while (headers->value.length() && headers->value.back() < 33) { headers->value.remove_suffix(1); } /* Trim initial whitespace (SP, HTAB) */ - while (headers->value.length() && headers->value.front() < 33) - { + while (headers->value.length() && headers->value.front() < 33) { headers->value.remove_prefix(1); } - + headers++; /* We definitely have at least one header (or request line), so check if we are done */ - if (*postPaddedBuffer == '\r') - { - if (postPaddedBuffer[1] == '\n') - { + if (*postPaddedBuffer == '\r') { + if (postPaddedBuffer[1] == '\n') { /* This cann take the very last header space */ headers->key = std::string_view(nullptr, 0); - return (unsigned int)((postPaddedBuffer + 2) - start); - } - else - { + return (unsigned int) ((postPaddedBuffer + 2) - start); + } else { /* \r\n\r plus non-\n letter is malformed request, or simply out of search space */ + if (postPaddedBuffer + 1 < end) { + err = HTTP_ERROR_400_BAD_REQUEST; + } return 0; } } - } - else - { + } else { /* We are either out of search space or this is a malformed request */ return 0; } } /* We ran out of header space, too large request */ + err = HTTP_ERROR_431_REQUEST_HEADER_FIELDS_TOO_LARGE; return 0; } + /* This is the only caller of getHeaders and is thus the deepest part of the parser. + * From here we return either [consumed, user] for "keep going", + * or [consumed, nullptr] for "break; I am closed or upgraded to websocket" + * or [whatever, fullptr] for "break and close me, I am a parser error!" */ + template + std::pair fenceAndConsumePostPadded(char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction &requestHandler, MoveOnlyFunction &dataHandler) { - /* This is the only caller of getHeaders and is thus the deepest part of the parser. - * From here we return either [consumed, user] for "keep going", - * or [consumed, nullptr] for "break; I am closed or upgraded to websocket" - * or [whatever, fullptr] for "break and close me, I am a parser error!" */ - template - std::pair fenceAndConsumePostPadded(char *data, unsigned int length, void *user, void *reserved, HttpRequest *req, MoveOnlyFunction &requestHandler, MoveOnlyFunction &dataHandler) - { + /* How much data we CONSUMED (to throw away) */ + unsigned int consumedTotal = 0; + unsigned int err = 0; - /* How much data we CONSUMED (to throw away) */ - unsigned int consumedTotal = 0; + /* Fence two bytes past end of our buffer (buffer has post padded margins). + * This is to always catch scan for \r but not for \r\n. */ + data[length] = '\r'; + data[length + 1] = 'a'; /* Anything that is not \n, to trigger "invalid request" */ + bool isAncientHTTP = false; - /* Fence two bytes past end of our buffer (buffer has post padded margins). - * This is to always catch scan for \r but not for \r\n. */ - data[length] = '\r'; - data[length + 1] = 'a'; /* Anything that is not \n, to trigger "invalid request" */ - bool isAncientHttp = false; - for (unsigned int consumed; length && (consumed = getHeaders(data, data + length, req->headers, reserved, &isAncientHttp));) - { - data += consumed; - length -= consumed; - consumedTotal += consumed; + for (unsigned int consumed; length && (consumed = getHeaders(data, data + length, req->headers, reserved, err, isAncientHTTP)); ) { + data += consumed; + length -= consumed; + consumedTotal += consumed; - /* Even if we could parse it, check for length here as well */ - if (consumed > MAX_FALLBACK_SIZE) { - return {0, FULLPTR}; - } - - /* Store HTTP version (ancient 1.0 or 1.1) */ - req->ancientHttp = isAncientHttp; - - /* Add all headers to bloom filter */ - req->bf.reset(); - for (HttpRequest::Header *h = req->headers; (++h)->key.length();) - { - req->bf.add(h->key); - } - - /* Break if no host header (but we can have empty string which is different from nullptr) */ - if (!req->getHeader("host").data()) - { - return {0, FULLPTR}; - } - - /* RFC 9112 6.3 - * If a message is received with both a Transfer-Encoding and a Content-Length header field, - * the Transfer-Encoding overrides the Content-Length. Such a message might indicate an attempt - * to perform request smuggling (Section 11.2) or response splitting (Section 11.1) and - * ought to be handled as an error. */ - std::string_view transferEncodingString = req->getHeader("transfer-encoding"); - std::string_view contentLengthString = req->getHeader("content-length"); - if (transferEncodingString.length() && contentLengthString.length()) - { - /* Returning fullptr is the same as calling the errorHandler */ - /* We could be smart and set an error in the context along with this, to indicate what - * http error response we might want to return */ - return {0, FULLPTR}; - } - - /* Parse query */ - const char *querySeparatorPtr = (const char *)memchr(req->headers->value.data(), '?', req->headers->value.length()); - req->querySeparator = (unsigned int)((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data()); - - /* If returned socket is not what we put in we need - * to break here as we either have upgraded to - * WebSockets or otherwise closed the socket. */ - void *returnedUser = requestHandler(user, req); - if (returnedUser != user) - { - /* We are upgraded to WebSocket or otherwise broken */ - return {consumedTotal, returnedUser}; - } - - /* The rules at play here according to RFC 9112 for requests are essentially: - * If both content-length and transfer-encoding then invalid message; must break. - * If has transfer-encoding then must be chunked regardless of value. - * If content-length then fixed length even if 0. - * If none of the above then fixed length is 0. */ - - /* RFC 9112 6.3 - * If a message is received with both a Transfer-Encoding and a Content-Length header field, - * the Transfer-Encoding overrides the Content-Length. */ - if (transferEncodingString.length()) - { - - /* If a proxy sent us the transfer-encoding header that 100% means it must be chunked or else the proxy is - * not RFC 9112 compliant. Therefore it is always better to assume this is the case, since that entirely eliminates - * all forms of transfer-encoding obfuscation tricks. We just rely on the header. */ - - /* RFC 9112 6.3 - * If a Transfer-Encoding header field is present in a request and the chunked transfer coding is not the - * final encoding, the message body length cannot be determined reliably; the server MUST respond with the - * 400 (Bad Request) status code and then close the connection. */ - - /* In this case we fail later by having the wrong interpretation (assuming chunked). - * This could be made stricter but makes no difference either way, unless forwarding the identical message as a proxy. */ - - remainingStreamingBytes = STATE_IS_CHUNKED; - /* If consume minimally, we do not want to consume anything but we want to mark this as being chunked */ - if (!CONSUME_MINIMALLY) - { - /* Go ahead and parse it (todo: better heuristics for emitting FIN to the app level) */ - std::string_view dataToConsume(data, length); - for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) - { - dataHandler(user, chunk, chunk.length() == 0); - } - if (isParsingInvalidChunkedEncoding(remainingStreamingBytes)) - { - return {0, FULLPTR}; - } - unsigned int consumed = (length - (unsigned int)dataToConsume.length()); - data = (char *)dataToConsume.data(); - length = (unsigned int)dataToConsume.length(); - consumedTotal += consumed; - } - } - else if (contentLengthString.length()) - { - remainingStreamingBytes = toUnsignedInteger(contentLengthString); - if (remainingStreamingBytes == UINT_MAX) - { - /* Parser error */ - return {0, FULLPTR}; - } - - if (!CONSUME_MINIMALLY) - { - unsigned int emittable = (unsigned int) std::min(remainingStreamingBytes, length); - dataHandler(user, std::string_view(data, emittable), emittable == remainingStreamingBytes); - remainingStreamingBytes -= emittable; - - data += emittable; - length -= emittable; - consumedTotal += emittable; - } - } - else - { - /* If we came here without a body; emit an empty data chunk to signal no data */ - dataHandler(user, {}, true); - } - - /* Consume minimally should break as easrly as possible */ - if (CONSUME_MINIMALLY) - { - break; - } + /* Even if we could parse it, check for length here as well */ + if (consumed > MAX_FALLBACK_SIZE) { + return {HTTP_ERROR_431_REQUEST_HEADER_FIELDS_TOO_LARGE, FULLPTR}; } - return {consumedTotal, user}; - } - public: - void *consumePostPadded(char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction &&requestHandler, MoveOnlyFunction &&dataHandler, MoveOnlyFunction &&errorHandler) - { - /* This resets BloomFilter by construction, but later we also reset it again. - * Optimize this to skip resetting twice (req could be made global) */ - HttpRequest req; + /* Store HTTP version (ancient 1.0 or 1.1) */ + req->ancientHttp = isAncientHTTP; - if (remainingStreamingBytes) - { + /* Add all headers to bloom filter */ + req->bf.reset(); + for (HttpRequest::Header *h = req->headers; (++h)->key.length(); ) { + req->bf.add(h->key); + } + + /* Break if no host header (but we can have empty string which is different from nullptr) */ + if (!req->getHeader("host").data()) { + return {HTTP_ERROR_400_BAD_REQUEST, FULLPTR}; + } - /* It's either chunked or with a content-length */ - if (isParsingChunkedEncoding(remainingStreamingBytes)) - { + /* RFC 9112 6.3 + * If a message is received with both a Transfer-Encoding and a Content-Length header field, + * the Transfer-Encoding overrides the Content-Length. Such a message might indicate an attempt + * to perform request smuggling (Section 11.2) or response splitting (Section 11.1) and + * ought to be handled as an error. */ + std::string_view transferEncodingString = req->getHeader("transfer-encoding"); + std::string_view contentLengthString = req->getHeader("content-length"); + if (transferEncodingString.length() && contentLengthString.length()) { + /* Returning fullptr is the same as calling the errorHandler */ + /* We could be smart and set an error in the context along with this, to indicate what + * http error response we might want to return */ + return {HTTP_ERROR_400_BAD_REQUEST, FULLPTR}; + } + + /* Parse query */ + const char *querySeparatorPtr = (const char *) memchr(req->headers->value.data(), '?', req->headers->value.length()); + req->querySeparator = (unsigned int) ((querySeparatorPtr ? querySeparatorPtr : req->headers->value.data() + req->headers->value.length()) - req->headers->value.data()); + + /* If returned socket is not what we put in we need + * to break here as we either have upgraded to + * WebSockets or otherwise closed the socket. */ + void *returnedUser = requestHandler(user, req); + if (returnedUser != user) { + /* We are upgraded to WebSocket or otherwise broken */ + return {consumedTotal, returnedUser}; + } + + /* The rules at play here according to RFC 9112 for requests are essentially: + * If both content-length and transfer-encoding then invalid message; must break. + * If has transfer-encoding then must be chunked regardless of value. + * If content-length then fixed length even if 0. + * If none of the above then fixed length is 0. */ + + /* RFC 9112 6.3 + * If a message is received with both a Transfer-Encoding and a Content-Length header field, + * the Transfer-Encoding overrides the Content-Length. */ + if (transferEncodingString.length()) { + + /* If a proxy sent us the transfer-encoding header that 100% means it must be chunked or else the proxy is + * not RFC 9112 compliant. Therefore it is always better to assume this is the case, since that entirely eliminates + * all forms of transfer-encoding obfuscation tricks. We just rely on the header. */ + + /* RFC 9112 6.3 + * If a Transfer-Encoding header field is present in a request and the chunked transfer coding is not the + * final encoding, the message body length cannot be determined reliably; the server MUST respond with the + * 400 (Bad Request) status code and then close the connection. */ + + /* In this case we fail later by having the wrong interpretation (assuming chunked). + * This could be made stricter but makes no difference either way, unless forwarding the identical message as a proxy. */ + + remainingStreamingBytes = STATE_IS_CHUNKED; + /* If consume minimally, we do not want to consume anything but we want to mark this as being chunked */ + if (!CONSUME_MINIMALLY) { + /* Go ahead and parse it (todo: better heuristics for emitting FIN to the app level) */ std::string_view dataToConsume(data, length); - for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) - { + for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) { dataHandler(user, chunk, chunk.length() == 0); } - if (isParsingInvalidChunkedEncoding(remainingStreamingBytes)) - { - return FULLPTR; + if (isParsingInvalidChunkedEncoding(remainingStreamingBytes)) { + return {HTTP_ERROR_400_BAD_REQUEST, FULLPTR}; } - data = (char *)dataToConsume.data(); - length = (unsigned int)dataToConsume.length(); + unsigned int consumed = (length - (unsigned int) dataToConsume.length()); + data = (char *) dataToConsume.data(); + length = (unsigned int) dataToConsume.length(); + consumedTotal += consumed; } - else - { - // this is exactly the same as below! - // todo: refactor this - if (remainingStreamingBytes >= length) - { - void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == length); - remainingStreamingBytes -= length; - return returnedUser; + } else if (contentLengthString.length()) { + remainingStreamingBytes = toUnsignedInteger(contentLengthString); + if (remainingStreamingBytes == UINT64_MAX) { + /* Parser error */ + return {HTTP_ERROR_400_BAD_REQUEST, FULLPTR}; + } + + if (!CONSUME_MINIMALLY) { + unsigned int emittable = (unsigned int) std::min(remainingStreamingBytes, length); + dataHandler(user, std::string_view(data, emittable), emittable == remainingStreamingBytes); + remainingStreamingBytes -= emittable; + + data += emittable; + length -= emittable; + consumedTotal += emittable; + } + } else { + /* If we came here without a body; emit an empty data chunk to signal no data */ + dataHandler(user, {}, true); + } + + /* Consume minimally should break as easrly as possible */ + if (CONSUME_MINIMALLY) { + break; + } + } + /* Whenever we return FULLPTR, the interpretation of "consumed" should be the HttpError enum. */ + if (err) { + return {err, FULLPTR}; + } + return {consumedTotal, user}; + } + +public: + std::pair consumePostPadded(char *data, unsigned int length, void *user, void *reserved, MoveOnlyFunction &&requestHandler, MoveOnlyFunction &&dataHandler) { + + /* This resets BloomFilter by construction, but later we also reset it again. + * Optimize this to skip resetting twice (req could be made global) */ + HttpRequest req; + if (remainingStreamingBytes) { + + /* It's either chunked or with a content-length */ + if (isParsingChunkedEncoding(remainingStreamingBytes)) { + std::string_view dataToConsume(data, length); + for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) { + dataHandler(user, chunk, chunk.length() == 0); + } + if (isParsingInvalidChunkedEncoding(remainingStreamingBytes)) { + return {HTTP_ERROR_400_BAD_REQUEST, FULLPTR}; + } + data = (char *) dataToConsume.data(); + length = (unsigned int) dataToConsume.length(); + } else { + // this is exactly the same as below! + // todo: refactor this + if (remainingStreamingBytes >= length) { + void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == length); + remainingStreamingBytes -= length; + return {0, returnedUser}; + } else { + void *returnedUser = dataHandler(user, std::string_view(data, remainingStreamingBytes), true); + + data += (unsigned int) remainingStreamingBytes; + length -= (unsigned int) remainingStreamingBytes; + + remainingStreamingBytes = 0; + + if (returnedUser != user) { + return {0, returnedUser}; } - else - { - void *returnedUser = dataHandler(user, std::string_view(data, remainingStreamingBytes), true); + } + } - data += (unsigned int) remainingStreamingBytes; - length -= (unsigned int) remainingStreamingBytes; + } else if (fallback.length()) { + unsigned int had = (unsigned int) fallback.length(); - remainingStreamingBytes = 0; + size_t maxCopyDistance = std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t) length); - if (returnedUser != user) - { - return returnedUser; + /* We don't want fallback to be short string optimized, since we want to move it */ + fallback.reserve(fallback.length() + maxCopyDistance + std::max(MINIMUM_HTTP_POST_PADDING, sizeof(std::string))); + fallback.append(data, maxCopyDistance); + + // break here on break + std::pair consumed = fenceAndConsumePostPadded(fallback.data(), (unsigned int) fallback.length(), user, reserved, &req, requestHandler, dataHandler); + if (consumed.second != user) { + return consumed; + } + + if (consumed.first) { + + /* This logic assumes that we consumed everything in fallback buffer. + * This is critically important, as we will get an integer overflow in case + * of "had" being larger than what we consumed, and that we would drop data */ + fallback.clear(); + data += consumed.first - had; + length -= consumed.first - had; + + if (remainingStreamingBytes) { + /* It's either chunked or with a content-length */ + if (isParsingChunkedEncoding(remainingStreamingBytes)) { + std::string_view dataToConsume(data, length); + for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) { + dataHandler(user, chunk, chunk.length() == 0); + } + if (isParsingInvalidChunkedEncoding(remainingStreamingBytes)) { + return {HTTP_ERROR_400_BAD_REQUEST, FULLPTR}; + } + data = (char *) dataToConsume.data(); + length = (unsigned int) dataToConsume.length(); + } else { + // this is exactly the same as above! + if (remainingStreamingBytes >= (unsigned int) length) { + void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == (unsigned int) length); + remainingStreamingBytes -= length; + return {0, returnedUser}; } else { void *returnedUser = dataHandler(user, std::string_view(data, remainingStreamingBytes), true); @@ -700,120 +796,40 @@ namespace uWS remainingStreamingBytes = 0; if (returnedUser != user) { - return returnedUser; + return {0, returnedUser}; } } } } + + } else { + if (fallback.length() == MAX_FALLBACK_SIZE) { + return {HTTP_ERROR_431_REQUEST_HEADER_FIELDS_TOO_LARGE, FULLPTR}; + } + return {0, user}; } - else if (fallback.length()) - { - unsigned int had = (unsigned int)fallback.length(); - - size_t maxCopyDistance = std::min(MAX_FALLBACK_SIZE - fallback.length(), (size_t)length); - - /* We don't want fallback to be short string optimized, since we want to move it */ - fallback.reserve(fallback.length() + maxCopyDistance + std::max(MINIMUM_HTTP_POST_PADDING, sizeof(std::string))); - fallback.append(data, maxCopyDistance); - - // break here on break - std::pair consumed = fenceAndConsumePostPadded(fallback.data(), (unsigned int)fallback.length(), user, reserved, &req, requestHandler, dataHandler); - if (consumed.second != user) - { - return consumed.second; - } - - if (consumed.first) - { - - /* This logic assumes that we consumed everything in fallback buffer. - * This is critically important, as we will get an integer overflow in case - * of "had" being larger than what we consumed, and that we would drop data */ - fallback.clear(); - data += consumed.first - had; - length -= consumed.first - had; - - if (remainingStreamingBytes) - { - /* It's either chunked or with a content-length */ - if (isParsingChunkedEncoding(remainingStreamingBytes)) - { - std::string_view dataToConsume(data, length); - for (auto chunk : uWS::ChunkIterator(&dataToConsume, &remainingStreamingBytes)) - { - dataHandler(user, chunk, chunk.length() == 0); - } - if (isParsingInvalidChunkedEncoding(remainingStreamingBytes)) - { - return FULLPTR; - } - data = (char *)dataToConsume.data(); - length = (unsigned int)dataToConsume.length(); - } - else - { - // this is exactly the same as above! - if (remainingStreamingBytes >= (unsigned int)length) - { - void *returnedUser = dataHandler(user, std::string_view(data, length), remainingStreamingBytes == (unsigned int)length); - remainingStreamingBytes -= length; - return returnedUser; - } - else - { - void *returnedUser = dataHandler(user, std::string_view(data, remainingStreamingBytes), true); - - data += remainingStreamingBytes; - length -= remainingStreamingBytes; - - remainingStreamingBytes = 0; - - if (returnedUser != user) - { - return returnedUser; - } - } - } - } - } - else - { - if (fallback.length() == MAX_FALLBACK_SIZE) - { - // note: you don't really need error handler, just return something strange! - // we could have it return a constant pointer to denote error! - return errorHandler(user); - } - return user; - } - } - - std::pair consumed = fenceAndConsumePostPadded(data, length, user, reserved, &req, requestHandler, dataHandler); - if (consumed.second != user) - { - return consumed.second; - } - - data += consumed.first; - length -= consumed.first; - - if (length) - { - if (length < MAX_FALLBACK_SIZE) - { - fallback.append(data, length); - } - else - { - return errorHandler(user); - } - } - - // added for now - return user; } - }; + + std::pair consumed = fenceAndConsumePostPadded(data, length, user, reserved, &req, requestHandler, dataHandler); + if (consumed.second != user) { + return consumed; + } + + data += consumed.first; + length -= consumed.first; + + if (length) { + if (length < MAX_FALLBACK_SIZE) { + fallback.append(data, length); + } else { + return {HTTP_ERROR_431_REQUEST_HEADER_FIELDS_TOO_LARGE, FULLPTR}; + } + } + + // added for now + return {0, user}; + } +}; } -#endif // UWS_HTTPPARSER_H diff --git a/packages/bun-uws/src/HttpResponseData.h b/packages/bun-uws/src/HttpResponseData.h index 9613e84fe4..b4c3195f26 100644 --- a/packages/bun-uws/src/HttpResponseData.h +++ b/packages/bun-uws/src/HttpResponseData.h @@ -15,8 +15,7 @@ * limitations under the License. */ // clang-format off -#ifndef UWS_HTTPRESPONSEDATA_H -#define UWS_HTTPRESPONSEDATA_H +#pragma once /* This data belongs to the HttpResponse */ @@ -106,4 +105,4 @@ struct HttpResponseData : AsyncSocketData, HttpParser { } -#endif // UWS_HTTPRESPONSEDATA_H + diff --git a/packages/bun-uws/src/TopicTree.h b/packages/bun-uws/src/TopicTree.h index 215afd28b2..69855cf9fc 100644 --- a/packages/bun-uws/src/TopicTree.h +++ b/packages/bun-uws/src/TopicTree.h @@ -15,9 +15,7 @@ * limitations under the License. */ -#ifndef UWS_TOPICTREE_H -#define UWS_TOPICTREE_H - +#pragma once #include #include #include @@ -366,4 +364,4 @@ public: } -#endif + diff --git a/test/harness.ts b/test/harness.ts index 99803988c9..652d99b1f9 100644 --- a/test/harness.ts +++ b/test/harness.ts @@ -307,174 +307,174 @@ const binaryTypes = { "float32array": Float32Array, "float64array": Float64Array, } as const; - -expect.extend({ - toHaveTestTimedOutAfter(actual: any, expected: number) { - if (typeof actual !== "string") { - return { - pass: false, - message: () => `Expected ${actual} to be a string`, - }; - } - - const preStartI = actual.indexOf("timed out after "); - if (preStartI === -1) { - return { - pass: false, - message: () => `Expected ${actual} to contain "timed out after "`, - }; - } - const startI = preStartI + "timed out after ".length; - const endI = actual.indexOf("ms", startI); - if (endI === -1) { - return { - pass: false, - message: () => `Expected ${actual} to contain "ms" after "timed out after "`, - }; - } - const int = parseInt(actual.slice(startI, endI)); - if (!Number.isSafeInteger(int)) { - return { - pass: false, - message: () => `Expected ${int} to be a safe integer`, - }; - } - - return { - pass: int >= expected, - message: () => `Expected ${int} to be >= ${expected}`, - }; - }, - toBeBinaryType(actual: any, expected: keyof typeof binaryTypes) { - switch (expected) { - case "buffer": +if (expect.extend) + expect.extend({ + toHaveTestTimedOutAfter(actual: any, expected: number) { + if (typeof actual !== "string") { return { - pass: Buffer.isBuffer(actual), - message: () => `Expected ${actual} to be buffer`, + pass: false, + message: () => `Expected ${actual} to be a string`, }; - case "arraybuffer": + } + + const preStartI = actual.indexOf("timed out after "); + if (preStartI === -1) { return { - pass: actual instanceof ArrayBuffer, - message: () => `Expected ${actual} to be ArrayBuffer`, + pass: false, + message: () => `Expected ${actual} to contain "timed out after "`, }; - default: { - const ctor = binaryTypes[expected]; - if (!ctor) { + } + const startI = preStartI + "timed out after ".length; + const endI = actual.indexOf("ms", startI); + if (endI === -1) { + return { + pass: false, + message: () => `Expected ${actual} to contain "ms" after "timed out after "`, + }; + } + const int = parseInt(actual.slice(startI, endI)); + if (!Number.isSafeInteger(int)) { + return { + pass: false, + message: () => `Expected ${int} to be a safe integer`, + }; + } + + return { + pass: int >= expected, + message: () => `Expected ${int} to be >= ${expected}`, + }; + }, + toBeBinaryType(actual: any, expected: keyof typeof binaryTypes) { + switch (expected) { + case "buffer": + return { + pass: Buffer.isBuffer(actual), + message: () => `Expected ${actual} to be buffer`, + }; + case "arraybuffer": + return { + pass: actual instanceof ArrayBuffer, + message: () => `Expected ${actual} to be ArrayBuffer`, + }; + default: { + const ctor = binaryTypes[expected]; + if (!ctor) { + return { + pass: false, + message: () => `Expected ${expected} to be a binary type`, + }; + } + + return { + pass: actual instanceof ctor, + message: () => `Expected ${actual} to be ${expected}`, + }; + } + } + }, + toRun(cmds: string[], optionalStdout?: string, expectedCode: number = 0) { + const result = Bun.spawnSync({ + cmd: [bunExe(), ...cmds], + env: bunEnv, + stdio: ["inherit", "pipe", "inherit"], + }); + + if (result.exitCode !== expectedCode) { + return { + pass: false, + message: () => `Command ${cmds.join(" ")} failed:` + "\n" + result.stdout.toString("utf-8"), + }; + } + + if (optionalStdout != null) { + return { + pass: result.stdout.toString("utf-8") === optionalStdout, + message: () => + `Expected ${cmds.join(" ")} to output ${optionalStdout} but got ${result.stdout.toString("utf-8")}`, + }; + } + + return { + pass: true, + message: () => `Expected ${cmds.join(" ")} to fail`, + }; + }, + toThrowWithCode(fn: CallableFunction, cls: CallableFunction, code: string) { + try { + fn(); + return { + pass: false, + message: () => `Received function did not throw`, + }; + } catch (e) { + // expect(e).toBeInstanceOf(cls); + if (!(e instanceof cls)) { return { pass: false, - message: () => `Expected ${expected} to be a binary type`, + message: () => `Expected error to be instanceof ${cls.name}; got ${e.__proto__.constructor.name}`, + }; + } + + // expect(e).toHaveProperty("code"); + if (!("code" in e)) { + return { + pass: false, + message: () => `Expected error to have property 'code'; got ${e}`, + }; + } + + // expect(e.code).toEqual(code); + if (e.code !== code) { + return { + pass: false, + message: () => `Expected error to have code '${code}'; got ${e.code}`, }; } return { - pass: actual instanceof ctor, - message: () => `Expected ${actual} to be ${expected}`, + pass: true, }; } - } - }, - toRun(cmds: string[], optionalStdout?: string, expectedCode: number = 0) { - const result = Bun.spawnSync({ - cmd: [bunExe(), ...cmds], - env: bunEnv, - stdio: ["inherit", "pipe", "inherit"], - }); - - if (result.exitCode !== expectedCode) { - return { - pass: false, - message: () => `Command ${cmds.join(" ")} failed:` + "\n" + result.stdout.toString("utf-8"), - }; - } - - if (optionalStdout != null) { - return { - pass: result.stdout.toString("utf-8") === optionalStdout, - message: () => - `Expected ${cmds.join(" ")} to output ${optionalStdout} but got ${result.stdout.toString("utf-8")}`, - }; - } - - return { - pass: true, - message: () => `Expected ${cmds.join(" ")} to fail`, - }; - }, - toThrowWithCode(fn: CallableFunction, cls: CallableFunction, code: string) { - try { - fn(); - return { - pass: false, - message: () => `Received function did not throw`, - }; - } catch (e) { - // expect(e).toBeInstanceOf(cls); - if (!(e instanceof cls)) { + }, + async toThrowWithCodeAsync(fn: CallableFunction, cls: CallableFunction, code: string) { + try { + await fn(); return { pass: false, - message: () => `Expected error to be instanceof ${cls.name}; got ${e.__proto__.constructor.name}`, + message: () => `Received function did not throw`, }; - } + } catch (e) { + // expect(e).toBeInstanceOf(cls); + if (!(e instanceof cls)) { + return { + pass: false, + message: () => `Expected error to be instanceof ${cls.name}; got ${e.__proto__.constructor.name}`, + }; + } + + // expect(e).toHaveProperty("code"); + if (!("code" in e)) { + return { + pass: false, + message: () => `Expected error to have property 'code'; got ${e}`, + }; + } + + // expect(e.code).toEqual(code); + if (e.code !== code) { + return { + pass: false, + message: () => `Expected error to have code '${code}'; got ${e.code}`, + }; + } - // expect(e).toHaveProperty("code"); - if (!("code" in e)) { return { - pass: false, - message: () => `Expected error to have property 'code'; got ${e}`, + pass: true, }; } - - // expect(e.code).toEqual(code); - if (e.code !== code) { - return { - pass: false, - message: () => `Expected error to have code '${code}'; got ${e.code}`, - }; - } - - return { - pass: true, - }; - } - }, - async toThrowWithCodeAsync(fn: CallableFunction, cls: CallableFunction, code: string) { - try { - await fn(); - return { - pass: false, - message: () => `Received function did not throw`, - }; - } catch (e) { - // expect(e).toBeInstanceOf(cls); - if (!(e instanceof cls)) { - return { - pass: false, - message: () => `Expected error to be instanceof ${cls.name}; got ${e.__proto__.constructor.name}`, - }; - } - - // expect(e).toHaveProperty("code"); - if (!("code" in e)) { - return { - pass: false, - message: () => `Expected error to have property 'code'; got ${e}`, - }; - } - - // expect(e.code).toEqual(code); - if (e.code !== code) { - return { - pass: false, - message: () => `Expected error to have code '${code}'; got ${e.code}`, - }; - } - - return { - pass: true, - }; - } - }, -}); + }, + }); export function ospath(path: string) { if (isWindows) { @@ -1115,34 +1115,35 @@ String.prototype.isUTF16 = function () { return require("bun:internal-for-testing").jscInternals.isUTF16String(this); }; -expect.extend({ - toBeLatin1String(actual: unknown) { - if ((actual as string).isLatin1()) { +if (expect.extend) + expect.extend({ + toBeLatin1String(actual: unknown) { + if ((actual as string).isLatin1()) { + return { + pass: true, + message: () => `Expected ${actual} to be a Latin1 string`, + }; + } + return { - pass: true, + pass: false, message: () => `Expected ${actual} to be a Latin1 string`, }; - } + }, + toBeUTF16String(actual: unknown) { + if ((actual as string).isUTF16()) { + return { + pass: true, + message: () => `Expected ${actual} to be a UTF16 string`, + }; + } - return { - pass: false, - message: () => `Expected ${actual} to be a Latin1 string`, - }; - }, - toBeUTF16String(actual: unknown) { - if ((actual as string).isUTF16()) { return { - pass: true, + pass: false, message: () => `Expected ${actual} to be a UTF16 string`, }; - } - - return { - pass: false, - message: () => `Expected ${actual} to be a UTF16 string`, - }; - }, -}); + }, + }); interface BunHarnessTestMatchers { toBeLatin1String(): void; diff --git a/test/js/bun/http/hspec.test.ts b/test/js/bun/http/hspec.test.ts new file mode 100644 index 0000000000..b2715ff831 --- /dev/null +++ b/test/js/bun/http/hspec.test.ts @@ -0,0 +1,7 @@ +import { test, expect } from "bun:test"; +import { runTests } from "./http-spec.ts"; + +test("https://github.com/uNetworking/h1spec tests pass", async () => { + const passed = await runTests(); + expect(passed).toBe(true); +}); diff --git a/test/js/bun/http/http-spec.ts b/test/js/bun/http/http-spec.ts new file mode 100644 index 0000000000..f51126b2d7 --- /dev/null +++ b/test/js/bun/http/http-spec.ts @@ -0,0 +1,344 @@ +// https://github.com/uNetworking/h1spec +// https://github.com/oven-sh/bun/issues/14826 +// Thanks to Alex Hultman +import net from "net"; + +// Define test cases +interface TestCase { + request: string; + description: string; + expectedStatus: [number, number][]; + expectedTimeout?: boolean; +} + +const testCases: TestCase[] = [ + { + request: "G", + description: "Fragmented method", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET ", + description: "Fragmented URL 1", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello", + description: "Fragmented URL 2", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello ", + description: "Fragmented URL 3", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP", + description: "Fragmented HTTP version", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1", + description: "Fragmented request line", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r", + description: "Fragmented request line newline 1", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r\n", + description: "Fragmented request line newline 2", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r\nHos", + description: "Fragmented field name", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r\nHost:", + description: "Fragmented field value 1", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r\nHost: ", + description: "Fragmented field value 2", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r\nHost: localhost", + description: "Fragmented field value 3", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r\nHost: localhost\r", + description: "Fragmented field value 4", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r\nHost: localhost\r\n", + description: "Fragmented request", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET /hello HTTP/1.1\r\nHost: localhost\r\n\r", + description: "Fragmented request termination", + expectedStatus: [[-1, -1]], + expectedTimeout: true, + }, + { + request: "GET / \r\n\r\n", + description: "Request without HTTP version", + expectedStatus: [[400, 599]], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\nExpect: 100-continue\r\n\r\n", + description: "Request with Expect header", + expectedStatus: [ + [100, 100], + [200, 299], + ], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Valid GET request", + expectedStatus: [[200, 299]], + }, + { + request: "GET / HTTP/1.0\r\nHost: example.com\r\n\r\n", + description: "Valid GET request with HTTP/1.0", + expectedStatus: [[200, 299]], + }, + { + request: "GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Valid GET request for a proxy URL", + expectedStatus: [[200, 299]], + }, + { + request: "GET https://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Valid GET request for an https proxy URL", + expectedStatus: [[200, 299]], + }, + { + request: "GET HTTPS://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Valid GET request for an HTTPS proxy URL", + expectedStatus: [[200, 299]], + }, + { + request: "GET HTTPZ://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Invalid GET request for an HTTPS proxy URL", + expectedStatus: [[400, 499]], + }, + { + request: "GET H-TTP://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Invalid GET request for an HTTPS proxy URL", + expectedStatus: [[400, 499]], + }, + { + request: "GET HTTP://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Valid GET request for an HTTP proxy URL", + expectedStatus: [[200, 299]], + }, + { + request: "GET HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Invalid GET request target (space)", + expectedStatus: [[400, 499]], + }, + { + request: "GET ^ HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Invalid GET request target (caret)", + expectedStatus: [[400, 499]], + }, + { + request: "GET / HTTP/1.1\r\nhoSt:\texample.com\r\nempty:\r\n\r\n", + description: "Valid GET request with edge cases", + expectedStatus: [[200, 299]], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\nX-Invalid[]: test\r\n\r\n", + description: "Invalid header characters", + expectedStatus: [[400, 499]], + }, + { + request: "GET / HTTP/1.1\r\nContent-Length: 5\r\n\r\n", + description: "Missing Host header", + expectedStatus: [[400, 499]], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: -123456789123456789123456789\r\n\r\n", + description: "Overflowing negative Content-Length header", + expectedStatus: [[400, 499]], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: -1234\r\n\r\n", + description: "Negative Content-Length header", + expectedStatus: [[400, 499]], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: abc\r\n\r\n", + description: "Non-numeric Content-Length header", + expectedStatus: [[400, 499]], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\nX-Empty-Header: \r\n\r\n", + description: "Empty header value", + expectedStatus: [[200, 299]], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\nX-Bad-Control-Char: test\x07\r\n\r\n", + description: "Header containing invalid control character", + expectedStatus: [[400, 499]], + }, + { + request: "GET / HTTP/9.9\r\nHost: example.com\r\n\r\n", + description: "Invalid HTTP version", + expectedStatus: [ + [400, 499], + [500, 599], + ], + }, + { + request: "Extra lineGET / HTTP/1.1\r\nHost: example.com\r\n\r\n", + description: "Invalid prefix of request", + expectedStatus: [ + [400, 499], + [500, 599], + ], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\n\rSome-Header: Test\r\n\r\n", + description: "Invalid line ending", + expectedStatus: [[400, 499]], + }, + { + request: "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 5\r\n\r\nhello", + description: "Valid POST request with body", + expectedStatus: [ + [200, 299], + [404, 404], + ], + }, + { + request: "GET / HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\nContent-Length: 5\r\n\r\n", + description: "Conflicting Transfer-Encoding and Content-Length", + expectedStatus: [[400, 499]], + }, +]; + +export async function runTestsStandalone(host: string, port: number) { + const results = await Promise.all(testCases.map(testCase => runTestCase(testCase, host, parseInt(port, 10)))); + + const passedCount = results.filter(result => result).length; + console.log(`\n${passedCount} out of ${testCases.length} tests passed.`); + return passedCount === testCases.length; +} + +// Run all test cases in parallel +export async function runTests() { + let host, port; + + using server = Bun.serve({ + port: 0, + fetch(req) { + return new Response("Hello, world!"); + }, + }); + + host = server.url.hostname; + port = server.url.port; + return await runTestsStandalone(host, port); +} + +// Run a single test case with a 3-second timeout on reading +async function runTestCase(testCase: TestCase, host: string, port: number): Promise { + try { + const conn = new Promise((resolve, reject) => { + const client = net.createConnection({ host, port }, () => { + resolve(client); + }); + client.on("error", reject); + }); + + const client: net.Socket = await conn; + + // Send the request + client.write(Buffer.from(testCase.request)); + + // Set up a read timeout promise + const readTimeout = new Promise(resolve => { + const timeoutId = setTimeout(() => { + if (testCase.expectedTimeout) { + console.log(`✅ ${testCase.description}: Server waited successfully`); + client.destroy(); // Ensure the connection is closed on timeout + resolve(true); + } else { + console.error(`❌ ${testCase.description}: Read operation timed out`); + client.destroy(); // Ensure the connection is closed on timeout + resolve(false); + } + }, 500); + + client.on("data", data => { + // Clear the timeout if read completes + clearTimeout(timeoutId); + const response = data.toString(); + const statusCode = parseStatusCode(response); + + const isSuccess = testCase.expectedStatus.some(([min, max]) => statusCode >= min && statusCode <= max); + if (!isSuccess) { + console.log(JSON.stringify(response, null, 2)); + } + console.log( + `${isSuccess ? "✅" : "❌"} ${ + testCase.description + }: Response Status Code ${statusCode}, Expected ranges: ${JSON.stringify(testCase.expectedStatus)}`, + ); + client.destroy(); + resolve(isSuccess); + }); + + client.on("error", error => { + clearTimeout(timeoutId); + console.error(`Error in test "${testCase.description}":`, error); + resolve(false); + }); + }); + + // Wait for the read operation or timeout + return await readTimeout; + } catch (error) { + console.error(`Error in test "${testCase.description}":`, error); + return false; + } +} + +// Parse the HTTP status code from the response +function parseStatusCode(response: string): number { + const statusLine = response.split("\r\n")[0]; + const match = statusLine.match(/HTTP\/1\.\d (\d{3})/); + return match ? parseInt(match[1], 10) : 0; +} + +if (import.meta.main) { + if (process.argv.length > 2) { + await runTestsStandalone(process.argv[2], parseInt(process.argv[3], 10)); + } else { + await runTests(); + } +} diff --git a/test/js/bun/http/serve.test.ts b/test/js/bun/http/serve.test.ts index 12875cbf05..bfffb65c0e 100644 --- a/test/js/bun/http/serve.test.ts +++ b/test/js/bun/http/serve.test.ts @@ -2073,3 +2073,75 @@ it("allow custom timeout per request", async () => { expect(res.status).toBe(200); expect(res.text()).resolves.toBe("Hello, World!"); }, 20_000); + +it("#6462", async () => { + let headers: string[] = []; + using server = Bun.serve({ + port: 0, + async fetch(request) { + for (const key of request.headers.keys()) { + headers = headers.concat([[key, request.headers.get(key)]]); + } + return new Response( + JSON.stringify({ + "headers": headers, + }), + { status: 200 }, + ); + }, + }); + + const bytes = Buffer.from(`GET / HTTP/1.1\r\nConnection: close\r\nHost: ${server.hostname}\r\nTest!: test\r\n\r\n`); + const { promise, resolve } = Promise.withResolvers(); + await Bun.connect({ + port: server.port, + hostname: server.hostname, + socket: { + open(socket) { + const wrote = socket.write(bytes); + console.log("wrote", wrote); + }, + data(socket, data) { + console.log(data.toString("utf8")); + }, + close(socket) { + resolve(); + }, + }, + }); + await promise; + + expect(headers).toStrictEqual([ + ["connection", "close"], + ["host", "localhost"], + ["test!", "test"], + ]); +}); + +it("#6583", async () => { + const callback = mock(); + using server = Bun.serve({ + fetch: callback, + port: 0, + hostname: "localhost", + }); + const { promise, resolve } = Promise.withResolvers(); + await Bun.connect({ + port: server.port, + hostname: server.hostname, + tls: true, + socket: { + open(socket) { + socket.write("GET / HTTP/1.1\r\nConnection: close\r\nHost: localhost\r\n\r\n"); + }, + data(socket, data) { + console.log(data.toString("utf8")); + }, + close(socket) { + resolve(); + }, + }, + }); + await promise; + expect(callback).not.toHaveBeenCalled(); +}); diff --git a/test/js/node/http/max-header-size-fixture.ts b/test/js/node/http/max-header-size-fixture.ts index 33f4af5ec3..04954c9c47 100644 --- a/test/js/node/http/max-header-size-fixture.ts +++ b/test/js/node/http/max-header-size-fixture.ts @@ -1,6 +1,6 @@ import http from "node:http"; -if (http.maxHeaderSize !== parseInt(process.env.BUN_HTTP_MAX_HEADER_SIZE, 10)) { +if (http.maxHeaderSize !== parseInt(process.env.BUN_HTTP_MAX_HEADER_SIZE ?? "0", 10)) { throw new Error("BUN_HTTP_MAX_HEADER_SIZE is not set to the correct value"); } @@ -18,16 +18,20 @@ await fetch(`${server.url}/`, { }); try { - await fetch(`${server.url}/`, { + const response = await fetch(`${server.url}/`, { headers: { "Huge": Buffer.alloc(http.maxHeaderSize + 1024, "abc").toString(), }, }); - throw new Error("bad"); -} catch (e) { - if (e.message.includes("bad")) { - process.exit(1); + if (response.status === 431) { + throw new Error("good!!"); } - process.exit(0); + throw new Error("bad!"); +} catch (e) { + if (e instanceof Error && e.message.includes("good!!")) { + process.exit(0); + } + + throw e; } diff --git a/test/js/node/http/node-http-maxHeaderSize.test.ts b/test/js/node/http/node-http-maxHeaderSize.test.ts index 86eda1ee07..3a77e783bd 100644 --- a/test/js/node/http/node-http-maxHeaderSize.test.ts +++ b/test/js/node/http/node-http-maxHeaderSize.test.ts @@ -18,22 +18,23 @@ test("maxHeaderSize", async () => { }, }); - expect( - async () => - await fetch(`${server.url}/`, { - headers: { - "Huge": Buffer.alloc(8 * 1024, "abc").toString(), - }, - }), - ).toThrow(); - expect( - async () => - await fetch(`${server.url}/`, { - headers: { - "Huge": Buffer.alloc(512, "abc").toString(), - }, - }), - ).not.toThrow(); + { + const response = await fetch(`${server.url}/`, { + headers: { + "Huge": Buffer.alloc(8 * 1024, "abc").toString(), + }, + }); + expect(response.status).toBe(431); + } + + { + const response = await fetch(`${server.url}/`, { + headers: { + "Huge": Buffer.alloc(15 * 1024, "abc").toString(), + }, + }); + expect(response.status).toBe(431); + } } http.maxHeaderSize = 16 * 1024; { @@ -45,22 +46,23 @@ test("maxHeaderSize", async () => { }, }); - expect( - async () => - await fetch(`${server.url}/`, { - headers: { - "Huge": Buffer.alloc(15 * 1024, "abc").toString(), - }, - }), - ).not.toThrow(); - expect( - async () => - await fetch(`${server.url}/`, { - headers: { - "Huge": Buffer.alloc(17 * 1024, "abc").toString(), - }, - }), - ).toThrow(); + { + const response = await fetch(`${server.url}/`, { + headers: { + "Huge": Buffer.alloc(15 * 1024, "abc").toString(), + }, + }); + expect(response.status).toBe(200); + } + + { + const response = await fetch(`${server.url}/`, { + headers: { + "Huge": Buffer.alloc(17 * 1024, "abc").toString(), + }, + }); + expect(response.status).toBe(431); + } } http.maxHeaderSize = originalMaxHeaderSize; diff --git a/test/js/node/http/node-http-proxy.js b/test/js/node/http/node-http-proxy.js new file mode 100644 index 0000000000..410b030d5c --- /dev/null +++ b/test/js/node/http/node-http-proxy.js @@ -0,0 +1,79 @@ +import assert from "node:assert"; +import { createServer, request } from "node:http"; +import url from "node:url"; + +export async function run() { + const { promise, resolve, reject } = Promise.withResolvers(); + + const proxyServer = createServer(function (req, res) { + // Use URL object instead of deprecated url.parse + const parsedUrl = new URL(req.url, `http://${req.headers.host}`); + + const options = { + protocol: parsedUrl.protocol, + hostname: parsedUrl.hostname, + port: parsedUrl.port, + path: parsedUrl.pathname + parsedUrl.search, + method: req.method, + headers: req.headers, + }; + + const proxyRequest = request(options, function (proxyResponse) { + res.writeHead(proxyResponse.statusCode, proxyResponse.headers); + proxyResponse.pipe(res); // Use pipe instead of manual data handling + }); + + proxyRequest.on("error", error => { + console.error("Proxy Request Error:", error); + res.writeHead(500); + res.end("Proxy Error"); + }); + + req.pipe(proxyRequest); // Use pipe instead of manual data handling + }); + + proxyServer.listen(0, "localhost", async () => { + const address = proxyServer.address(); + + const options = { + protocol: "http:", + hostname: "localhost", + port: address.port, + path: "/", // Change path to / + headers: { + Host: "example.com", + "accept-encoding": "identity", + }, + }; + + const req = request(options, res => { + let data = ""; + res.on("data", chunk => { + data += chunk; + }); + res.on("end", () => { + try { + assert.strictEqual(res.statusCode, 200); + assert(data.length > 0); + assert(data.includes("This domain is for use in illustrative examples in documents")); + resolve(); + } catch (err) { + reject(err); + } + }); + }); + + req.on("error", err => { + reject(err); + }); + + req.end(); + }); + + await promise; + proxyServer.close(); +} + +if (import.meta.main) { + run().catch(console.error); +} diff --git a/test/js/node/http/node-http.test.ts b/test/js/node/http/node-http.test.ts index a3f4e2f767..af60305a8c 100644 --- a/test/js/node/http/node-http.test.ts +++ b/test/js/node/http/node-http.test.ts @@ -1,6 +1,11 @@ -// @ts-nocheck -import { bunExe } from "bun:harness"; -import { bunEnv, randomPort } from "harness"; +/** + * All new tests in this file should also run in Node.js. + * + * Do not add any tests that only run in Bun. + * + * A handful of older tests do not run in Node in this file. These tests should be updated to run in Node, or deleted. + */ +import { bunEnv, randomPort, bunExe } from "harness"; import { createTest } from "node-harness"; import { spawnSync } from "node:child_process"; import { EventEmitter, once } from "node:events"; @@ -23,10 +28,9 @@ import { tmpdir } from "node:os"; import * as path from "node:path"; import * as stream from "node:stream"; import { PassThrough } from "node:stream"; -import url from "node:url"; import * as zlib from "node:zlib"; +import { run as runHTTPProxyTest } from "./node-http-proxy.js"; const { describe, expect, it, beforeAll, afterAll, createDoneDotAll, mock, test } = createTest(import.meta.path); - function listen(server: Server, protocol: string = "http"): Promise { return new Promise((resolve, reject) => { const timeout = setTimeout(() => reject("Timed out"), 5000).unref(); @@ -772,62 +776,8 @@ describe("node:http", () => { }); }); - it("request via http proxy, issue#4295", done => { - const proxyServer = createServer(function (req, res) { - let option = url.parse(req.url); - option.host = req.headers.host; - option.headers = req.headers; - - const proxyRequest = request(option, function (proxyResponse) { - res.writeHead(proxyResponse.statusCode, proxyResponse.headers); - proxyResponse.on("data", function (chunk) { - res.write(chunk, "binary"); - }); - proxyResponse.on("end", function () { - res.end(); - }); - }); - req.on("data", function (chunk) { - proxyRequest.write(chunk, "binary"); - }); - req.on("end", function () { - proxyRequest.end(); - }); - }); - - proxyServer.listen({ port: 0 }, async (_err, hostname, port) => { - const options = { - protocol: "http:", - hostname: hostname, - port: port, - path: "http://example.com", - headers: { - Host: "example.com", - "accept-encoding": "identity", - }, - }; - - const req = request(options, res => { - let data = ""; - res.on("data", chunk => { - data += chunk; - }); - res.on("end", () => { - try { - expect(res.statusCode).toBe(200); - expect(data.length).toBeGreaterThan(0); - expect(data).toContain("This domain is for use in illustrative examples in documents"); - done(); - } catch (err) { - done(err); - } - }); - }); - req.on("error", err => { - done(err); - }); - req.end(); - }); + it("request via http proxy, issue#4295", async () => { + await runHTTPProxyTest(); }); it("should correctly stream a multi-chunk response #5320", async done => { diff --git a/test/preload.ts b/test/preload.ts index 5e472661a6..811af099b9 100644 --- a/test/preload.ts +++ b/test/preload.ts @@ -16,4 +16,4 @@ for (let key in harness.bunEnv) { process.env[key] = harness.bunEnv[key] + ""; } -Bun.$.env(process.env); +if (Bun.$?.env) Bun.$.env(process.env);