diff --git a/packages/bun-usockets/src/context.c b/packages/bun-usockets/src/context.c index 49e8bc3a16..b25a41881f 100644 --- a/packages/bun-usockets/src/context.c +++ b/packages/bun-usockets/src/context.c @@ -72,7 +72,7 @@ void us_socket_context_close(int ssl, struct us_socket_context_t *context) { while (ls) { struct us_listen_socket_t *nextLS = (struct us_listen_socket_t *) ls->s.next; us_listen_socket_close(ssl, ls); - + ls = nextLS; } @@ -310,7 +310,7 @@ struct us_bun_verify_error_t us_socket_verify_error(int ssl, struct us_socket_t } #endif - return (struct us_bun_verify_error_t) { .error = 0, .code = NULL, .reason = NULL }; + return (struct us_bun_verify_error_t) { .error = 0, .code = NULL, .reason = NULL }; } void us_internal_socket_context_free(int ssl, struct us_socket_context_t *context) { @@ -337,7 +337,7 @@ void us_socket_context_ref(int ssl, struct us_socket_context_t *context) { } void us_socket_context_unref(int ssl, struct us_socket_context_t *context) { uint32_t ref_count = context->ref_count; - context->ref_count--; + context->ref_count--; if (ref_count == 1) { us_internal_socket_context_free(ssl, context); } @@ -520,7 +520,7 @@ void *us_socket_context_connect(int ssl, struct us_socket_context_t *context, co } struct us_connecting_socket_t *c = us_calloc(1, sizeof(struct us_connecting_socket_t) + socket_ext_size); - c->socket_ext_size = socket_ext_size; + c->socket_ext_size = socket_ext_size; c->options = options; c->ssl = ssl > 0; c->timeout = 255; @@ -641,9 +641,9 @@ void us_internal_socket_after_open(struct us_socket_t *s, int error) { /* Emit error, close without emitting on_close */ - /* There are two possible states here: - 1. It's a us_connecting_socket_t*. DNS resolution failed, or a connection failed. - 2. It's a us_socket_t* + /* There are two possible states here: + 1. It's a us_connecting_socket_t*. DNS resolution failed, or a connection failed. + 2. It's a us_socket_t* We differentiate between these two cases by checking if the connect_state is null. */ @@ -887,7 +887,7 @@ void us_socket_context_on_connect_error(int ssl, struct us_socket_context_t *con return; } #endif - + context->on_connect_error = on_connect_error; } @@ -898,7 +898,7 @@ void us_socket_context_on_socket_connect_error(int ssl, struct us_socket_context return; } #endif - + context->on_socket_connect_error = on_connect_error; } diff --git a/packages/bun-usockets/src/socket.c b/packages/bun-usockets/src/socket.c index 4ca98c88f0..a346101c9d 100644 --- a/packages/bun-usockets/src/socket.c +++ b/packages/bun-usockets/src/socket.c @@ -167,12 +167,13 @@ void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) { if (!c->pending_resolve_callback) { us_connecting_socket_free(ssl, c); } -} +} struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason) { if(ssl) { return (struct us_socket_t *)us_internal_ssl_socket_close((struct us_internal_ssl_socket_t *) s, code, reason); } + if (!us_socket_is_closed(0, s)) { /* make sure the context is alive until the callback ends */ us_socket_context_ref(ssl, s->context); @@ -227,8 +228,8 @@ struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, vo /* preserve the return value from on_close if its called */ return res; - } + return s; } @@ -445,18 +446,18 @@ int us_connecting_socket_get_error(int ssl, struct us_connecting_socket_t *c) { return c->error; } -/* +/* Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context context ext will not be copied to the new context, new context will contain us_wrapped_socket_context_t on ext */ struct us_socket_t *us_socket_wrap_with_tls(int ssl, struct us_socket_t *s, struct us_bun_socket_context_options_t options, struct us_socket_events_t events, int socket_ext_size) { // only accepts non-TLS sockets if (ssl) { - return NULL; + return NULL; } return(struct us_socket_t *) us_internal_ssl_socket_wrap_with_tls(s, options, events, socket_ext_size); -} +} // if a TLS socket calls this, it will start SSL call open event and TLS handshake if required // will have no effect if the socket is closed or is not TLS diff --git a/packages/bun-uws/src/App.h b/packages/bun-uws/src/App.h index 0429ce4dd8..ece41dcec6 100644 --- a/packages/bun-uws/src/App.h +++ b/packages/bun-uws/src/App.h @@ -412,14 +412,14 @@ public: webSocketContext->getExt()->messageHandler = std::move(behavior.message); webSocketContext->getExt()->drainHandler = std::move(behavior.drain); webSocketContext->getExt()->subscriptionHandler = std::move(behavior.subscription); - webSocketContext->getExt()->closeHandler = std::move([closeHandler = std::move(behavior.close)](WebSocket *ws, int code, std::string_view message) mutable { + webSocketContext->getExt()->closeHandler = [closeHandler = std::move(behavior.close)](WebSocket *ws, int code, std::string_view message) mutable { if (closeHandler) { closeHandler(ws, code, message); } /* Destruct user data after returning from close handler */ ((UserData *) ws->getUserData())->~UserData(); - }); + }; webSocketContext->getExt()->pingHandler = std::move(behavior.ping); webSocketContext->getExt()->pongHandler = std::move(behavior.pong); @@ -432,8 +432,8 @@ public: webSocketContext->getExt()->maxLifetime = behavior.maxLifetime; webSocketContext->getExt()->compression = behavior.compression; - /* Calculate idleTimeoutCompnents */ - webSocketContext->getExt()->calculateIdleTimeoutCompnents(behavior.idleTimeout); + /* Calculate idleTimeoutComponents */ + webSocketContext->getExt()->calculateIdleTimeoutComponents(behavior.idleTimeout); httpContext->onHttp("GET", pattern, [webSocketContext, behavior = std::move(behavior)](auto *res, auto *req) mutable { @@ -619,6 +619,11 @@ public: return std::move(*this); } + TemplatedApp &&setUsingCustomExpectHandler(bool value) { + httpContext->getSocketContextData()->usingCustomExpectHandler = value; + return std::move(*this); + } + }; typedef TemplatedApp App; diff --git a/packages/bun-uws/src/AsyncSocketData.h b/packages/bun-uws/src/AsyncSocketData.h index 2dd4ed88b0..ad0d13ca5e 100644 --- a/packages/bun-uws/src/AsyncSocketData.h +++ b/packages/bun-uws/src/AsyncSocketData.h @@ -81,7 +81,7 @@ struct AsyncSocketData { } - /* Or emppty */ + /* Or empty */ AsyncSocketData() = default; }; diff --git a/packages/bun-uws/src/HttpContext.h b/packages/bun-uws/src/HttpContext.h index fad558f23b..89888317b8 100644 --- a/packages/bun-uws/src/HttpContext.h +++ b/packages/bun-uws/src/HttpContext.h @@ -43,10 +43,10 @@ private: HttpContext() = delete; /* Maximum delay allowed until an HTTP connection is terminated due to outstanding request or rejected data (slow loris protection) */ - static const int HTTP_IDLE_TIMEOUT_S = 10; + static constexpr int HTTP_IDLE_TIMEOUT_S = 10; /* Minimum allowed receive throughput per second (clients uploading less than 16kB/sec get dropped) */ - static const int HTTP_RECEIVE_THROUGHPUT_BYTES = 16 * 1024; + static constexpr int HTTP_RECEIVE_THROUGHPUT_BYTES = 16 * 1024; us_socket_context_t *getSocketContext() { return (us_socket_context_t *) this; @@ -199,6 +199,8 @@ private: httpResponseData->state |= HttpResponseData::HTTP_CONNECTION_CLOSE; } + httpResponseData->fromAncientRequest = httpRequest->isAncient(); + /* Select the router based on SNI (only possible for SSL) */ auto *selectedRouter = &httpContextData->router; if constexpr (SSL) { @@ -360,9 +362,8 @@ private: /* Handle HTTP write out (note: SSL_read may trigger this spuriously, the app need to handle spurious calls) */ us_socket_context_on_writable(SSL, getSocketContext(), [](us_socket_t *s) { - - AsyncSocket *asyncSocket = (AsyncSocket *) s; - HttpResponseData *httpResponseData = (HttpResponseData *) asyncSocket->getAsyncSocketData(); + auto *asyncSocket = reinterpret_cast *>(s); + auto *httpResponseData = reinterpret_cast *>(asyncSocket->getAsyncSocketData()); /* Ask the developer to write data and return success (true) or failure (false), OR skip sending anything and return success (true). */ if (httpResponseData->onWritable) { @@ -371,7 +372,7 @@ private: /* We expect the developer to return whether or not write was successful (true). * If write was never called, the developer should still return true so that we may drain. */ - bool success = httpResponseData->callOnWritable((HttpResponse *)asyncSocket, httpResponseData->offset); + bool success = httpResponseData->callOnWritable(reinterpret_cast *>(asyncSocket), httpResponseData->offset); /* The developer indicated that their onWritable failed. */ if (!success) { @@ -398,28 +399,26 @@ private: } /* Expect another writable event, or another request within the timeout */ - ((HttpResponse *) s)->resetTimeout(); + reinterpret_cast *>(s)->resetTimeout(); return s; }); /* Handle FIN, HTTP does not support half-closed sockets, so simply close */ us_socket_context_on_end(SSL, getSocketContext(), [](us_socket_t *s) { - ((AsyncSocket *)s)->uncorkWithoutSending(); - + auto *asyncSocket = reinterpret_cast *>(s); + asyncSocket->uncorkWithoutSending(); /* We do not care for half closed sockets */ - AsyncSocket *asyncSocket = (AsyncSocket *) s; return asyncSocket->close(); - }); /* Handle socket timeouts, simply close them so to not confuse client with FIN */ us_socket_context_on_timeout(SSL, getSocketContext(), [](us_socket_t *s) { /* Force close rather than gracefully shutdown and risk confusing the client with a complete download */ - AsyncSocket *asyncSocket = (AsyncSocket *) s; + AsyncSocket *asyncSocket = reinterpret_cast *>(s); // Node.js by default closes the connection but they emit the timeout event before that - HttpResponseData *httpResponseData = (HttpResponseData *) asyncSocket->getAsyncSocketData(); + HttpResponseData *httpResponseData = reinterpret_cast *>(asyncSocket->getAsyncSocketData()); if (httpResponseData->onTimeout) { httpResponseData->onTimeout((HttpResponse *)s, httpResponseData->userData); @@ -495,16 +494,20 @@ public: } } - httpContextData->currentRouter->add(methods, pattern, [handler = std::move(handler), parameterOffsets = std::move(parameterOffsets)](auto *r) mutable { + const bool &customContinue = httpContextData->usingCustomExpectHandler; + + httpContextData->currentRouter->add(methods, pattern, [handler = std::move(handler), parameterOffsets = std::move(parameterOffsets), &customContinue](auto *r) mutable { auto user = r->getUserData(); user.httpRequest->setYield(false); user.httpRequest->setParameters(r->getParameters()); user.httpRequest->setParameterOffsets(¶meterOffsets); - /* Middleware? Automatically respond to expectations */ - std::string_view expect = user.httpRequest->getHeader("expect"); - if (expect.length() && expect == "100-continue") { - user.httpResponse->writeContinue(); + if (!customContinue) { + /* Middleware? Automatically respond to expectations */ + std::string_view expect = user.httpRequest->getHeader("expect"); + if (expect.length() && expect == "100-continue") { + user.httpResponse->writeContinue(); + } } handler(user.httpResponse, user.httpRequest); diff --git a/packages/bun-uws/src/HttpContextData.h b/packages/bun-uws/src/HttpContextData.h index 53f1b91065..c71f3098d2 100644 --- a/packages/bun-uws/src/HttpContextData.h +++ b/packages/bun-uws/src/HttpContextData.h @@ -27,7 +27,6 @@ namespace uWS { template struct HttpResponse; struct HttpRequest; - template struct alignas(16) HttpContextData { template friend struct HttpContext; @@ -52,6 +51,7 @@ private: void *upgradedWebSocket = nullptr; bool isParsingHttp = false; bool rejectUnauthorized = false; + bool usingCustomExpectHandler = false; /* Used to simulate Node.js socket events. */ OnSocketClosedCallback onSocketClosed = nullptr; diff --git a/packages/bun-uws/src/HttpResponse.h b/packages/bun-uws/src/HttpResponse.h index 5bd9816c5d..0fcc2cc208 100644 --- a/packages/bun-uws/src/HttpResponse.h +++ b/packages/bun-uws/src/HttpResponse.h @@ -118,7 +118,7 @@ public: } /* if write was called and there was previously no Content-Length header set */ - if (httpResponseData->state & HttpResponseData::HTTP_WRITE_CALLED && !(httpResponseData->state & HttpResponseData::HTTP_WROTE_CONTENT_LENGTH_HEADER)) { + if (httpResponseData->state & HttpResponseData::HTTP_WRITE_CALLED && !(httpResponseData->state & HttpResponseData::HTTP_WROTE_CONTENT_LENGTH_HEADER) && !httpResponseData->fromAncientRequest) { /* We do not have tryWrite-like functionalities, so ignore optional in this path */ @@ -482,7 +482,7 @@ public: HttpResponseData *httpResponseData = getHttpResponseData(); - if (!(httpResponseData->state & HttpResponseData::HTTP_WROTE_CONTENT_LENGTH_HEADER)) { + if (!(httpResponseData->state & HttpResponseData::HTTP_WROTE_CONTENT_LENGTH_HEADER) && !httpResponseData->fromAncientRequest) { if (!(httpResponseData->state & HttpResponseData::HTTP_WRITE_CALLED)) { /* Write mark on first call to write */ writeMark(); diff --git a/packages/bun-uws/src/HttpResponseData.h b/packages/bun-uws/src/HttpResponseData.h index 5c3467fc93..13b3abf695 100644 --- a/packages/bun-uws/src/HttpResponseData.h +++ b/packages/bun-uws/src/HttpResponseData.h @@ -36,7 +36,7 @@ struct HttpResponseData : AsyncSocketData, HttpParser { using OnAbortedCallback = void (*)(uWS::HttpResponse*, void*); using OnTimeoutCallback = void (*)(uWS::HttpResponse*, void*); using OnDataCallback = void (*)(uWS::HttpResponse* response, const char* chunk, size_t chunk_length, bool, void*); - + /* When we are done with a response we mark it like so */ void markDone() { onAborted = nullptr; @@ -53,7 +53,7 @@ struct HttpResponseData : AsyncSocketData, HttpParser { } /* Caller of onWritable. It is possible onWritable calls markDone so we need to borrow it. */ - bool callOnWritable( uWS::HttpResponse* response, uint64_t offset) { + bool callOnWritable(uWS::HttpResponse* response, uint64_t offset) { /* Borrow real onWritable */ auto* borrowedOnWritable = std::move(onWritable); @@ -100,6 +100,7 @@ struct HttpResponseData : AsyncSocketData, HttpParser { /* Current state (content-length sent, status sent, write called, etc */ uint8_t state = 0; uint8_t idleTimeout = 10; // default HTTP_TIMEOUT 10 seconds + bool fromAncientRequest = false; #ifdef UWS_WITH_PROXY ProxyParser proxyParser; diff --git a/packages/bun-uws/src/WebSocket.h b/packages/bun-uws/src/WebSocket.h index 3f48b91271..6b5efc81f7 100644 --- a/packages/bun-uws/src/WebSocket.h +++ b/packages/bun-uws/src/WebSocket.h @@ -115,7 +115,7 @@ public: char header[10]; int header_length = (int) protocol::formatMessage(header, "", 0, opCode, message.length(), compress, fin); int written = us_socket_write2(0, (struct us_socket_t *)this, header, header_length, message.data(), (int) message.length()); - + if (written != header_length + (int) message.length()) { /* Buffer up backpressure */ if (written > header_length) { @@ -289,7 +289,7 @@ public: ); WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) this); - + if (!webSocketData->subscriber) { return false; } /* Cannot return numSubscribers as this is only for this particular websocket context */ diff --git a/packages/bun-uws/src/WebSocketContextData.h b/packages/bun-uws/src/WebSocketContextData.h index e5ae49aeae..c016be49c4 100644 --- a/packages/bun-uws/src/WebSocketContextData.h +++ b/packages/bun-uws/src/WebSocketContextData.h @@ -82,7 +82,7 @@ public: std::pair idleTimeoutComponents; /* This is run once on start-up */ - void calculateIdleTimeoutCompnents(unsigned short idleTimeout) { + void calculateIdleTimeoutComponents(unsigned short idleTimeout) { unsigned short margin = 4; /* 4, 8 or 16 seconds margin based on idleTimeout */ while ((int) idleTimeout - margin * 2 >= margin * 2 && margin < 16) { diff --git a/src/bun.js/api/server.classes.ts b/src/bun.js/api/server.classes.ts index a198cc4f2d..5055c413eb 100644 --- a/src/bun.js/api/server.classes.ts +++ b/src/bun.js/api/server.classes.ts @@ -100,9 +100,12 @@ export default [ fn: "writeHead", length: 3, }, + writeContinue: { + fn: "writeContinue", + }, write: { fn: "write", - length: 2, + length: 3, }, end: { fn: "end", @@ -161,6 +164,13 @@ export default [ getter: "getOnAbort", setter: "setOnAbort", }, + hasCustomOnData: { + getter: "getHasCustomOnData", + setter: "setHasCustomOnData", + }, + upgraded: { + getter: "getUpgraded", + }, // ontimeout: { // getter: "getOnTimeout", // setter: "setOnTimeout", @@ -173,6 +183,7 @@ export default [ klass: {}, finalize: true, noConstructor: true, + values: ["onAborted", "onWritable", "onData"], }), define({ @@ -289,6 +300,7 @@ export default [ finalize: true, construct: true, klass: {}, + values: ["socket"], }), define({ diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 1a8d1da21f..c7d09c586f 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -5175,6 +5175,10 @@ pub const ServerWebSocket = struct { const signal = this.signal; this.signal = null; + if (ServerWebSocket.socketGetCached(this.getThisValue())) |socket| { + Bun__callNodeHTTPServerSocketOnClose(socket); + } + defer { if (signal) |sig| { sig.pendingActivityUnref(); @@ -6154,6 +6158,7 @@ pub const NodeHTTPResponse = struct { finished: bool = false, ended: bool = false, upgraded: bool = false, + hasCustomOnData: bool = false, is_request_pending: bool = true, body_read_state: BodyReadState = .none, body_read_ref: JSC.Ref = .{}, @@ -6223,18 +6228,16 @@ pub const NodeHTTPResponse = struct { done = 2, }; - extern "C" fn Bun__getNodeHTTPResponseThisValue(c_int, *anyopaque) JSC.JSValue; + extern "C" fn Bun__getNodeHTTPResponseThisValue(bool, *anyopaque) JSC.JSValue; fn getThisValue(this: *NodeHTTPResponse) JSC.JSValue { - return Bun__getNodeHTTPResponseThisValue(@intFromBool(this.response == .SSL), this.response.socket()); + return Bun__getNodeHTTPResponseThisValue(this.response == .SSL, this.response.socket()); } - extern "C" fn Bun__getNodeHTTPServerSocketThisValue(c_int, *anyopaque) JSC.JSValue; + extern "C" fn Bun__getNodeHTTPServerSocketThisValue(bool, *anyopaque) JSC.JSValue; fn getServerSocketValue(this: *NodeHTTPResponse) JSC.JSValue { - return Bun__getNodeHTTPServerSocketThisValue(@intFromBool(this.response == .SSL), this.response.socket()); + return Bun__getNodeHTTPServerSocketThisValue(this.response == .SSL, this.response.socket()); } - extern "C" fn Bun__setNodeHTTPServerSocketUsSocketValue(JSC.JSValue, *anyopaque) void; - pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_protocol: ZigString, sec_websocket_extensions: ZigString) bool { const upgrade_ctx = this.upgrade_context.context orelse return false; const ws_handler = this.server.webSocketHandler() orelse return false; @@ -6255,6 +6258,7 @@ pub const NodeHTTPResponse = struct { defer if (new_socket) |socket| { this.upgraded = true; Bun__setNodeHTTPServerSocketUsSocketValue(socketValue, socket); + ServerWebSocket.socketSetCached(ws.getThisValue(), ws_handler.globalObject, socketValue); defer this.js_ref.unref(JSC.VirtualMachine.get()); switch (this.response) { .SSL => this.response = uws.AnyResponse.init(uws.NewApp(true).Response.castRes(@alignCast(@ptrCast(socket)))), @@ -6335,7 +6339,7 @@ pub const NodeHTTPResponse = struct { pub fn maybeStopReadingBody(this: *NodeHTTPResponse, vm: *JSC.VirtualMachine) void { this.upgrade_context.deinit(); // we can discard the upgrade context now - if ((this.aborted or this.ended) and (this.body_read_ref.has or this.body_read_state == .pending) and !this.onDataCallback.has()) { + if ((this.aborted or this.ended) and (this.body_read_ref.has or this.body_read_state == .pending) and (!this.hasCustomOnData or !this.onDataCallback.has())) { const had_ref = this.body_read_ref.has; this.response.clearOnData(); this.body_read_ref.unref(vm); @@ -6398,7 +6402,7 @@ pub const NodeHTTPResponse = struct { pub fn create( any_server_tag: u64, globalObject: *JSC.JSGlobalObject, - has_body: *i32, + has_body: *bool, request: *uws.Request, is_ssl: i32, response_ptr: *anyopaque, @@ -6415,7 +6419,7 @@ pub const NodeHTTPResponse = struct { break :brk 0; }; - has_body.* = @intFromBool(req_len > 0 or request.header("transfer-encoding") != null); + has_body.* = req_len > 0 or request.header("transfer-encoding") != null; } const response = NodeHTTPResponse.new(.{ @@ -6428,14 +6432,14 @@ pub const NodeHTTPResponse = struct { true => uws.AnyResponse{ .SSL = @ptrCast(response_ptr) }, false => uws.AnyResponse{ .TCP = @ptrCast(response_ptr) }, }, - .body_read_state = if (has_body.* != 0) .pending else .none, + .body_read_state = if (has_body.*) .pending else .none, // 1 - the HTTP response // 1 - the JS object // 1 - the Server handler. - // 1 - the onData callback (request bod) - .ref_count = if (has_body.* != 0) 4 else 3, + // 1 - the onData callback (request body) + .ref_count = if (has_body.*) 4 else 3, }); - if (has_body.* != 0) { + if (has_body.*) { response.body_read_ref.ref(vm); } response.js_ref.ref(vm); @@ -6598,7 +6602,7 @@ pub const NodeHTTPResponse = struct { } } - pub fn writeContinue(this: *NodeHTTPResponse, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) JSC.JSValue { + pub fn writeContinue(this: *NodeHTTPResponse, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { const arguments = callframe.arguments_old(1).slice(); _ = arguments; // autofix if (this.isDone()) { @@ -6964,7 +6968,7 @@ pub const NodeHTTPResponse = struct { if (is_end) { // Discard the body read ref if it's pending and no onData callback is set at this point. // This is the equivalent of req._dump(). - if (this.body_read_ref.has and this.body_read_state == .pending and !this.onDataCallback.has()) { + if (this.body_read_ref.has and this.body_read_state == .pending and (!this.hasCustomOnData or !this.onDataCallback.has())) { this.body_read_ref.unref(JSC.VirtualMachine.get()); this.deref(); this.body_read_state = .none; @@ -7021,10 +7025,10 @@ pub const NodeHTTPResponse = struct { pub fn setOnAbort(this: *NodeHTTPResponse, globalObject: *JSC.JSGlobalObject, value: JSValue) bool { if (this.isDone() or value == .undefined) { this.onAbortedCallback.clearWithoutDeallocation(); - return true; + } else { + this.onAbortedCallback.set(globalObject, value.withAsyncContextIfNeeded(globalObject)); } - this.onAbortedCallback.set(globalObject, value.withAsyncContextIfNeeded(globalObject)); return true; } @@ -7032,6 +7036,19 @@ pub const NodeHTTPResponse = struct { return this.onDataCallback.get() orelse .undefined; } + pub fn getHasCustomOnData(this: *NodeHTTPResponse, _: *JSC.JSGlobalObject) JSC.JSValue { + return JSC.jsBoolean(this.hasCustomOnData); + } + + pub fn getUpgraded(this: *NodeHTTPResponse, _: *JSC.JSGlobalObject) JSC.JSValue { + return JSC.jsBoolean(this.upgraded); + } + + pub fn setHasCustomOnData(this: *NodeHTTPResponse, _: *JSC.JSGlobalObject, value: JSValue) bool { + this.hasCustomOnData = value.toBoolean(); + return true; + } + fn clearOnDataCallback(this: *NodeHTTPResponse) void { if (this.body_read_state != .none) { this.onDataCallback.deinit(); @@ -7068,6 +7085,7 @@ pub const NodeHTTPResponse = struct { } this.onDataCallback.set(globalObject, value.withAsyncContextIfNeeded(globalObject)); + this.hasCustomOnData = true; this.response.onData(*NodeHTTPResponse, onData, this); this.is_data_buffered_during_pause = false; @@ -8816,6 +8834,10 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp else NodeHTTPServer__onRequest_http; + pub fn setUsingCustomExpectHandler(this: *ThisServer, value: bool) void { + NodeHTTP_setUsingCustomExpectHandler(ssl_enabled, this.app.?, value); + } + var did_send_idletimeout_warning_once = false; fn onTimeoutForIdleWarn(_: *anyopaque, _: *App.Response) void { if (debug_mode and !did_send_idletimeout_warning_once) { @@ -9330,7 +9352,7 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp } if (this.config.onNodeHTTPRequest != .zero) { app.any("/*", *ThisServer, this, onNodeHTTPRequest); - NodeHTTP_assignOnCloseFunction(@intFromBool(ssl_enabled), app); + NodeHTTP_assignOnCloseFunction(ssl_enabled, app); } else if (this.config.onRequest != .zero and !@"has /*") { app.any("/*", *ThisServer, this, onRequest); } @@ -9501,6 +9523,10 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp route_list_value = this.setRoutes(); } + if (this.config.onNodeHTTPRequest != .zero) { + this.setUsingCustomExpectHandler(true); + } + switch (this.config.address) { .tcp => |tcp| { var host: ?[*:0]const u8 = null; @@ -9873,7 +9899,9 @@ extern fn NodeHTTPServer__onRequest_https( node_response_ptr: *?*NodeHTTPResponse, ) JSC.JSValue; -extern fn NodeHTTP_assignOnCloseFunction(c_int, *anyopaque) void; +extern fn NodeHTTP_assignOnCloseFunction(bool, *anyopaque) void; + +extern fn NodeHTTP_setUsingCustomExpectHandler(bool, *anyopaque, bool) void; fn throwSSLErrorIfNecessary(globalThis: *JSC.JSGlobalObject) bool { const err_code = BoringSSL.ERR_get_error(); @@ -9902,3 +9930,6 @@ extern "c" fn Bun__ServerRouteList__create( paths: [*]ZigString, pathsLength: usize, ) JSC.JSValue; + +extern "C" fn Bun__setNodeHTTPServerSocketUsSocketValue(JSC.JSValue, ?*anyopaque) void; +extern "C" fn Bun__callNodeHTTPServerSocketOnClose(JSC.JSValue) void; diff --git a/src/bun.js/bindings/ErrorCode.cpp b/src/bun.js/bindings/ErrorCode.cpp index 0da65dacd0..0a833206f6 100644 --- a/src/bun.js/bindings/ErrorCode.cpp +++ b/src/bun.js/bindings/ErrorCode.cpp @@ -2034,6 +2034,10 @@ JSC_DEFINE_HOST_FUNCTION(Bun::jsFunctionMakeErrorWithCode, (JSC::JSGlobalObject return JSC::JSValue::encode(createError(globalObject, ErrorCode::ERR_HTTP2_OUT_OF_STREAMS, "No stream ID is available because maximum stream ID has been reached"_s)); case ErrorCode::ERR_HTTP_BODY_NOT_ALLOWED: return JSC::JSValue::encode(createError(globalObject, ErrorCode::ERR_HTTP_BODY_NOT_ALLOWED, "Adding content for this request method or response status is not allowed."_s)); + case ErrorCode::ERR_HTTP_SOCKET_ASSIGNED: + return JSC::JSValue::encode(createError(globalObject, ErrorCode::ERR_HTTP_SOCKET_ASSIGNED, "Socket already assigned"_s)); + case ErrorCode::ERR_STREAM_RELEASE_LOCK: + return JSC::JSValue::encode(createError(globalObject, ErrorCode::ERR_STREAM_RELEASE_LOCK, "Stream reader cancelled via releaseLock()"_s)); default: { break; diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index c4320ce178..f8635d1ecd 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -80,6 +80,7 @@ const errors: ErrorCodeMapping = [ ["ERR_HTTP_HEADERS_SENT", Error], ["ERR_HTTP_INVALID_HEADER_VALUE", TypeError], ["ERR_HTTP_INVALID_STATUS_CODE", RangeError], + ["ERR_HTTP_SOCKET_ASSIGNED", Error], ["ERR_HTTP2_ALTSVC_INVALID_ORIGIN", TypeError], ["ERR_HTTP2_ALTSVC_LENGTH", TypeError], ["ERR_HTTP2_ERROR", Error], diff --git a/src/bun.js/bindings/JSGlobalObject.zig b/src/bun.js/bindings/JSGlobalObject.zig index b35d576033..b2021116e1 100644 --- a/src/bun.js/bindings/JSGlobalObject.zig +++ b/src/bun.js/bindings/JSGlobalObject.zig @@ -491,12 +491,34 @@ pub const JSGlobalObject = opaque { // you most likely need to run // make clean-jsc-bindings // make bindings -j10 - const assertion = this.bunVMUnsafe() == @as(*anyopaque, @ptrCast(JSC.VirtualMachine.get())); - bun.assert(assertion); + if (JSC.VirtualMachine.VMHolder.vm) |vm_| { + bun.assert(this.bunVMUnsafe() == @as(*anyopaque, @ptrCast(vm_))); + } else { + @panic("This thread lacks a Bun VM"); + } } return @as(*JSC.VirtualMachine, @ptrCast(@alignCast(this.bunVMUnsafe()))); } + pub const ThreadKind = enum { + main, + other, + }; + + pub fn tryBunVM(this: *JSGlobalObject) struct { *JSC.VirtualMachine, ThreadKind } { + const vmPtr = @as(*JSC.VirtualMachine, @ptrCast(@alignCast(this.bunVMUnsafe()))); + + if (JSC.VirtualMachine.VMHolder.vm) |vm_| { + if (comptime bun.Environment.allow_assert) { + bun.assert(this.bunVMUnsafe() == @as(*anyopaque, @ptrCast(vm_))); + } + } else { + return .{ vmPtr, .other }; + } + + return .{ vmPtr, .main }; + } + /// We can't do the threadlocal check when queued from another thread pub fn bunVMConcurrently(this: *JSGlobalObject) *JSC.VirtualMachine { return @as(*JSC.VirtualMachine, @ptrCast(@alignCast(this.bunVMUnsafe()))); diff --git a/src/bun.js/bindings/NodeHTTP.cpp b/src/bun.js/bindings/NodeHTTP.cpp index 02ef9344cd..7b5e3b9ce6 100644 --- a/src/bun.js/bindings/NodeHTTP.cpp +++ b/src/bun.js/bindings/NodeHTTP.cpp @@ -125,7 +125,6 @@ public: void close() { - auto* socket = this->socket; if (socket) { us_socket_close(is_ssl, socket, 0, nullptr); } @@ -405,7 +404,7 @@ static void* getNodeHTTPResponsePtr(us_socket_t* socket) return responseObject->wrapped(); } -extern "C" EncodedJSValue Bun__getNodeHTTPResponseThisValue(int is_ssl, us_socket_t* socket) +extern "C" EncodedJSValue Bun__getNodeHTTPResponseThisValue(bool is_ssl, us_socket_t* socket) { if (is_ssl) { return JSValue::encode(getNodeHTTPResponse(socket)); @@ -413,7 +412,7 @@ extern "C" EncodedJSValue Bun__getNodeHTTPResponseThisValue(int is_ssl, us_socke return JSValue::encode(getNodeHTTPResponse(socket)); } -extern "C" EncodedJSValue Bun__getNodeHTTPServerSocketThisValue(int is_ssl, us_socket_t* socket) +extern "C" EncodedJSValue Bun__getNodeHTTPServerSocketThisValue(bool is_ssl, us_socket_t* socket) { if (is_ssl) { return JSValue::encode(getNodeHTTPServerSocket(socket)); @@ -427,6 +426,12 @@ extern "C" void Bun__setNodeHTTPServerSocketUsSocketValue(EncodedJSValue thisVal response->socket = socket; } +extern "C" void Bun__callNodeHTTPServerSocketOnClose(EncodedJSValue thisValue) +{ + auto* response = jsCast(JSValue::decode(thisValue)); + response->onClose(); +} + BUN_DECLARE_HOST_FUNCTION(jsFunctionRequestOrResponseHasBodyValue); BUN_DECLARE_HOST_FUNCTION(jsFunctionGetCompleteRequestOrResponseBodyValueAsArrayBuffer); extern "C" uWS::HttpRequest* Request__getUWSRequest(void*); @@ -838,7 +843,7 @@ static void assignOnCloseFunction(uWS::TemplatedApp* app) }); } -extern "C" void NodeHTTP_assignOnCloseFunction(int is_ssl, void* uws_app) +extern "C" void NodeHTTP_assignOnCloseFunction(bool is_ssl, void* uws_app) { if (is_ssl) { assignOnCloseFunction(reinterpret_cast*>(uws_app)); @@ -846,7 +851,17 @@ extern "C" void NodeHTTP_assignOnCloseFunction(int is_ssl, void* uws_app) assignOnCloseFunction(reinterpret_cast*>(uws_app)); } } -extern "C" EncodedJSValue NodeHTTPResponse__createForJS(size_t any_server, JSC::JSGlobalObject* globalObject, int* hasBody, uWS::HttpRequest* request, int isSSL, void* response_ptr, void* upgrade_ctx, void** nodeHttpResponsePtr); + +extern "C" void NodeHTTP_setUsingCustomExpectHandler(bool is_ssl, void* uws_app, bool value) +{ + if (is_ssl) { + reinterpret_cast*>(uws_app)->setUsingCustomExpectHandler(value); + } else { + reinterpret_cast*>(uws_app)->setUsingCustomExpectHandler(value); + } +} + +extern "C" EncodedJSValue NodeHTTPResponse__createForJS(size_t any_server, JSC::JSGlobalObject* globalObject, bool* hasBody, uWS::HttpRequest* request, int isSSL, void* response_ptr, void* upgrade_ctx, void** nodeHttpResponsePtr); template static EncodedJSValue NodeHTTPServer__onRequest( @@ -874,7 +889,7 @@ static EncodedJSValue NodeHTTPServer__onRequest( return JSValue::encode(exception); } - int hasBody = 0; + bool hasBody = false; WebCore::JSNodeHTTPResponse* nodeHTTPResponseObject = jsCast(JSValue::decode(NodeHTTPResponse__createForJS(any_server, globalObject, &hasBody, request, isSSL, response, upgrade_ctx, nodeHttpResponsePtr))); JSC::CallData callData = getCallData(callbackObject); diff --git a/src/js/builtins.d.ts b/src/js/builtins.d.ts index b12420e278..d6fc776333 100644 --- a/src/js/builtins.d.ts +++ b/src/js/builtins.d.ts @@ -699,6 +699,7 @@ declare function $ERR_HTTP2_ALTSVC_LENGTH(): TypeError; declare function $ERR_HTTP2_PING_LENGTH(): RangeError; declare function $ERR_HTTP2_OUT_OF_STREAMS(): Error; declare function $ERR_HTTP_BODY_NOT_ALLOWED(): Error; +declare function $ERR_HTTP_SOCKET_ASSIGNED(): Error; declare function $ERR_DIR_CLOSED(): Error; /** diff --git a/src/js/node/http.ts b/src/js/node/http.ts index 38a77bc137..c0a631c6e7 100644 --- a/src/js/node/http.ts +++ b/src/js/node/http.ts @@ -145,10 +145,6 @@ const validateHeaderValue = (name, value) => { } }; -function ERR_HTTP_SOCKET_ASSIGNED() { - return new Error(`ServerResponse has an already assigned socket`); -} - // TODO: add primordial for URL // Importing from node:url is unnecessary const { URL, WebSocket, CloseEvent, MessageEvent } = globalThis; @@ -352,8 +348,8 @@ const NodeHTTPServerSocket = class Socket extends Duplex { this[kHandle] = null; const message = this._httpMessage; const req = message?.req; - if (req && !req.complete) { - // at this point the socket is already destroyed, lets avoid UAF + if (req && !req.complete && !req[kHandle]?.upgraded) { + // At this point the socket is already destroyed; let's avoid UAF req[kHandle] = undefined; req.destroy(new ConnResetException("aborted")); } @@ -972,19 +968,24 @@ const ServerPrototype = { didFinish = true; resolveFunction && resolveFunction(); } + http_res.once("close", onClose); if (reachedRequestsLimit) { server.emit("dropRequest", http_req, socket); http_res.writeHead(503); http_res.end(); socket.destroy(); - } else { - const upgrade = http_req.headers.upgrade; - if (upgrade) { - server.emit("upgrade", http_req, socket, kEmptyBuffer); + } else if (http_req.headers.upgrade) { + server.emit("upgrade", http_req, socket, kEmptyBuffer); + } else if (http_req.headers.expect === "100-continue") { + if (server.listenerCount("checkContinue") > 0) { + server.emit("checkContinue", http_req, http_res); } else { + http_res.writeContinue(); server.emit("request", http_req, http_res); } + } else { + server.emit("request", http_req, http_res); } socket.cork(); @@ -992,18 +993,14 @@ const ServerPrototype = { if (capturedError) { handle = undefined; http_res.removeListener("close", onClose); - if (socket._httpMessage === http_res) { - socket._httpMessage = null; - } + http_res.detachSocket(socket); throw capturedError; } if (handle.finished || didFinish) { handle = undefined; http_res.removeListener("close", onClose); - if (socket._httpMessage === http_res) { - socket._httpMessage = null; - } + http_res.detachSocket(socket); return; } @@ -1373,6 +1370,7 @@ const IncomingMessagePrototype = { if (!internalRequest.ondata) { internalRequest.ondata = onDataIncomingMessage.bind(this); + internalRequest.hasCustomOnData = false; } return true; @@ -1814,7 +1812,7 @@ function emitContinueAndSocketNT(self) { self.emit("socket", self.socket); } - //Emit continue event for the client (internally we auto handle it) + // Emit continue event for the client (internally we auto handle it) if (!self._closed && self.getHeader("expect") === "100-continue") { self.emit("continue"); } @@ -1961,14 +1959,18 @@ const ServerResponsePrototype = { this._writeRaw("HTTP/1.1 102 Processing\r\n\r\n", "ascii", cb); }, writeContinue(cb) { - this._writeRaw("HTTP/1.1 100 Continue\r\n\r\n", "ascii", cb); + this.socket[kHandle]?.response?.writeContinue(); + cb?.(); }, // This end method is actually on the OutgoingMessage prototype in Node.js // But we don't want it for the fetch() response version. end(chunk, encoding, callback) { const handle = this[kHandle]; - const isFinished = this.finished || handle?.finished; + if (handle?.aborted) { + return this; + } + if ($isCallable(chunk)) { callback = chunk; chunk = undefined; @@ -1983,6 +1985,7 @@ const ServerResponsePrototype = { if (hasServerResponseFinished(this, chunk, callback)) { return this; } + if (chunk && !this._hasBody) { if (this.req?.method === "HEAD") { chunk = undefined; @@ -1991,62 +1994,67 @@ const ServerResponsePrototype = { } } - if (handle) { - const headerState = this[headerStateSymbol]; - callWriteHeadIfObservable(this, headerState); - - if (headerState !== NodeHTTPHeaderState.sent) { - handle.cork(() => { - handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol]); - - // If handle.writeHead throws, we don't want headersSent to be set to true. - // So we set it here. - this[headerStateSymbol] = NodeHTTPHeaderState.sent; - - // https://github.com/nodejs/node/blob/2eff28fb7a93d3f672f80b582f664a7c701569fb/lib/_http_outgoing.js#L987 - this._contentLength = handle.end(chunk, encoding); - }); - } else { - // If there's no data but you already called end, then you're done. - // We can ignore it in that case. - if (!(!chunk && handle.ended) && !handle.aborted) { - handle.end(chunk, encoding); - } + if (!handle) { + if (typeof callback === "function") { + process.nextTick(callback); } - this._header = " "; - const req = this.req; - const socket = req.socket; - if (!req._consuming && !req?._readableState?.resumeScheduled) { - req._dump(); + return this; + } + + const headerState = this[headerStateSymbol]; + callWriteHeadIfObservable(this, headerState); + + if (headerState !== NodeHTTPHeaderState.sent) { + handle.cork(() => { + handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol]); + + // If handle.writeHead throws, we don't want headersSent to be set to true. + // So we set it here. + this[headerStateSymbol] = NodeHTTPHeaderState.sent; + + // https://github.com/nodejs/node/blob/2eff28fb7a93d3f672f80b582f664a7c701569fb/lib/_http_outgoing.js#L987 + this._contentLength = handle.end(chunk, encoding); + }); + } else { + // If there's no data but you already called end, then you're done. + // We can ignore it in that case. + if (!(!chunk && handle.ended) && !handle.aborted) { + handle.end(chunk, encoding); } - this.detachSocket(socket); - this.finished = true; - this.emit("prefinish"); - this._callPendingCallbacks(); + } + this._header = " "; + const req = this.req; + const socket = req.socket; + if (!req._consuming && !req?._readableState?.resumeScheduled) { + req._dump(); + } + this.detachSocket(socket); + this.finished = true; + this.emit("prefinish"); + this._callPendingCallbacks(); - if (callback) { - process.nextTick( - function (callback, self) { - // In Node.js, the "finish" event triggers the "close" event. - // So it shouldn't become closed === true until after "finish" is emitted and the callback is called. - self.emit("finish"); - try { - callback(); - } catch (err) { - self.emit("error", err); - } - - process.nextTick(emitCloseNT, self); - }, - callback, - this, - ); - } else { - process.nextTick(function (self) { + if (callback) { + process.nextTick( + function (callback, self) { + // In Node.js, the "finish" event triggers the "close" event. + // So it shouldn't become closed === true until after "finish" is emitted and the callback is called. self.emit("finish"); + try { + callback(); + } catch (err) { + self.emit("error", err); + } + process.nextTick(emitCloseNT, self); - }, this); - } + }, + callback, + this, + ); + } else { + process.nextTick(function (self) { + self.emit("finish"); + process.nextTick(emitCloseNT, self); + }, this); } return this; @@ -2081,6 +2089,14 @@ const ServerResponsePrototype = { const headerState = this[headerStateSymbol]; callWriteHeadIfObservable(this, headerState); + if (!handle) { + if (this.socket) { + return this.socket.write(chunk, encoding, callback); + } else { + return OutgoingMessagePrototype.write.$call(this, chunk, encoding, callback); + } + } + if (this[headerStateSymbol] !== NodeHTTPHeaderState.sent) { handle.cork(() => { handle.writeHead(this.statusCode, this.statusMessage, this[headersSymbol]); @@ -2198,7 +2214,7 @@ const ServerResponsePrototype = { assignSocket(socket) { if (socket._httpMessage) { - throw ERR_HTTP_SOCKET_ASSIGNED(); + throw $ERR_HTTP_SOCKET_ASSIGNED("Socket already assigned"); } socket._httpMessage = this; socket.once("close", onServerResponseClose); @@ -2391,7 +2407,7 @@ function ServerResponse_finalDeprecated(chunk, encoding, callback) { req.complete = true; process.nextTick(emitRequestCloseNT, req); } - callback && callback(); + callback?.(); return; } @@ -2473,14 +2489,14 @@ function ClientRequest(input, options, cb) { }; let writeCount = 0; - let resolveNextChunk = () => {}; + let resolveNextChunk: ((end: boolean) => void) | undefined = end => {}; const pushChunk = chunk => { this[kBodyChunks].push(chunk); if (writeCount > 1) { startFetch(); } - resolveNextChunk?.(); + resolveNextChunk?.(false); }; const write_ = (chunk, encoding, callback) => { @@ -2511,7 +2527,7 @@ function ClientRequest(input, options, cb) { for (let chunk of this[kBodyChunks]) { bodySize += chunk.length; - if (bodySize > MAX_FAKE_BACKPRESSURE_SIZE) { + if (bodySize >= MAX_FAKE_BACKPRESSURE_SIZE) { break; } } @@ -2635,171 +2651,196 @@ function ClientRequest(input, options, cb) { keepalive = agentKeepalive; } - let url: string; - let proxy: string | undefined; const protocol = this[kProtocol]; const path = this[kPath]; let host = this[kHost]; - if (isIPv6(host)) { - host = `[${host}]`; - } - if (path.startsWith("http://") || path.startsWith("https://")) { - url = path; - proxy = `${protocol}//${host}${this[kUseDefaultPort] ? "" : ":" + this[kPort]}`; - } else { - url = `${protocol}//${host}${this[kUseDefaultPort] ? "" : ":" + this[kPort]}${path}`; - // support agent proxy url/string for http/https - try { - // getters can throw - const agentProxy = this[kAgent]?.proxy; - // this should work for URL like objects and strings - proxy = agentProxy?.href || agentProxy; - } catch {} - } - const tls = protocol === "https:" && this[kTls] ? { ...this[kTls], serverName: this[kTls].servername } : undefined; + const getURL = host => { + if (isIPv6(host)) { + host = `[${host}]`; + } - const fetchOptions: any = { - method, - headers: this.getHeaders(), - redirect: "manual", - signal: this[kAbortController]?.signal, - // Timeouts are handled via this.setTimeout. - timeout: false, - // Disable auto gzip/deflate - decompress: false, - keepalive, - }; - let keepOpen = false; - - if (customBody === undefined) { - fetchOptions.duplex = "half"; - keepOpen = true; - } - - if (method !== "GET" && method !== "HEAD" && method !== "OPTIONS") { - const self = this; - if (customBody !== undefined) { - fetchOptions.body = customBody; + if (path.startsWith("http://") || path.startsWith("https://")) { + return [path`${protocol}//${host}${this[kUseDefaultPort] ? "" : ":" + this[kPort]}`]; } else { - fetchOptions.body = async function* () { - while (self[kBodyChunks]?.length > 0) { - yield self[kBodyChunks].shift(); - } + let proxy: string | undefined; + const url = `${protocol}//${host}${this[kUseDefaultPort] ? "" : ":" + this[kPort]}${path}`; + // support agent proxy url/string for http/https + try { + // getters can throw + const agentProxy = this[kAgent]?.proxy; + // this should work for URL like objects and strings + proxy = agentProxy?.href || agentProxy; + } catch {} + return [url, proxy]; + } + }; - if (self[kBodyChunks]?.length === 0) { - self.emit("drain"); - } + let [url, proxy] = getURL(host); - while (!self.finished) { - yield await new Promise(resolve => { - resolveNextChunk = end => { - resolveNextChunk = undefined; - if (end) { - resolve(undefined); - } else { - resolve(self[kBodyChunks].shift()); - } - }; - }); + const go = url => { + const tls = + protocol === "https:" && this[kTls] ? { ...this[kTls], serverName: this[kTls].servername } : undefined; + + const fetchOptions: any = { + method, + headers: this.getHeaders(), + redirect: "manual", + signal: this[kAbortController]?.signal, + // Timeouts are handled via this.setTimeout. + timeout: false, + // Disable auto gzip/deflate + decompress: false, + keepalive, + }; + let keepOpen = false; + + if (customBody === undefined) { + fetchOptions.duplex = "half"; + keepOpen = true; + } + + if (method !== "GET" && method !== "HEAD" && method !== "OPTIONS") { + const self = this; + if (customBody !== undefined) { + fetchOptions.body = customBody; + } else { + fetchOptions.body = async function* () { + while (self[kBodyChunks]?.length > 0) { + yield self[kBodyChunks].shift(); + } if (self[kBodyChunks]?.length === 0) { self.emit("drain"); } - } - handleResponse?.(); - }; - } - } + while (!self.finished) { + yield await new Promise(resolve => { + resolveNextChunk = end => { + resolveNextChunk = undefined; + if (end) { + resolve(undefined); + } else { + resolve(self[kBodyChunks].shift()); + } + }; + }); - if (tls) { - fetchOptions.tls = tls; - } - - if (!!$debug) { - fetchOptions.verbose = true; - } - - if (proxy) { - fetchOptions.proxy = proxy; - } - - const socketPath = this[kSocketPath]; - - if (socketPath) { - fetchOptions.unix = socketPath; - } - - //@ts-ignore - this[kFetchRequest] = fetch(url, fetchOptions) - .then(response => { - if (this.aborted) { - maybeEmitClose(); - return; - } - - handleResponse = () => { - this[kFetchRequest] = null; - this[kClearTimeout](); - handleResponse = undefined; - const prevIsHTTPS = isNextIncomingMessageHTTPS; - isNextIncomingMessageHTTPS = response.url.startsWith("https:"); - var res = (this.res = new IncomingMessage(response, { - [typeSymbol]: NodeHTTPIncomingRequestType.FetchResponse, - [reqSymbol]: this, - })); - isNextIncomingMessageHTTPS = prevIsHTTPS; - res.req = this; - process.nextTick( - (self, res) => { - // If the user did not listen for the 'response' event, then they - // can't possibly read the data, so we ._dump() it into the void - // so that the socket doesn't hang there in a paused state. - if (self.aborted || !self.emit("response", res)) { - res._dump(); + if (self[kBodyChunks]?.length === 0) { + self.emit("drain"); } - }, - this, - res, - ); - maybeEmitClose(); - if (res.statusCode === 304) { - res.complete = true; + } + + handleResponse?.(); + }; + } + } + + if (tls) { + fetchOptions.tls = tls; + } + + if (!!$debug) { + fetchOptions.verbose = true; + } + + if (proxy) { + fetchOptions.proxy = proxy; + } + + const socketPath = this[kSocketPath]; + + if (socketPath) { + fetchOptions.unix = socketPath; + } + + //@ts-ignore + this[kFetchRequest] = fetch(url, fetchOptions) + .then(response => { + if (this.aborted) { maybeEmitClose(); return; } - }; - if (!keepOpen) { - handleResponse(); - } + handleResponse = () => { + this[kFetchRequest] = null; + this[kClearTimeout](); + handleResponse = undefined; + const prevIsHTTPS = isNextIncomingMessageHTTPS; + isNextIncomingMessageHTTPS = response.url.startsWith("https:"); + var res = (this.res = new IncomingMessage(response, { + [typeSymbol]: NodeHTTPIncomingRequestType.FetchResponse, + [reqSymbol]: this, + })); + isNextIncomingMessageHTTPS = prevIsHTTPS; + res.req = this; + process.nextTick( + (self, res) => { + // If the user did not listen for the 'response' event, then they + // can't possibly read the data, so we ._dump() it into the void + // so that the socket doesn't hang there in a paused state. + if (self.aborted || !self.emit("response", res)) { + res._dump(); + } + }, + this, + res, + ); + maybeEmitClose(); + if (res.statusCode === 304) { + res.complete = true; + maybeEmitClose(); + return; + } + }; - onEnd(); - }) - .catch(err => { - // Node treats AbortError separately. - // The "abort" listener on the abort controller should have called this - if (isAbortError(err)) { - return; - } + if (!keepOpen) { + handleResponse(); + } - if (!!$debug) globalReportError(err); + onEnd(); + }) + .catch(err => { + // Node treats AbortError separately. + // The "abort" listener on the abort controller should have called this + if (isAbortError(err)) { + return; + } - this.emit("error", err); - }) - .finally(() => { - if (!keepOpen) { - this[kFetchRequest] = null; - this[kClearTimeout](); + if (!!$debug) globalReportError(err); + + this.emit("error", err); + }) + .finally(() => { + if (!keepOpen) { + this[kFetchRequest] = null; + this[kClearTimeout](); + } + }); + }; + + if (options.lookup) { + options.lookup(options.hostname, (err, address, family) => { + if (err) { + if (!!$debug) globalReportError(err); + this.emit("error", err); + } else { + [url, proxy] = getURL(address); + if (!this.hasHeader("Host")) { + this.setHeader("Host", options.hostname); + } + go(url); } }); + } else { + go(url); + } return true; }; let onEnd = () => {}; - let handleResponse = () => {}; + let handleResponse: (() => void) | undefined = () => {}; const send = () => { this.finished = true; diff --git a/src/napi/napi.zig b/src/napi/napi.zig index b3ec1091a6..d80f5ca889 100644 --- a/src/napi/napi.zig +++ b/src/napi/napi.zig @@ -2137,12 +2137,21 @@ pub const NapiFinalizerTask = struct { } pub fn schedule(this: *NapiFinalizerTask) void { - const vm = this.finalizer.env.?.toJS().bunVM(); + const globalThis = this.finalizer.env.?.toJS(); + + const vm, const thread_kind = globalThis.tryBunVM(); + + if (thread_kind != .main) { + // TODO(@heimskr): do we need to handle the case where the vm is shutting down? + vm.eventLoop().enqueueTaskConcurrent(JSC.ConcurrentTask.create(JSC.Task.init(this))); + return; + } + if (vm.isShuttingDown()) { // Immediate tasks won't run, so we run this as a cleanup hook instead vm.rareData().pushCleanupHook(vm.global, this, runAsCleanupHook); } else { - this.finalizer.env.?.toJS().bunVM().event_loop.enqueueImmediateTask(JSC.Task.init(this)); + globalThis.bunVM().event_loop.enqueueImmediateTask(JSC.Task.init(this)); } } diff --git a/test/js/node/http/node-http.test.ts b/test/js/node/http/node-http.test.ts index 5114c4243e..2775312493 100644 --- a/test/js/node/http/node-http.test.ts +++ b/test/js/node/http/node-http.test.ts @@ -1297,7 +1297,7 @@ describe("server.address should be valid IP", () => { expect(res.socket).toBe(socket); expect(socket._httpMessage).toBe(res); - expect(() => res.assignSocket(socket)).toThrow("ServerResponse has an already assigned socket"); + expect(() => res.assignSocket(socket)).toThrow("Socket already assigned"); socket.emit("close"); doneSocket(); } catch (err) { diff --git a/test/js/node/test/parallel/test-http-expect-continue.js b/test/js/node/test/parallel/test-http-expect-continue.js new file mode 100644 index 0000000000..9d24a9278d --- /dev/null +++ b/test/js/node/test/parallel/test-http-expect-continue.js @@ -0,0 +1,81 @@ +// Copyright Joyent, Inc. and other Node contributors. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit +// persons to whom the Software is furnished to do so, subject to the +// following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN +// NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +// USE OR OTHER DEALINGS IN THE SOFTWARE. + +'use strict'; +const common = require('../common'); +const assert = require('assert'); +const http = require('http'); + +const test_req_body = 'some stuff...\n'; +const test_res_body = 'other stuff!\n'; +let sent_continue = false; +let got_continue = false; + +const handler = common.mustCall((req, res) => { + assert.ok(sent_continue, 'Full response sent before 100 Continue'); + console.error('Server sending full response...'); + res.writeHead(200, { + 'Content-Type': 'text/plain', + 'ABCD': '1' + }); + res.end(test_res_body); +}); + +const server = http.createServer(common.mustNotCall()); +server.on('checkContinue', common.mustCall((req, res) => { + console.error('Server got Expect: 100-continue...'); + res.writeContinue(); + sent_continue = true; + setTimeout(function() { + handler(req, res); + }, 100); +})); +server.listen(0); + + +server.on('listening', common.mustCall(() => { + const req = http.request({ + port: server.address().port, + method: 'POST', + path: '/world', + headers: { 'Expect': '100-continue' } + }); + console.error('Client sending request...'); + let body = ''; + req.on('continue', common.mustCall(() => { + console.error('Client got 100 Continue...'); + got_continue = true; + req.end(test_req_body); + })); + req.on('response', common.mustCall((res) => { + assert.ok(got_continue, 'Full response received before 100 Continue'); + assert.strictEqual(res.statusCode, 200, + `Final status code was ${res.statusCode}, not 200.`); + res.setEncoding('utf8'); + res.on('data', function(chunk) { body += chunk; }); + res.on('end', common.mustCall(() => { + console.error('Got full response.'); + assert.strictEqual(body, test_res_body); + assert.ok('abcd' in res.headers, 'Response headers missing.'); + server.close(); + })); + })); +})); diff --git a/test/js/node/test/parallel/test-http-full-response.js b/test/js/node/test/parallel/test-http-full-response.js new file mode 100644 index 0000000000..0332f91c03 --- /dev/null +++ b/test/js/node/test/parallel/test-http-full-response.js @@ -0,0 +1,85 @@ +// Copyright Joyent, Inc. and other Node contributors. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit +// persons to whom the Software is furnished to do so, subject to the +// following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN +// NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +// USE OR OTHER DEALINGS IN THE SOFTWARE. + +'use strict'; +const common = require('../common'); +const assert = require('assert'); +// This test requires the program 'ab' +const http = require('http'); +const exec = require('child_process').exec; + +const bodyLength = 12345; + +const body = 'c'.repeat(bodyLength); + +if (typeof Bun !== "undefined" && !Bun.which("ab")) { + common.skip("ab not found"); +} + +const server = http.createServer(function(req, res) { + res.writeHead(200, { + 'Content-Length': bodyLength, + 'Content-Type': 'text/plain' + }); + res.end(body); +}); + +function runAb(opts, callback) { + const command = `ab ${opts} http://127.0.0.1:${server.address().port}/`; + exec(command, function(err, stdout, stderr) { + if (err) { + if (/ab|apr/i.test(stderr)) { + common.printSkipMessage(`problem spawning \`ab\`.\n${stderr}`); + process.reallyExit(0); + } + throw err; + } + + let m = /Document Length:\s*(\d+) bytes/i.exec(stdout); + const documentLength = parseInt(m[1]); + + m = /Complete requests:\s*(\d+)/i.exec(stdout); + const completeRequests = parseInt(m[1]); + + m = /HTML transferred:\s*(\d+) bytes/i.exec(stdout); + const htmlTransferred = parseInt(m[1]); + + assert.strictEqual(bodyLength, documentLength); + assert.strictEqual(completeRequests * documentLength, htmlTransferred); + + if (callback) callback(); + }); +} + +server.listen(0, common.mustCall(function() { + runAb('-c 1 -n 10', common.mustCall(function() { + console.log('-c 1 -n 10 okay'); + + runAb('-c 1 -n 100', common.mustCall(function() { + console.log('-c 1 -n 100 okay'); + + runAb('-c 1 -n 1000', common.mustCall(function() { + console.log('-c 1 -n 1000 okay'); + server.close(); + })); + })); + })); +})); diff --git a/test/js/node/test/parallel/test-http-outgoing-message-write-callback.js b/test/js/node/test/parallel/test-http-outgoing-message-write-callback.js new file mode 100644 index 0000000000..3a32285faa --- /dev/null +++ b/test/js/node/test/parallel/test-http-outgoing-message-write-callback.js @@ -0,0 +1,39 @@ +'use strict'; + +const common = require('../common'); + +// This test ensures that the callback of `OutgoingMessage.prototype.write()` is +// called also when writing empty chunks or when the message has no body. + +const assert = require('assert'); +const http = require('http'); +const stream = require('stream'); + +for (const method of ['GET, HEAD']) { + const expected = ['a', 'b', '', Buffer.alloc(0), 'c']; + const results = []; + + const writable = new stream.Writable({ + write(chunk, encoding, callback) { + callback(); + } + }); + + const res = new http.ServerResponse({ + method: method, + httpVersionMajor: 1, + httpVersionMinor: 1 + }); + + res.assignSocket(writable); + + for (const chunk of expected) { + res.write(chunk, () => { + results.push(chunk); + }); + } + + res.end(common.mustCall(() => { + assert.deepStrictEqual(results, expected); + })); +} diff --git a/test/js/node/test/parallel/test-http-wget.js b/test/js/node/test/parallel/test-http-wget.js new file mode 100644 index 0000000000..0abe850d3f --- /dev/null +++ b/test/js/node/test/parallel/test-http-wget.js @@ -0,0 +1,79 @@ +// Copyright Joyent, Inc. and other Node contributors. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit +// persons to whom the Software is furnished to do so, subject to the +// following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN +// NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +// USE OR OTHER DEALINGS IN THE SOFTWARE. + +'use strict'; +const common = require('../common'); +const assert = require('assert'); +const net = require('net'); +const http = require('http'); + +// `wget` sends an HTTP/1.0 request with Connection: Keep-Alive +// +// Sending back a chunked response to an HTTP/1.0 client would be wrong, +// so what has to happen in this case is that the connection is closed +// by the server after the entity body if the Content-Length was not +// sent. +// +// If the Content-Length was sent, we can probably safely honor the +// keep-alive request, even though HTTP 1.0 doesn't say that the +// connection can be kept open. Presumably any client sending this +// header knows that it is extending HTTP/1.0 and can handle the +// response. We don't test that here however, just that if the +// content-length is not provided, that the connection is in fact +// closed. + +const server = http.createServer((req, res) => { + res.writeHead(200, { 'Content-Type': 'text/plain' }); + res.write('hello '); + res.write('world\n'); + res.end(); +}); +server.listen(0); + +server.on('listening', common.mustCall(() => { + const c = net.createConnection(server.address().port); + let server_response = ''; + + c.setEncoding('utf8'); + + c.on('connect', () => { + c.write('GET / HTTP/1.0\r\n' + + 'Host: localhost\r\n' + + 'Connection: Keep-Alive\r\n\r\n'); + }); + + c.on('data', (chunk) => { + console.log(chunk); + server_response += chunk; + }); + + c.on('end', common.mustCall(() => { + const m = server_response.split('\r\n\r\n'); + assert.strictEqual(m[1], 'hello world\n'); + console.log('got end'); + c.end(); + })); + + c.on('close', common.mustCall(() => { + console.log('got close'); + server.close(); + })); +})); diff --git a/test/js/node/test/parallel/test-http-write-callbacks.js b/test/js/node/test/parallel/test-http-write-callbacks.js new file mode 100644 index 0000000000..390fddf1dc --- /dev/null +++ b/test/js/node/test/parallel/test-http-write-callbacks.js @@ -0,0 +1,96 @@ +// Copyright Joyent, Inc. and other Node contributors. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to permit +// persons to whom the Software is furnished to do so, subject to the +// following conditions: +// +// The above copyright notice and this permission notice shall be included +// in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN +// NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +// OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +// USE OR OTHER DEALINGS IN THE SOFTWARE. + +'use strict'; +const common = require('../common'); +const assert = require('assert'); + +const http = require('http'); + +let serverEndCb = false; +let serverIncoming = ''; +const serverIncomingExpect = 'bazquuxblerg'; + +let clientEndCb = false; +let clientIncoming = ''; +const clientIncomingExpect = 'asdffoobar'; + +process.on('exit', () => { + assert(serverEndCb); + assert.strictEqual(serverIncoming, serverIncomingExpect); + assert(clientEndCb); + assert.strictEqual(clientIncoming, clientIncomingExpect); + console.log('ok'); +}); + +// Verify that we get a callback when we do res.write(..., cb) +const server = http.createServer((req, res) => { + res.statusCode = 400; + res.end('Bad Request.\nMust send Expect:100-continue\n'); +}); + +server.on('checkContinue', (req, res) => { + server.close(); + assert.strictEqual(req.method, 'PUT'); + res.writeContinue(() => { + // Continue has been written + req.on('end', () => { + res.write('asdf', common.mustSucceed(() => { + res.write('foo', 'ascii', common.mustSucceed(() => { + res.end(Buffer.from('bar'), 'buffer', common.mustSucceed(() => { + serverEndCb = true; + })); + })); + })); + }); + }); + + req.setEncoding('ascii'); + req.on('data', (c) => { + serverIncoming += c; + }); +}); + +server.listen(0, function() { + const req = http.request({ + port: this.address().port, + method: 'PUT', + headers: { 'expect': '100-continue' } + }); + req.on('continue', () => { + // ok, good to go. + req.write('YmF6', 'base64', common.mustSucceed(() => { + req.write(Buffer.from('quux'), common.mustSucceed(() => { + req.end('626c657267', 'hex', common.mustSucceed(() => { + clientEndCb = true; + })); + })); + })); + }); + req.on('response', (res) => { + // This should not come until after the end is flushed out + assert(clientEndCb); + res.setEncoding('ascii'); + res.on('data', (c) => { + clientIncoming += c; + }); + }); +});