From 5b7fd9ed0ea42396ea8de49135d5d7a485f6da19 Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Thu, 4 Sep 2025 14:18:31 -0800 Subject: [PATCH 1/7] node:_http_server: implement Server.prototype.closeIdleConnections (#22234) --- packages/bun-types/bun.d.ts | 5 ++ packages/bun-usockets/src/context.c | 23 ++++--- packages/bun-usockets/src/internal/internal.h | 16 ++--- packages/bun-usockets/src/loop.c | 7 +-- packages/bun-usockets/src/socket.c | 2 +- packages/bun-uws/src/App.h | 16 +++++ packages/bun-uws/src/AsyncSocket.h | 3 + packages/bun-uws/src/HttpContext.h | 10 ++-- packages/bun-uws/src/HttpContextData.h | 7 +-- packages/bun-uws/src/HttpResponse.h | 9 ++- packages/bun-uws/src/HttpResponseData.h | 11 +++- src/bun.js/api/server.classes.ts | 4 ++ src/bun.js/api/server.zig | 60 +++++-------------- src/bun.js/bindings/NodeHTTP.cpp | 2 +- src/deps/libuwsockets.cpp | 25 ++++++-- src/deps/uws/App.zig | 6 ++ src/js/node/_http_server.ts | 4 +- ...client-keep-alive-release-before-finish.js | 39 ------------ .../test-http-flush-response-headers.js | 2 +- .../test/parallel/test-http-response-close.js | 2 +- 20 files changed, 119 insertions(+), 134 deletions(-) delete mode 100644 test/js/node/test/parallel/test-http-client-keep-alive-release-before-finish.js diff --git a/packages/bun-types/bun.d.ts b/packages/bun-types/bun.d.ts index 66414b6ac3..f6b3b2ae95 100644 --- a/packages/bun-types/bun.d.ts +++ b/packages/bun-types/bun.d.ts @@ -3846,6 +3846,11 @@ declare module "bun" { * @category HTTP & Networking */ interface Server extends Disposable { + /* + * Closes all connections connected to this server which are not sending a request or waiting for a response. Does not close the listen socket. + */ + closeIdleConnections(): void; + /** * Stop listening to prevent new connections from being accepted. * diff --git a/packages/bun-usockets/src/context.c b/packages/bun-usockets/src/context.c index 605bb6de11..6e2c3f3e18 100644 --- a/packages/bun-usockets/src/context.c +++ b/packages/bun-usockets/src/context.c @@ -153,7 +153,7 @@ void us_internal_socket_context_unlink_connecting_socket(int ssl, struct us_sock } /* We always add in the top, so we don't modify any s.next */ -void us_internal_socket_context_link_listen_socket(struct us_socket_context_t *context, struct us_listen_socket_t *ls) { +void us_internal_socket_context_link_listen_socket(int ssl, struct us_socket_context_t *context, struct us_listen_socket_t *ls) { struct us_socket_t* s = &ls->s; s->context = context; s->next = (struct us_socket_t *) context->head_listen_sockets; @@ -162,7 +162,7 @@ void us_internal_socket_context_link_listen_socket(struct us_socket_context_t *c context->head_listen_sockets->s.prev = s; } context->head_listen_sockets = ls; - us_socket_context_ref(0, context); + us_socket_context_ref(ssl, context); } void us_internal_socket_context_link_connecting_socket(int ssl, struct us_socket_context_t *context, struct us_connecting_socket_t *c) { @@ -179,7 +179,7 @@ void us_internal_socket_context_link_connecting_socket(int ssl, struct us_socket /* We always add in the top, so we don't modify any s.next */ -void us_internal_socket_context_link_socket(struct us_socket_context_t *context, struct us_socket_t *s) { +void us_internal_socket_context_link_socket(int ssl, struct us_socket_context_t *context, struct us_socket_t *s) { s->context = context; s->next = context->head_sockets; s->prev = 0; @@ -187,7 +187,7 @@ void us_internal_socket_context_link_socket(struct us_socket_context_t *context, context->head_sockets->prev = s; } context->head_sockets = s; - us_socket_context_ref(0, context); + us_socket_context_ref(ssl, context); us_internal_enable_sweep_timer(context->loop); } @@ -388,7 +388,7 @@ struct us_listen_socket_t *us_socket_context_listen(int ssl, struct us_socket_co s->flags.is_ipc = 0; s->next = 0; s->flags.allow_half_open = (options & LIBUS_SOCKET_ALLOW_HALF_OPEN); - us_internal_socket_context_link_listen_socket(context, ls); + us_internal_socket_context_link_listen_socket(ssl, context, ls); ls->socket_ext_size = socket_ext_size; @@ -423,7 +423,7 @@ struct us_listen_socket_t *us_socket_context_listen_unix(int ssl, struct us_sock s->flags.is_paused = 0; s->flags.is_ipc = 0; s->next = 0; - us_internal_socket_context_link_listen_socket(context, ls); + us_internal_socket_context_link_listen_socket(ssl, context, ls); ls->socket_ext_size = socket_ext_size; @@ -456,7 +456,7 @@ struct us_socket_t* us_socket_context_connect_resolved_dns(struct us_socket_cont socket->connect_state = NULL; socket->connect_next = NULL; - us_internal_socket_context_link_socket(context, socket); + us_internal_socket_context_link_socket(0, context, socket); return socket; } @@ -584,7 +584,7 @@ int start_connections(struct us_connecting_socket_t *c, int count) { flags->is_paused = 0; flags->is_ipc = 0; /* Link it into context so that timeout fires properly */ - us_internal_socket_context_link_socket(context, s); + us_internal_socket_context_link_socket(0, context, s); // TODO check this, specifically how it interacts with the SSL code // does this work when we create multiple sockets at once? will we need multiple SSL contexts? @@ -762,7 +762,7 @@ struct us_socket_t *us_socket_context_connect_unix(int ssl, struct us_socket_con connect_socket->flags.is_ipc = 0; connect_socket->connect_state = NULL; connect_socket->connect_next = NULL; - us_internal_socket_context_link_socket(context, connect_socket); + us_internal_socket_context_link_socket(ssl, context, connect_socket); return connect_socket; } @@ -804,12 +804,9 @@ struct us_socket_t *us_socket_context_adopt_socket(int ssl, struct us_socket_con } struct us_connecting_socket_t *c = s->connect_state; - struct us_socket_t *new_s = s; - if (ext_size != -1) { struct us_poll_t *pool_ref = &s->p; - new_s = (struct us_socket_t *) us_poll_resize(pool_ref, loop, sizeof(struct us_socket_t) + ext_size); if (c) { c->connecting_head = new_s; @@ -831,7 +828,7 @@ struct us_socket_t *us_socket_context_adopt_socket(int ssl, struct us_socket_con /* We manually ref/unref context to handle context life cycle with low-priority queue */ us_socket_context_ref(ssl, context); } else { - us_internal_socket_context_link_socket(context, new_s); + us_internal_socket_context_link_socket(ssl, context, new_s); } /* We can safely unref the old context here with can potentially be freed */ us_socket_context_unref(ssl, old_context); diff --git a/packages/bun-usockets/src/internal/internal.h b/packages/bun-usockets/src/internal/internal.h index 360a676954..7ee718e723 100644 --- a/packages/bun-usockets/src/internal/internal.h +++ b/packages/bun-usockets/src/internal/internal.h @@ -150,16 +150,12 @@ void us_internal_init_loop_ssl_data(us_loop_r loop); void us_internal_free_loop_ssl_data(us_loop_r loop); /* Socket context related */ -void us_internal_socket_context_link_socket(us_socket_context_r context, - us_socket_r s); -void us_internal_socket_context_unlink_socket(int ssl, - us_socket_context_r context, us_socket_r s); +void us_internal_socket_context_link_socket(int ssl, us_socket_context_r context, us_socket_r s); +void us_internal_socket_context_unlink_socket(int ssl, us_socket_context_r context, us_socket_r s); void us_internal_socket_after_resolve(struct us_connecting_socket_t *s); void us_internal_socket_after_open(us_socket_r s, int error); -struct us_internal_ssl_socket_t * -us_internal_ssl_socket_close(us_internal_ssl_socket_r s, int code, - void *reason); +struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(us_internal_ssl_socket_r s, int code, void *reason); int us_internal_handle_dns_results(us_loop_r loop); @@ -271,7 +267,7 @@ struct us_listen_socket_t { }; /* Listen sockets are keps in their own list */ -void us_internal_socket_context_link_listen_socket( +void us_internal_socket_context_link_listen_socket(int ssl, us_socket_context_r context, struct us_listen_socket_t *s); void us_internal_socket_context_unlink_listen_socket(int ssl, us_socket_context_r context, struct us_listen_socket_t *s); @@ -288,8 +284,7 @@ struct us_socket_context_t { struct us_socket_t *iterator; struct us_socket_context_t *prev, *next; - struct us_socket_t *(*on_open)(struct us_socket_t *, int is_client, char *ip, - int ip_length); + struct us_socket_t *(*on_open)(struct us_socket_t *, int is_client, char *ip, int ip_length); struct us_socket_t *(*on_data)(struct us_socket_t *, char *data, int length); struct us_socket_t *(*on_fd)(struct us_socket_t *, int fd); struct us_socket_t *(*on_writable)(struct us_socket_t *); @@ -301,7 +296,6 @@ struct us_socket_context_t { struct us_connecting_socket_t *(*on_connect_error)(struct us_connecting_socket_t *, int code); struct us_socket_t *(*on_socket_connect_error)(struct us_socket_t *, int code); int (*is_low_prio)(struct us_socket_t *); - }; /* Internal SSL interface */ diff --git a/packages/bun-usockets/src/loop.c b/packages/bun-usockets/src/loop.c index 2129561d02..b1605dcfab 100644 --- a/packages/bun-usockets/src/loop.c +++ b/packages/bun-usockets/src/loop.c @@ -40,7 +40,6 @@ void us_internal_enable_sweep_timer(struct us_loop_t *loop) { us_timer_set(loop->data.sweep_timer, (void (*)(struct us_timer_t *)) sweep_timer_cb, LIBUS_TIMEOUT_GRANULARITY * 1000, LIBUS_TIMEOUT_GRANULARITY * 1000); Bun__internal_ensureDateHeaderTimerIsEnabled(loop); } - } void us_internal_disable_sweep_timer(struct us_loop_t *loop) { @@ -183,7 +182,7 @@ void us_internal_handle_low_priority_sockets(struct us_loop_t *loop) { if (s->next) s->next->prev = 0; s->next = 0; - us_internal_socket_context_link_socket(s->context, s); + us_internal_socket_context_link_socket(0, s->context, s); us_poll_change(&s->p, us_socket_context(0, s)->loop, us_poll_events(&s->p) | LIBUS_SOCKET_READABLE); s->flags.low_prio_state = 2; @@ -340,7 +339,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int eof, in /* We always use nodelay */ bsd_socket_nodelay(client_fd, 1); - us_internal_socket_context_link_socket(listen_socket->s.context, s); + us_internal_socket_context_link_socket(0, listen_socket->s.context, s); listen_socket->s.context->on_open(s, 0, bsd_addr_get_ip(&addr), bsd_addr_get_ip_length(&addr)); @@ -364,7 +363,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int eof, in /* Note: if we failed a write as a socket of one loop then adopted * to another loop, this will be wrong. Absurd case though */ loop->data.last_write_failed = 0; - + s = s->context->on_writable(s); if (!s || us_socket_is_closed(0, s)) { diff --git a/packages/bun-usockets/src/socket.c b/packages/bun-usockets/src/socket.c index 8b3a8723e3..a4b02a7f42 100644 --- a/packages/bun-usockets/src/socket.c +++ b/packages/bun-usockets/src/socket.c @@ -329,7 +329,7 @@ struct us_socket_t *us_socket_from_fd(struct us_socket_context_t *ctx, int socke bsd_socket_nodelay(fd, 1); apple_no_sigpipe(fd); bsd_set_nonblocking(fd); - us_internal_socket_context_link_socket(ctx, s); + us_internal_socket_context_link_socket(0, ctx, s); return s; #endif diff --git a/packages/bun-uws/src/App.h b/packages/bun-uws/src/App.h index a68f306de4..d98389e787 100644 --- a/packages/bun-uws/src/App.h +++ b/packages/bun-uws/src/App.h @@ -298,6 +298,22 @@ public: return std::move(*this); } + /** Closes all connections connected to this server which are not sending a request or waiting for a response. Does not close the listen socket. */ + TemplatedApp &&closeIdle() { + auto context = (struct us_socket_context_t *)this->httpContext; + struct us_socket_t *s = context->head_sockets; + while (s) { + HttpResponseData *httpResponseData = HttpResponse::getHttpResponseDataS(s); + httpResponseData->shouldCloseOnceIdle = true; + struct us_socket_t *next = s->next; + if (httpResponseData->isIdle) { + us_socket_close(SSL, s, LIBUS_SOCKET_CLOSE_CODE_CLEAN_SHUTDOWN, 0); + } + s = next; + } + return std::move(*this); + } + template TemplatedApp &&ws(std::string_view pattern, WebSocketBehavior &&behavior) { /* Don't compile if alignment rules cannot be satisfied */ diff --git a/packages/bun-uws/src/AsyncSocket.h b/packages/bun-uws/src/AsyncSocket.h index e5bcf5cabb..540e7ee7f5 100644 --- a/packages/bun-uws/src/AsyncSocket.h +++ b/packages/bun-uws/src/AsyncSocket.h @@ -386,6 +386,9 @@ public: /* We do not need to care for buffering here, write does that */ return {0, true}; } + if (length == 0) { + return {written, failed}; + } } /* We should only return with new writes, not things written to cork already */ diff --git a/packages/bun-uws/src/HttpContext.h b/packages/bun-uws/src/HttpContext.h index 0fc7cf9f56..c0866ffdde 100644 --- a/packages/bun-uws/src/HttpContext.h +++ b/packages/bun-uws/src/HttpContext.h @@ -137,10 +137,6 @@ private: return (HttpContextData *) us_socket_context_ext(SSL, getSocketContext()); } - static HttpContextData *getSocketContextDataS(us_socket_t *s) { - return (HttpContextData *) us_socket_context_ext(SSL, getSocketContext(s)); - } - /* Init the HttpContext by registering libusockets event handlers */ HttpContext *init() { @@ -247,6 +243,7 @@ private: /* Mark that we are inside the parser now */ httpContextData->flags.isParsingHttp = true; + httpResponseData->isIdle = false; // clients need to know the cursor after http parse, not servers! // how far did we read then? we need to know to continue with websocket parsing data? or? @@ -398,6 +395,7 @@ private: /* Timeout on uncork failure */ auto [written, failed] = ((AsyncSocket *) returnedData)->uncork(); if (written > 0 || failed) { + httpResponseData->isIdle = true; /* All Http sockets timeout by this, and this behavior match the one in HttpResponse::cork */ ((HttpResponse *) s)->resetTimeout(); } @@ -642,6 +640,10 @@ public: }, priority); } + static HttpContextData *getSocketContextDataS(us_socket_t *s) { + return (HttpContextData *) us_socket_context_ext(SSL, getSocketContext(s)); + } + /* Listen to port using this HttpContext */ us_listen_socket_t *listen(const char *host, int port, int options) { int error = 0; diff --git a/packages/bun-uws/src/HttpContextData.h b/packages/bun-uws/src/HttpContextData.h index 49c094c64e..48ec202dd1 100644 --- a/packages/bun-uws/src/HttpContextData.h +++ b/packages/bun-uws/src/HttpContextData.h @@ -63,7 +63,6 @@ private: OnSocketClosedCallback onSocketClosed = nullptr; OnClientErrorCallback onClientError = nullptr; - HttpFlags flags; uint64_t maxHeaderSize = 0; // 0 means no limit // TODO: SNI @@ -73,10 +72,8 @@ private: filterHandlers.clear(); } - public: - bool isAuthorized() const { - return flags.isAuthorized; - } +public: + HttpFlags flags; }; } diff --git a/packages/bun-uws/src/HttpResponse.h b/packages/bun-uws/src/HttpResponse.h index 8a92248960..03c82ca77d 100644 --- a/packages/bun-uws/src/HttpResponse.h +++ b/packages/bun-uws/src/HttpResponse.h @@ -50,6 +50,11 @@ public: HttpResponseData *getHttpResponseData() { return (HttpResponseData *) Super::getAsyncSocketData(); } + + static HttpResponseData *getHttpResponseDataS(us_socket_t *s) { + return (HttpResponseData *) us_socket_ext(SSL, s); + } + void setTimeout(uint8_t seconds) { auto* data = getHttpResponseData(); data->idleTimeout = seconds; @@ -132,7 +137,7 @@ public: /* Terminating 0 chunk */ Super::write("0\r\n\r\n", 5); - httpResponseData->markDone(); + httpResponseData->markDone(this); /* We need to check if we should close this socket here now */ if (!Super::isCorked()) { @@ -198,7 +203,7 @@ public: /* Remove onAborted function if we reach the end */ if (httpResponseData->offset == totalSize) { - httpResponseData->markDone(); + httpResponseData->markDone(this); /* We need to check if we should close this socket here now */ if (!Super::isCorked()) { diff --git a/packages/bun-uws/src/HttpResponseData.h b/packages/bun-uws/src/HttpResponseData.h index eda5a15b2c..26c3428049 100644 --- a/packages/bun-uws/src/HttpResponseData.h +++ b/packages/bun-uws/src/HttpResponseData.h @@ -22,11 +22,15 @@ #include "HttpParser.h" #include "AsyncSocketData.h" #include "ProxyParser.h" +#include "HttpContext.h" #include "MoveOnlyFunction.h" namespace uWS { +template +struct HttpContext; + template struct HttpResponseData : AsyncSocketData, HttpParser { template friend struct HttpResponse; @@ -38,7 +42,7 @@ struct HttpResponseData : AsyncSocketData, HttpParser { using OnDataCallback = void (*)(uWS::HttpResponse* response, const char* chunk, size_t chunk_length, bool, void*); /* When we are done with a response we mark it like so */ - void markDone() { + void markDone(uWS::HttpResponse *uwsRes) { onAborted = nullptr; /* Also remove onWritable so that we do not emit when draining behind the scenes. */ onWritable = nullptr; @@ -50,6 +54,9 @@ struct HttpResponseData : AsyncSocketData, HttpParser { /* We are done with this request */ this->state &= ~HttpResponseData::HTTP_RESPONSE_PENDING; + + HttpResponseData *httpResponseData = uwsRes->getHttpResponseData(); + httpResponseData->isIdle = true; } /* Caller of onWritable. It is possible onWritable calls markDone so we need to borrow it. */ @@ -101,6 +108,8 @@ struct HttpResponseData : AsyncSocketData, HttpParser { uint8_t state = 0; uint8_t idleTimeout = 10; // default HTTP_TIMEOUT 10 seconds bool fromAncientRequest = false; + bool isIdle = true; + bool shouldCloseOnceIdle = false; #ifdef UWS_WITH_PROXY diff --git a/src/bun.js/api/server.classes.ts b/src/bun.js/api/server.classes.ts index 44cb521d86..ccbd36e8fe 100644 --- a/src/bun.js/api/server.classes.ts +++ b/src/bun.js/api/server.classes.ts @@ -29,6 +29,10 @@ function generate(name) { fn: "dispose", length: 0, }, + closeIdleConnections: { + fn: "closeIdleConnections", + length: 0, + }, stop: { fn: "doStop", length: 1, diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 38bec4a432..f26f455258 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -741,12 +741,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d } } - pub fn onUpgrade( - this: *ThisServer, - globalThis: *jsc.JSGlobalObject, - object: jsc.JSValue, - optional: ?JSValue, - ) bun.JSError!JSValue { + pub fn onUpgrade(this: *ThisServer, globalThis: *jsc.JSGlobalObject, object: jsc.JSValue, optional: ?JSValue) bun.JSError!JSValue { if (this.config.websocket == null) { return globalThis.throwInvalidArguments("To enable websocket support, set the \"websocket\" object in Bun.serve({})", .{}); } @@ -1132,11 +1127,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d return this.js_value.get(); } - pub fn onFetch( - this: *ThisServer, - ctx: *jsc.JSGlobalObject, - callframe: *jsc.CallFrame, - ) bun.JSError!jsc.JSValue { + pub fn onFetch(this: *ThisServer, ctx: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { jsc.markBinding(@src()); if (this.config.onRequest == .zero) { @@ -1253,6 +1244,14 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d return jsc.JSPromise.resolvedPromiseValue(ctx, response_value); } + pub fn closeIdleConnections(this: *ThisServer, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + _ = globalObject; + _ = callframe; + if (this.app == null) return .js_undefined; + this.app.?.closeIdleConnections(); + return .js_undefined; + } + pub fn stopFromJS(this: *ThisServer, abruptly: ?JSValue) jsc.JSValue { const rc = this.getAllClosedPromise(this.globalThis); @@ -1280,10 +1279,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d return .js_undefined; } - pub fn getPort( - this: *ThisServer, - _: *jsc.JSGlobalObject, - ) jsc.JSValue { + pub fn getPort(this: *ThisServer, _: *jsc.JSGlobalObject) jsc.JSValue { switch (this.config.address) { .unix => return .js_undefined, else => {}, @@ -1412,10 +1408,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d return bun.String.static(if (ssl_enabled) "https" else "http").toJS(globalThis); } - pub fn getDevelopment( - _: *ThisServer, - _: *jsc.JSGlobalObject, - ) jsc.JSValue { + pub fn getDevelopment(_: *ThisServer, _: *jsc.JSGlobalObject) jsc.JSValue { return jsc.JSValue.jsBoolean(debug_mode); } @@ -1989,11 +1982,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d } } - pub fn onNodeHTTPRequest( - this: *ThisServer, - req: *uws.Request, - resp: *App.Response, - ) void { + pub fn onNodeHTTPRequest(this: *ThisServer, req: *uws.Request, resp: *App.Response) void { jsc.markBinding(@src()); onNodeHTTPRequestWithUpgradeCtx(this, req, resp, null); } @@ -2073,11 +2062,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d ctx.toAsync(req, prepared.request_object); } - pub fn onRequest( - this: *ThisServer, - req: *uws.Request, - resp: *App.Response, - ) void { + pub fn onRequest(this: *ThisServer, req: *uws.Request, resp: *App.Response) void { var should_deinit_context = false; const prepared = this.prepareJsRequestContext(req, resp, &should_deinit_context, true, null) orelse return; @@ -2094,14 +2079,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d this.handleRequest(&should_deinit_context, prepared, req, response_value); } - pub fn onRequestFromSaved( - this: *ThisServer, - req: SavedRequest.Union, - resp: *App.Response, - callback: JSValue, - comptime arg_count: comptime_int, - extra_args: [arg_count]JSValue, - ) void { + pub fn onRequestFromSaved(this: *ThisServer, req: SavedRequest.Union, resp: *App.Response, callback: JSValue, comptime arg_count: comptime_int, extra_args: [arg_count]JSValue) void { const prepared: PreparedRequest = switch (req) { .stack => |r| this.prepareJsRequestContext(r, resp, null, true, null) orelse return, .saved => |data| .{ @@ -2291,13 +2269,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d server.handleRequest(&should_deinit_context, prepared, req, response_value); } - pub fn onWebSocketUpgrade( - this: *ThisServer, - resp: *App.Response, - req: *uws.Request, - upgrade_ctx: *uws.SocketContext, - id: usize, - ) void { + pub fn onWebSocketUpgrade(this: *ThisServer, resp: *App.Response, req: *uws.Request, upgrade_ctx: *uws.SocketContext, id: usize) void { jsc.markBinding(@src()); if (id == 1) { // This is actually a UserRoute if id is 1 so it's safe to cast diff --git a/src/bun.js/bindings/NodeHTTP.cpp b/src/bun.js/bindings/NodeHTTP.cpp index 3a16e75f42..29eca691c0 100644 --- a/src/bun.js/bindings/NodeHTTP.cpp +++ b/src/bun.js/bindings/NodeHTTP.cpp @@ -150,7 +150,7 @@ public: if (!context) return false; auto* data = (uWS::HttpContextData*)us_socket_context_ext(is_ssl, context); if (!data) return false; - return data->isAuthorized(); + return data->flags.isAuthorized; } ~JSNodeHTTPServerSocket() { diff --git a/src/deps/libuwsockets.cpp b/src/deps/libuwsockets.cpp index 1efae06d80..07bcff0e42 100644 --- a/src/deps/libuwsockets.cpp +++ b/src/deps/libuwsockets.cpp @@ -377,6 +377,19 @@ extern "C" } } + void uws_app_close_idle(int ssl, uws_app_t *app) + { + if (ssl) + { + uWS::SSLApp *uwsApp = (uWS::SSLApp *)app; + uwsApp->closeIdle(); + } + else + { + uWS::App *uwsApp = (uWS::App *)app; + uwsApp->closeIdle(); + } + } void uws_app_set_on_clienterror(int ssl, uws_app_t *app, void (*handler)(void *user_data, int is_ssl, struct us_socket_t *rawSocket, uint8_t errorCode, char *rawPacket, int rawPacketLength), void *user_data) { @@ -1277,7 +1290,7 @@ extern "C" auto *data = uwsRes->getHttpResponseData(); data->offset = offset; data->state |= uWS::HttpResponseData::HTTP_END_CALLED; - data->markDone(); + data->markDone(uwsRes); uwsRes->resetTimeout(); } else @@ -1285,8 +1298,8 @@ extern "C" uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; auto *data = uwsRes->getHttpResponseData(); data->offset = offset; - data->state |= uWS::HttpResponseData::HTTP_END_CALLED; - data->markDone(); + data->state |= uWS::HttpResponseData::HTTP_END_CALLED; + data->markDone(uwsRes); uwsRes->resetTimeout(); } } @@ -1328,7 +1341,7 @@ extern "C" uwsRes->AsyncSocket::write("\r\n", 2); } data->state |= uWS::HttpResponseData::HTTP_END_CALLED; - data->markDone(); + data->markDone(uwsRes); uwsRes->resetTimeout(); } else @@ -1350,7 +1363,7 @@ extern "C" uwsRes->AsyncSocket::write("\r\n", 2); } data->state |= uWS::HttpResponseData::HTTP_END_CALLED; - data->markDone(); + data->markDone(uwsRes); uwsRes->resetTimeout(); } } @@ -1793,7 +1806,7 @@ __attribute__((callback (corker, ctx))) uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; uwsRes->flushHeaders(); } else { - uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; + uWS::HttpResponse *uwsRes = (uWS::HttpResponse *)res; uwsRes->flushHeaders(); } } diff --git a/src/deps/uws/App.zig b/src/deps/uws/App.zig index 7280854773..be6c1950f6 100644 --- a/src/deps/uws/App.zig +++ b/src/deps/uws/App.zig @@ -43,9 +43,14 @@ pub fn NewApp(comptime ssl: bool) type { return c.uws_app_close(ssl_flag, @as(*uws_app_s, @ptrCast(this))); } + pub fn closeIdleConnections(this: *ThisApp) void { + return c.uws_app_close_idle(ssl_flag, @as(*uws_app_s, @ptrCast(this))); + } + pub fn create(opts: BunSocketContextOptions) ?*ThisApp { return @ptrCast(c.uws_create_app(ssl_flag, opts)); } + pub fn destroy(app: *ThisApp) void { return c.uws_app_destroy(ssl_flag, @as(*uws_app_s, @ptrCast(app))); } @@ -393,6 +398,7 @@ pub const c = struct { pub const uws_missing_server_handler = ?*const fn ([*c]const u8, ?*anyopaque) callconv(.C) void; pub extern fn uws_app_close(ssl: i32, app: *uws_app_s) void; + pub extern fn uws_app_close_idle(ssl: i32, app: *uws_app_s) void; pub extern fn uws_app_set_on_clienterror(ssl: c_int, app: *uws_app_s, handler: *const fn (*anyopaque, c_int, *us_socket_t, u8, ?[*]u8, c_int) callconv(.C) void, user_data: *anyopaque) void; pub extern fn uws_create_app(ssl: i32, options: BunSocketContextOptions) ?*uws_app_t; pub extern fn uws_app_destroy(ssl: i32, app: *uws_app_t) void; diff --git a/src/js/node/_http_server.ts b/src/js/node/_http_server.ts index 231b5e4bab..aef1343eea 100644 --- a/src/js/node/_http_server.ts +++ b/src/js/node/_http_server.ts @@ -302,7 +302,8 @@ Server.prototype.closeAllConnections = function () { }; Server.prototype.closeIdleConnections = function () { - // not actually implemented + const server = this[serverSymbol]; + server.closeIdleConnections(); }; Server.prototype.close = function (optionalCallback?) { @@ -318,6 +319,7 @@ Server.prototype.close = function (optionalCallback?) { } if (typeof optionalCallback === "function") setCloseCallback(this, optionalCallback); this.listening = false; + server.closeIdleConnections(); server.stop(); }; diff --git a/test/js/node/test/parallel/test-http-client-keep-alive-release-before-finish.js b/test/js/node/test/parallel/test-http-client-keep-alive-release-before-finish.js deleted file mode 100644 index e6e0bac1bb..0000000000 --- a/test/js/node/test/parallel/test-http-client-keep-alive-release-before-finish.js +++ /dev/null @@ -1,39 +0,0 @@ -'use strict'; -const common = require('../common'); -const http = require('http'); - -const server = http.createServer((req, res) => { - res.end(); -}).listen(0, common.mustCall(() => { - const agent = new http.Agent({ - maxSockets: 1, - keepAlive: true - }); - - const port = server.address().port; - - const post = http.request({ - agent, - method: 'POST', - port, - }, common.mustCall((res) => { - res.resume(); - })); - - // What happens here is that the server `end`s the response before we send - // `something`, and the client thought that this is a green light for sending - // next GET request - post.write(Buffer.alloc(16 * 1024, 'X')); - setTimeout(() => { - post.end('something'); - }, 100); - - http.request({ - agent, - method: 'GET', - port, - }, common.mustCall((res) => { - server.close(); - res.connection.end(); - })).end(); -})); diff --git a/test/js/node/test/parallel/test-http-flush-response-headers.js b/test/js/node/test/parallel/test-http-flush-response-headers.js index 1745d42285..0f0a1387b5 100644 --- a/test/js/node/test/parallel/test-http-flush-response-headers.js +++ b/test/js/node/test/parallel/test-http-flush-response-headers.js @@ -22,6 +22,6 @@ server.listen(0, common.localhostIPv4, function() { function onResponse(res) { assert.strictEqual(res.headers.foo, 'bar'); res.destroy(); - server.closeAllConnections(); + server.close(); } }); diff --git a/test/js/node/test/parallel/test-http-response-close.js b/test/js/node/test/parallel/test-http-response-close.js index 2ec1c260e9..848d316d8a 100644 --- a/test/js/node/test/parallel/test-http-response-close.js +++ b/test/js/node/test/parallel/test-http-response-close.js @@ -43,7 +43,7 @@ const assert = require('assert'); assert.strictEqual(res.destroyed, false); res.on('close', common.mustCall(() => { assert.strictEqual(res.destroyed, true); - server.closeAllConnections(); + server.close(); })); }) ); From 1503715c0e65570af5ad23758d8b2de8fda369a9 Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Wed, 3 Sep 2025 19:12:26 -0700 Subject: [PATCH 2/7] ok --- src/bun.js/webcore/fetch.zig | 46 ++++--- src/http.zig | 136 +++++++++++++-------- src/http/Signals.zig | 5 +- test/js/web/fetch/fetch.upgrade.test.ts | 63 ++++++++++ test/js/web/fetch/websocket.helpers.ts | 156 ++++++++++++++++++++++++ 5 files changed, 339 insertions(+), 67 deletions(-) create mode 100644 test/js/web/fetch/fetch.upgrade.test.ts create mode 100644 test/js/web/fetch/websocket.helpers.ts diff --git a/src/bun.js/webcore/fetch.zig b/src/bun.js/webcore/fetch.zig index a0c8d7dc90..c7d421b769 100644 --- a/src/bun.js/webcore/fetch.zig +++ b/src/bun.js/webcore/fetch.zig @@ -108,6 +108,7 @@ pub const FetchTasklet = struct { // custom checkServerIdentity check_server_identity: jsc.Strong.Optional = .empty, reject_unauthorized: bool = true, + is_websocket_upgrade: bool = false, // Custom Hostname hostname: ?[]u8 = null, is_waiting_body: bool = false, @@ -1069,6 +1070,7 @@ pub const FetchTasklet = struct { .memory_reporter = fetch_options.memory_reporter, .check_server_identity = fetch_options.check_server_identity, .reject_unauthorized = fetch_options.reject_unauthorized, + .is_websocket_upgrade = fetch_options.is_websocket_upgrade, }; fetch_tasklet.signals = fetch_tasklet.signal_store.to(); @@ -1201,19 +1203,23 @@ pub const FetchTasklet = struct { // dont have backpressure so we will schedule the data to be written // if we have backpressure the onWritable will drain the buffer needs_schedule = stream_buffer.isEmpty(); - //16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n - var formated_size_buffer: [18]u8 = undefined; - const formated_size = std.fmt.bufPrint( - formated_size_buffer[0..], - "{x}\r\n", - .{data.len}, - ) catch |err| switch (err) { - error.NoSpaceLeft => unreachable, - }; - bun.handleOom(stream_buffer.ensureUnusedCapacity(formated_size.len + data.len + 2)); - stream_buffer.writeAssumeCapacity(formated_size); - stream_buffer.writeAssumeCapacity(data); - stream_buffer.writeAssumeCapacity("\r\n"); + if (this.is_websocket_upgrade) { + bun.handleOom(stream_buffer.write(data)); + } else { + //16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n + var formated_size_buffer: [18]u8 = undefined; + const formated_size = std.fmt.bufPrint( + formated_size_buffer[0..], + "{x}\r\n", + .{data.len}, + ) catch |err| switch (err) { + error.NoSpaceLeft => unreachable, + }; + bun.handleOom(stream_buffer.ensureUnusedCapacity(formated_size.len + data.len + 2)); + stream_buffer.writeAssumeCapacity(formated_size); + stream_buffer.writeAssumeCapacity(data); + stream_buffer.writeAssumeCapacity("\r\n"); + } // pause the stream if we hit the high water mark return stream_buffer.size() >= highWaterMark; @@ -1271,6 +1277,7 @@ pub const FetchTasklet = struct { check_server_identity: jsc.Strong.Optional = .empty, unix_socket_path: ZigString.Slice, ssl_config: ?*SSLConfig = null, + is_websocket_upgrade: bool = false, }; pub fn queue( @@ -1494,6 +1501,7 @@ pub fn Bun__fetch_( var memory_reporter = bun.handleOom(bun.default_allocator.create(bun.MemoryReportingAllocator)); // used to clean up dynamically allocated memory on error (a poor man's errdefer) var is_error = false; + var is_websocket_upgrade = false; var allocator = memory_reporter.wrap(bun.default_allocator); errdefer bun.default_allocator.destroy(memory_reporter); defer { @@ -2198,6 +2206,15 @@ pub fn Bun__fetch_( } } + if (headers_.fastGet(bun.webcore.FetchHeaders.HTTPHeaderName.Upgrade)) |_upgrade| { + const upgrade = _upgrade.toSlice(bun.default_allocator); + defer upgrade.deinit(); + const slice = upgrade.slice(); + if (bun.strings.eqlComptime(slice, "websocket")) { + is_websocket_upgrade = true; + } + } + break :extract_headers Headers.from(headers_, allocator, .{ .body = body.getAnyBlob() }) catch |err| bun.handleOom(err); } @@ -2333,7 +2350,7 @@ pub fn Bun__fetch_( } } - if (!method.hasRequestBody() and body.hasBody()) { + if (!method.hasRequestBody() and body.hasBody() and !is_websocket_upgrade) { const err = globalThis.toTypeError(.INVALID_ARG_VALUE, fetch_error_unexpected_body, .{}); is_error = true; return JSPromise.dangerouslyCreateRejectedPromiseValueWithoutNotifyingVM(globalThis, err); @@ -2651,6 +2668,7 @@ pub fn Bun__fetch_( .ssl_config = ssl_config, .hostname = hostname, .memory_reporter = memory_reporter, + .is_websocket_upgrade = is_websocket_upgrade, .check_server_identity = if (check_server_identity.isEmptyOrUndefinedOrNull()) .empty else .create(check_server_identity, globalThis), .unix_socket_path = unix_socket_path, }, diff --git a/src/http.zig b/src/http.zig index f4119f2a49..1d8eef5037 100644 --- a/src/http.zig +++ b/src/http.zig @@ -405,7 +405,9 @@ pub const Flags = packed struct(u16) { is_preconnect_only: bool = false, is_streaming_request_body: bool = false, defer_fail_until_connecting_is_complete: bool = false, - _padding: u5 = 0, + is_websockets: bool = false, + websocket_upgraded: bool = false, + _padding: u3 = 0, }; // TODO: reduce the size of this struct @@ -592,6 +594,11 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { hashHeaderConst("Accept-Encoding") => { override_accept_encoding = true; }, + hashHeaderConst("Upgrade") => { + if (std.ascii.eqlIgnoreCase(this.headerStr(header_values[i]), "websocket")) { + this.flags.is_websockets = true; + } + }, hashHeaderConst(chunked_encoded_header.name) => { // We don't want to override chunked encoding header if it was set by the user add_transfer_encoding = false; @@ -1019,11 +1026,14 @@ fn writeToStreamUsingBuffer(this: *HTTPClient, comptime is_ssl: bool, socket: Ne // no data to send so we are done return false; } - pub fn writeToStream(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket, data: []const u8) void { log("flushStream", .{}); var stream = &this.state.original_request_body.stream; const stream_buffer = stream.buffer orelse return; + if (this.flags.is_websockets and !this.flags.websocket_upgraded) { + // cannot drain yet, websocket is waiting for upgrade + return; + } const buffer = stream_buffer.acquire(); const wasEmpty = buffer.isEmpty() and data.len == 0; if (wasEmpty and stream.ended) { @@ -1324,56 +1334,78 @@ pub fn handleOnDataHeaders( ) void { log("handleOnDataHeaders", .{}); var to_read = incoming_data; - var amount_read: usize = 0; - var needs_move = true; - if (this.state.response_message_buffer.list.items.len > 0) { - // this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating - bun.handleOom(this.state.response_message_buffer.appendSliceExact(incoming_data)); - to_read = this.state.response_message_buffer.list.items; - needs_move = false; - } - // we reset the pending_response each time wich means that on parse error this will be always be empty - this.state.pending_response = picohttp.Response{}; - - // minimal http/1.1 request size is 16 bytes without headers and 26 with Host header - // if is less than 16 will always be a ShortRead - if (to_read.len < 16) { - log("handleShortRead", .{}); - this.handleShortRead(is_ssl, incoming_data, socket, needs_move); - return; - } - - var response = picohttp.Response.parseParts( - to_read, - &shared_response_headers_buf, - &amount_read, - ) catch |err| { - switch (err) { - error.ShortRead => { - this.handleShortRead(is_ssl, incoming_data, socket, needs_move); - }, - else => { - this.closeAndFail(err, is_ssl, socket); - }, + while (true) { + var amount_read: usize = 0; + var needs_move = true; + if (this.state.response_message_buffer.list.items.len > 0) { + // this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating + bun.handleOom(this.state.response_message_buffer.appendSliceExact(incoming_data)); + to_read = this.state.response_message_buffer.list.items; + needs_move = false; } - return; - }; - // we save the successful parsed response - this.state.pending_response = response; + // we reset the pending_response each time wich means that on parse error this will be always be empty + this.state.pending_response = picohttp.Response{}; - const body_buf = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..]; - // handle the case where we have a 100 Continue - if (response.status_code >= 100 and response.status_code < 200) { - log("information headers", .{}); - // we still can have the 200 OK in the same buffer sometimes - if (body_buf.len > 0) { - log("information headers with body", .{}); - this.onData(is_ssl, body_buf, ctx, socket); + // minimal http/1.1 request size is 16 bytes without headers and 26 with Host header + // if is less than 16 will always be a ShortRead + if (to_read.len < 16) { + log("handleShortRead", .{}); + this.handleShortRead(is_ssl, incoming_data, socket, needs_move); + return; } - return; + + const response = picohttp.Response.parseParts( + to_read, + &shared_response_headers_buf, + &amount_read, + ) catch |err| { + switch (err) { + error.ShortRead => { + this.handleShortRead(is_ssl, incoming_data, socket, needs_move); + }, + else => { + this.closeAndFail(err, is_ssl, socket); + }, + } + return; + }; + + // we save the successful parsed response + this.state.pending_response = response; + + to_read = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..]; + + if (response.status_code == 101) { + if (!this.flags.is_websockets) { + // we cannot upgrade to websocket because the client did not request it! + this.closeAndFail(error.UnrequestedUpgrade, is_ssl, socket); + return; + } + // special case for websocket upgrade + this.flags.is_websockets = true; + this.flags.websocket_upgraded = true; + if (this.signals.upgraded) |upgraded| { + upgraded.store(true, .monotonic); + } + // start draining the request body + this.flushStream(is_ssl, socket); + break; + } + + // handle the case where we have a 100 Continue + if (response.status_code >= 100 and response.status_code < 200) { + log("information headers", .{}); + // we still can have the 200 OK in the same buffer sometimes + // 1XX responses MUST NOT include a message-body, therefore we need to continue parsing + + continue; + } + + break; } + var response = this.state.pending_response.?; const should_continue = this.handleResponseMetadata( &response, ) catch |err| { @@ -1409,14 +1441,14 @@ pub fn handleOnDataHeaders( if (this.flags.proxy_tunneling and this.proxy_tunnel == null) { // we are proxing we dont need to cloneMetadata yet - this.startProxyHandshake(is_ssl, socket, body_buf); + this.startProxyHandshake(is_ssl, socket, to_read); return; } // we have body data incoming so we clone metadata and keep going this.cloneMetadata(); - if (body_buf.len == 0) { + if (to_read.len == 0) { // no body data yet, but we can report the headers if (this.signals.get(.header_progress)) { this.progressUpdate(is_ssl, ctx, socket); @@ -1426,7 +1458,7 @@ pub fn handleOnDataHeaders( if (this.state.response_stage == .body) { { - const report_progress = this.handleResponseBody(body_buf, true) catch |err| { + const report_progress = this.handleResponseBody(to_read, true) catch |err| { this.closeAndFail(err, is_ssl, socket); return; }; @@ -1439,7 +1471,7 @@ pub fn handleOnDataHeaders( } else if (this.state.response_stage == .body_chunk) { this.setTimeout(socket, 5); { - const report_progress = this.handleResponseBodyChunkedEncoding(body_buf) catch |err| { + const report_progress = this.handleResponseBodyChunkedEncoding(to_read) catch |err| { this.closeAndFail(err, is_ssl, socket); return; }; @@ -2149,7 +2181,7 @@ pub fn handleResponseMetadata( // [...] cannot contain a message body or trailer section. // therefore in these cases set content-length to 0, so the response body is always ignored // and is not waited for (which could cause a timeout) - if ((response.status_code >= 100 and response.status_code < 200) or response.status_code == 204 or response.status_code == 304) { + if ((response.status_code >= 100 and response.status_code < 200 and response.status_code != 101) or response.status_code == 204 or response.status_code == 304) { this.state.content_length = 0; } @@ -2416,7 +2448,7 @@ pub fn handleResponseMetadata( log("handleResponseMetadata: content_length is null and transfer_encoding {}", .{this.state.transfer_encoding}); } - if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events)) { + if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.websocket_upgraded)) { return ShouldContinue.continue_streaming; } else { return ShouldContinue.finished; diff --git a/src/http/Signals.zig b/src/http/Signals.zig index 78531e7f41..bf8d1d8360 100644 --- a/src/http/Signals.zig +++ b/src/http/Signals.zig @@ -4,8 +4,9 @@ header_progress: ?*std.atomic.Value(bool) = null, body_streaming: ?*std.atomic.Value(bool) = null, aborted: ?*std.atomic.Value(bool) = null, cert_errors: ?*std.atomic.Value(bool) = null, +upgraded: ?*std.atomic.Value(bool) = null, pub fn isEmpty(this: *const Signals) bool { - return this.aborted == null and this.body_streaming == null and this.header_progress == null and this.cert_errors == null; + return this.aborted == null and this.body_streaming == null and this.header_progress == null and this.cert_errors == null and this.upgraded == null; } pub const Store = struct { @@ -13,12 +14,14 @@ pub const Store = struct { body_streaming: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), aborted: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), cert_errors: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), + upgraded: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), pub fn to(this: *Store) Signals { return .{ .header_progress = &this.header_progress, .body_streaming = &this.body_streaming, .aborted = &this.aborted, .cert_errors = &this.cert_errors, + .upgraded = &this.upgraded, }; } }; diff --git a/test/js/web/fetch/fetch.upgrade.test.ts b/test/js/web/fetch/fetch.upgrade.test.ts new file mode 100644 index 0000000000..243bea7762 --- /dev/null +++ b/test/js/web/fetch/fetch.upgrade.test.ts @@ -0,0 +1,63 @@ +import { describe, expect, test } from "bun:test"; +import { encodeTextFrame, encodeCloseFrame, decodeFrames, upgradeHeaders } from "./websocket.helpers"; + +describe("fetch upgrade", () => { + test("should upgrade to websocket", async () => { + const serverMessages: string[] = []; + using server = Bun.serve({ + port: 3000, + fetch(req) { + if (server.upgrade(req)) return; + return new Response("Hello World"); + }, + websocket: { + open(ws) { + ws.send("Hello World"); + }, + message(ws, message) { + serverMessages.push(message as string); + }, + close(ws) { + serverMessages.push("close"); + }, + }, + }); + const res = await fetch(server.url, { + method: "GET", + headers: upgradeHeaders(), + async *body() { + yield encodeTextFrame("hello"); + yield encodeTextFrame("world"); + yield encodeTextFrame("bye"); + yield encodeCloseFrame(); + }, + }); + expect(res.status).toBe(101); + expect(res.headers.get("upgrade")).toBe("websocket"); + expect(res.headers.get("sec-websocket-accept")).toBeString(); + expect(res.headers.get("connection")).toBe("Upgrade"); + + const clientMessages: string[] = []; + const { promise, resolve } = Promise.withResolvers(); + const reader = res.body!.getReader(); + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + for (const msg of decodeFrames(Buffer.from(value))) { + if (typeof msg === "string") { + clientMessages.push(msg); + } else { + clientMessages.push(msg.type); + } + + if (msg.type === "close") { + resolve(); + } + } + } + await promise; + expect(serverMessages).toEqual(["hello", "world", "bye", "close"]); + expect(clientMessages).toEqual(["Hello World", "close"]); + }); +}); diff --git a/test/js/web/fetch/websocket.helpers.ts b/test/js/web/fetch/websocket.helpers.ts new file mode 100644 index 0000000000..6425735039 --- /dev/null +++ b/test/js/web/fetch/websocket.helpers.ts @@ -0,0 +1,156 @@ +import { createHash, randomBytes } from "node:crypto"; + +// RFC 6455 magic GUID +const WS_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +function makeKey() { + return randomBytes(16).toString("base64"); +} + +function acceptFor(key) { + return createHash("sha1") + .update(key + WS_GUID) + .digest("base64"); +} + +export function encodeCloseFrame(code = 1000, reason = "") { + const reasonBuf = Buffer.from(reason, "utf8"); + const payloadLen = 2 + reasonBuf.length; // 2 bytes for code + reason + const header = []; + let headerLen = 2; + if (payloadLen < 126) { + // masked bit (0x80) + length + header.push(0x88, 0x80 | payloadLen); + } else if (payloadLen <= 0xffff) { + headerLen += 2; + header.push(0x88, 0x80 | 126, payloadLen >> 8, payloadLen & 0xff); + } else { + throw new Error("Close reason too long"); + } + + const mask = randomBytes(4); + const buf = Buffer.alloc(headerLen + 4 + payloadLen); + Buffer.from(header).copy(buf, 0); + mask.copy(buf, headerLen); + + // write code + reason + const unmasked = Buffer.alloc(payloadLen); + unmasked.writeUInt16BE(code, 0); + reasonBuf.copy(unmasked, 2); + + // apply mask + for (let i = 0; i < payloadLen; i++) { + buf[headerLen + 4 + i] = unmasked[i] ^ mask[i & 3]; + } + + return buf; +} +export function* decodeFrames(buffer) { + let i = 0; + while (i + 2 <= buffer.length) { + const b0 = buffer[i++]; + const b1 = buffer[i++]; + const fin = (b0 & 0x80) !== 0; + const opcode = b0 & 0x0f; + const masked = (b1 & 0x80) !== 0; + let len = b1 & 0x7f; + + if (len === 126) { + if (i + 2 > buffer.length) break; + len = buffer.readUInt16BE(i); + i += 2; + } else if (len === 127) { + if (i + 8 > buffer.length) break; + const big = buffer.readBigUInt64BE(i); + i += 8; + if (big > BigInt(Number.MAX_SAFE_INTEGER)) throw new Error("frame too large"); + len = Number(big); + } + + let mask; + if (masked) { + if (i + 4 > buffer.length) break; + mask = buffer.subarray(i, i + 4); + i += 4; + } + + if (i + len > buffer.length) break; + let payload = buffer.subarray(i, i + len); + i += len; + + if (masked && mask) { + const unmasked = Buffer.alloc(len); + for (let j = 0; j < len; j++) unmasked[j] = payload[j] ^ mask[j & 3]; + payload = unmasked; + } + + if (!fin) throw new Error("fragmentation not supported in this demo"); + if (opcode === 0x1) { + // text + yield payload.toString("utf8"); + } else if (opcode === 0x8) { + // CLOSE + yield { type: "close" }; + return; + } else if (opcode === 0x9) { + // PING -> respond with PONG if you implement writes here + yield { type: "ping", data: payload }; + } else if (opcode === 0xa) { + // PONG + yield { type: "pong", data: payload }; + } else { + // ignore other opcodes for brevity + } + } +} + +// Encode a single unfragmented TEXT frame (client -> server must be masked) +export function encodeTextFrame(str) { + const payload = Buffer.from(str, "utf8"); + const len = payload.length; + + let headerLen = 2; + if (len >= 126 && len <= 0xffff) headerLen += 2; + else if (len > 0xffff) headerLen += 8; + const maskKeyLen = 4; + + const buf = Buffer.alloc(headerLen + maskKeyLen + len); + // FIN=1, RSV=0, opcode=0x1 (text) + buf[0] = 0x80 | 0x1; + + // Set masked bit and length field(s) + let offset = 1; + if (len < 126) { + buf[offset++] = 0x80 | len; // mask bit + length + } else if (len <= 0xffff) { + buf[offset++] = 0x80 | 126; + buf.writeUInt16BE(len, offset); + offset += 2; + } else { + buf[offset++] = 0x80 | 127; + buf.writeBigUInt64BE(BigInt(len), offset); + offset += 8; + } + + // Mask key + const mask = randomBytes(4); + mask.copy(buf, offset); + offset += 4; + + // Mask the payload + for (let i = 0; i < len; i++) { + buf[offset + i] = payload[i] ^ mask[i & 3]; + } + + return buf; +} + +export function upgradeHeaders() { + const secWebSocketKey = makeKey(); + return { + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": secWebSocketKey, + }; +} From c7f6623878d06f849de72b935b0e96572fdf3022 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Thu, 4 Sep 2025 02:15:48 +0000 Subject: [PATCH 3/7] [autofix.ci] apply automated fixes --- test/js/web/fetch/fetch.upgrade.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/js/web/fetch/fetch.upgrade.test.ts b/test/js/web/fetch/fetch.upgrade.test.ts index 243bea7762..661b7260b8 100644 --- a/test/js/web/fetch/fetch.upgrade.test.ts +++ b/test/js/web/fetch/fetch.upgrade.test.ts @@ -1,5 +1,5 @@ import { describe, expect, test } from "bun:test"; -import { encodeTextFrame, encodeCloseFrame, decodeFrames, upgradeHeaders } from "./websocket.helpers"; +import { decodeFrames, encodeCloseFrame, encodeTextFrame, upgradeHeaders } from "./websocket.helpers"; describe("fetch upgrade", () => { test("should upgrade to websocket", async () => { From 87274168085d72e16a26e2556148341d997d411e Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Wed, 3 Sep 2025 19:34:31 -0700 Subject: [PATCH 4/7] more generic --- src/bun.js/webcore/fetch.zig | 18 +++++++++--------- src/http.zig | 24 ++++++++++++++---------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/bun.js/webcore/fetch.zig b/src/bun.js/webcore/fetch.zig index c7d421b769..46b06fea9b 100644 --- a/src/bun.js/webcore/fetch.zig +++ b/src/bun.js/webcore/fetch.zig @@ -108,7 +108,7 @@ pub const FetchTasklet = struct { // custom checkServerIdentity check_server_identity: jsc.Strong.Optional = .empty, reject_unauthorized: bool = true, - is_websocket_upgrade: bool = false, + upgraded_connection: bool = false, // Custom Hostname hostname: ?[]u8 = null, is_waiting_body: bool = false, @@ -1070,7 +1070,7 @@ pub const FetchTasklet = struct { .memory_reporter = fetch_options.memory_reporter, .check_server_identity = fetch_options.check_server_identity, .reject_unauthorized = fetch_options.reject_unauthorized, - .is_websocket_upgrade = fetch_options.is_websocket_upgrade, + .upgraded_connection = fetch_options.upgraded_connection, }; fetch_tasklet.signals = fetch_tasklet.signal_store.to(); @@ -1203,7 +1203,7 @@ pub const FetchTasklet = struct { // dont have backpressure so we will schedule the data to be written // if we have backpressure the onWritable will drain the buffer needs_schedule = stream_buffer.isEmpty(); - if (this.is_websocket_upgrade) { + if (this.upgraded_connection) { bun.handleOom(stream_buffer.write(data)); } else { //16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n @@ -1277,7 +1277,7 @@ pub const FetchTasklet = struct { check_server_identity: jsc.Strong.Optional = .empty, unix_socket_path: ZigString.Slice, ssl_config: ?*SSLConfig = null, - is_websocket_upgrade: bool = false, + upgraded_connection: bool = false, }; pub fn queue( @@ -1501,7 +1501,7 @@ pub fn Bun__fetch_( var memory_reporter = bun.handleOom(bun.default_allocator.create(bun.MemoryReportingAllocator)); // used to clean up dynamically allocated memory on error (a poor man's errdefer) var is_error = false; - var is_websocket_upgrade = false; + var upgraded_connection = false; var allocator = memory_reporter.wrap(bun.default_allocator); errdefer bun.default_allocator.destroy(memory_reporter); defer { @@ -2210,8 +2210,8 @@ pub fn Bun__fetch_( const upgrade = _upgrade.toSlice(bun.default_allocator); defer upgrade.deinit(); const slice = upgrade.slice(); - if (bun.strings.eqlComptime(slice, "websocket")) { - is_websocket_upgrade = true; + if (!bun.strings.eqlComptime(slice, "h2") and !bun.strings.eqlComptime(slice, "h2c")) { + upgraded_connection = true; } } @@ -2350,7 +2350,7 @@ pub fn Bun__fetch_( } } - if (!method.hasRequestBody() and body.hasBody() and !is_websocket_upgrade) { + if (!method.hasRequestBody() and body.hasBody() and !upgraded_connection) { const err = globalThis.toTypeError(.INVALID_ARG_VALUE, fetch_error_unexpected_body, .{}); is_error = true; return JSPromise.dangerouslyCreateRejectedPromiseValueWithoutNotifyingVM(globalThis, err); @@ -2668,7 +2668,7 @@ pub fn Bun__fetch_( .ssl_config = ssl_config, .hostname = hostname, .memory_reporter = memory_reporter, - .is_websocket_upgrade = is_websocket_upgrade, + .upgraded_connection = upgraded_connection, .check_server_identity = if (check_server_identity.isEmptyOrUndefinedOrNull()) .empty else .create(check_server_identity, globalThis), .unix_socket_path = unix_socket_path, }, diff --git a/src/http.zig b/src/http.zig index 1d8eef5037..5f1e55f770 100644 --- a/src/http.zig +++ b/src/http.zig @@ -393,6 +393,11 @@ pub const HTTPVerboseLevel = enum { curl, }; +const HTTPUpgradeState = enum(u2) { + none = 0, + pending = 1, + upgraded = 2, +}; pub const Flags = packed struct(u16) { disable_timeout: bool = false, disable_keepalive: bool = false, @@ -405,8 +410,7 @@ pub const Flags = packed struct(u16) { is_preconnect_only: bool = false, is_streaming_request_body: bool = false, defer_fail_until_connecting_is_complete: bool = false, - is_websockets: bool = false, - websocket_upgraded: bool = false, + upgrade_state: HTTPUpgradeState = .none, _padding: u3 = 0, }; @@ -595,8 +599,9 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { override_accept_encoding = true; }, hashHeaderConst("Upgrade") => { - if (std.ascii.eqlIgnoreCase(this.headerStr(header_values[i]), "websocket")) { - this.flags.is_websockets = true; + const value = this.headerStr(header_values[i]); + if (!std.ascii.eqlIgnoreCase(value, "h2") and !std.ascii.eqlIgnoreCase(value, "h2c")) { + this.flags.upgrade_state = .pending; } }, hashHeaderConst(chunked_encoded_header.name) => { @@ -1030,8 +1035,8 @@ pub fn writeToStream(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCo log("flushStream", .{}); var stream = &this.state.original_request_body.stream; const stream_buffer = stream.buffer orelse return; - if (this.flags.is_websockets and !this.flags.websocket_upgraded) { - // cannot drain yet, websocket is waiting for upgrade + if (this.flags.upgrade_state == .pending) { + // cannot drain yet, upgrade is waiting for upgrade return; } const buffer = stream_buffer.acquire(); @@ -1378,14 +1383,13 @@ pub fn handleOnDataHeaders( to_read = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..]; if (response.status_code == 101) { - if (!this.flags.is_websockets) { + if (this.flags.upgrade_state == .none) { // we cannot upgrade to websocket because the client did not request it! this.closeAndFail(error.UnrequestedUpgrade, is_ssl, socket); return; } // special case for websocket upgrade - this.flags.is_websockets = true; - this.flags.websocket_upgraded = true; + this.flags.upgrade_state = .upgraded; if (this.signals.upgraded) |upgraded| { upgraded.store(true, .monotonic); } @@ -2448,7 +2452,7 @@ pub fn handleResponseMetadata( log("handleResponseMetadata: content_length is null and transfer_encoding {}", .{this.state.transfer_encoding}); } - if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.websocket_upgraded)) { + if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.upgrade_state == .upgraded)) { return ShouldContinue.continue_streaming; } else { return ShouldContinue.finished; From ed21db9414f8ca2c3e1e1c825e9960dca6290cb2 Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Thu, 4 Sep 2025 12:24:53 -0700 Subject: [PATCH 5/7] opsie --- src/http.zig | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/http.zig b/src/http.zig index 5f1e55f770..d7c699b589 100644 --- a/src/http.zig +++ b/src/http.zig @@ -1033,6 +1033,9 @@ fn writeToStreamUsingBuffer(this: *HTTPClient, comptime is_ssl: bool, socket: Ne } pub fn writeToStream(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket, data: []const u8) void { log("flushStream", .{}); + if (this.state.original_request_body != .stream) { + return; + } var stream = &this.state.original_request_body.stream; const stream_buffer = stream.buffer orelse return; if (this.flags.upgrade_state == .pending) { @@ -1399,7 +1402,7 @@ pub fn handleOnDataHeaders( } // handle the case where we have a 100 Continue - if (response.status_code >= 100 and response.status_code < 200) { + if (response.status_code >= 100 and response.status_code < 200 and to_read.len > 0) { log("information headers", .{}); // we still can have the 200 OK in the same buffer sometimes // 1XX responses MUST NOT include a message-body, therefore we need to continue parsing From 0609fa5122d1fa1676b3e5021521fc575520fe87 Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Thu, 4 Sep 2025 12:59:12 -0700 Subject: [PATCH 6/7] dont break stuff --- src/http.zig | 35 ++++++++++++++++--------- test/js/web/fetch/fetch.upgrade.test.ts | 2 +- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/http.zig b/src/http.zig index d7c699b589..25d133bd44 100644 --- a/src/http.zig +++ b/src/http.zig @@ -1340,18 +1340,18 @@ pub fn handleOnDataHeaders( ctx: *NewHTTPContext(is_ssl), socket: NewHTTPContext(is_ssl).HTTPSocket, ) void { - log("handleOnDataHeaders", .{}); + log("handleOnDataHeader data: {s}", .{incoming_data}); var to_read = incoming_data; + var needs_move = true; + if (this.state.response_message_buffer.list.items.len > 0) { + // this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating + bun.handleOom(this.state.response_message_buffer.appendSliceExact(incoming_data)); + to_read = this.state.response_message_buffer.list.items; + needs_move = false; + } while (true) { var amount_read: usize = 0; - var needs_move = true; - if (this.state.response_message_buffer.list.items.len > 0) { - // this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating - bun.handleOom(this.state.response_message_buffer.appendSliceExact(incoming_data)); - to_read = this.state.response_message_buffer.list.items; - needs_move = false; - } // we reset the pending_response each time wich means that on parse error this will be always be empty this.state.pending_response = picohttp.Response{}; @@ -1402,11 +1402,15 @@ pub fn handleOnDataHeaders( } // handle the case where we have a 100 Continue - if (response.status_code >= 100 and response.status_code < 200 and to_read.len > 0) { + if (response.status_code >= 100 and response.status_code < 200) { log("information headers", .{}); - // we still can have the 200 OK in the same buffer sometimes - // 1XX responses MUST NOT include a message-body, therefore we need to continue parsing + this.state.pending_response = null; + if (to_read.len == 0) { + // we only received 1XX responses, we wanna wait for the next status code + return; + } + // the buffer could still contain more 1XX responses or other status codes, so we continue parsing continue; } @@ -2188,7 +2192,7 @@ pub fn handleResponseMetadata( // [...] cannot contain a message body or trailer section. // therefore in these cases set content-length to 0, so the response body is always ignored // and is not waited for (which could cause a timeout) - if ((response.status_code >= 100 and response.status_code < 200 and response.status_code != 101) or response.status_code == 204 or response.status_code == 304) { + if ((response.status_code >= 100 and response.status_code < 200) or response.status_code == 204 or response.status_code == 304) { this.state.content_length = 0; } @@ -2454,8 +2458,13 @@ pub fn handleResponseMetadata( } else { log("handleResponseMetadata: content_length is null and transfer_encoding {}", .{this.state.transfer_encoding}); } + if (this.flags.upgrade_state == .upgraded) { + this.state.content_length = null; + this.state.flags.allow_keepalive = false; + return ShouldContinue.continue_streaming; + } - if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.upgrade_state == .upgraded)) { + if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events)) { return ShouldContinue.continue_streaming; } else { return ShouldContinue.finished; diff --git a/test/js/web/fetch/fetch.upgrade.test.ts b/test/js/web/fetch/fetch.upgrade.test.ts index 661b7260b8..58bc438f7f 100644 --- a/test/js/web/fetch/fetch.upgrade.test.ts +++ b/test/js/web/fetch/fetch.upgrade.test.ts @@ -5,7 +5,7 @@ describe("fetch upgrade", () => { test("should upgrade to websocket", async () => { const serverMessages: string[] = []; using server = Bun.serve({ - port: 3000, + port: 0, fetch(req) { if (server.upgrade(req)) return; return new Response("Hello World"); From 8df4827833445aceb7f4d2a6531e32c3fd8b28be Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Thu, 4 Sep 2025 15:21:00 -0700 Subject: [PATCH 7/7] we need to close at some point --- src/http.zig | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/http.zig b/src/http.zig index 25d133bd44..a2d8661008 100644 --- a/src/http.zig +++ b/src/http.zig @@ -1069,6 +1069,11 @@ pub fn writeToStream(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCo this.state.request_stage = .done; stream_buffer.release(); stream.detach(); + if (this.flags.upgrade_state == .upgraded) { + this.state.flags.received_last_chunk = true; + // upgraded connection will end when the body is done (no half-open connections) + this.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); + } } else { // only report drain if we send everything and previous we had something to send if (!wasEmpty) {