From 1c648063facdbef7cfa8379b4cb807838b0eea10 Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Wed, 31 Jul 2024 05:41:54 +0000 Subject: [PATCH] fix(tls/socket/fetch) shutdown fix + ref counted context (#12925) Co-authored-by: Jarred Sumner --- packages/bun-usockets/src/bsd.c | 4 + packages/bun-usockets/src/context.c | 43 ++- packages/bun-usockets/src/crypto/openssl.c | 192 +++++-------- packages/bun-usockets/src/internal/internal.h | 13 +- .../bun-usockets/src/internal/loop_data.h | 1 + .../src/internal/networking/bsd.h | 5 +- packages/bun-usockets/src/libusockets.h | 8 +- packages/bun-usockets/src/loop.c | 18 +- packages/bun-usockets/src/socket.c | 9 +- src/bun.js/api/bun/socket.zig | 263 ++++++++---------- src/bun.js/api/sockets.classes.ts | 4 +- src/bun.js/web_worker.zig | 15 +- src/cli/create_command.zig | 8 +- src/cli/upgrade_command.zig | 4 +- src/compile_target.zig | 2 +- src/deps/uws.zig | 7 +- src/http.zig | 205 +++++++------- src/install/install.zig | 6 +- src/js/node/http2.ts | 4 +- test/cli/hot/watch.test.ts | 13 +- .../test/dev-server-ssr-100.test.ts | 2 +- ...node-http-response-write-encode-fixture.js | 82 ++++++ test/js/node/http/node-http.test.ts | 79 +----- .../worker_threads/worker_threads.test.ts | 2 +- test/js/third_party/grpc-js/common.ts | 1 + .../js/web/fetch/fetch-leak-test-fixture-4.js | 6 +- test/js/web/fetch/fetch-leak.test.js | 2 +- test/js/web/fetch/fetch.test.ts | 2 +- test/js/web/websocket/websocket.test.js | 4 +- test/js/web/workers/worker.test.ts | 6 +- test/regression/issue/07500/07500.test.ts | 2 +- 31 files changed, 519 insertions(+), 493 deletions(-) create mode 100644 test/js/node/http/node-http-response-write-encode-fixture.js diff --git a/packages/bun-usockets/src/bsd.c b/packages/bun-usockets/src/bsd.c index c04b27fb0e..7e010b22f4 100644 --- a/packages/bun-usockets/src/bsd.c +++ b/packages/bun-usockets/src/bsd.c @@ -26,6 +26,10 @@ #include #ifndef _WIN32 +// Necessary for the stdint include +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif #include #include #include diff --git a/packages/bun-usockets/src/context.c b/packages/bun-usockets/src/context.c index aad1fb3f63..585b039966 100644 --- a/packages/bun-usockets/src/context.c +++ b/packages/bun-usockets/src/context.c @@ -15,18 +15,18 @@ * limitations under the License. */ -#include "libusockets.h" #include "internal/internal.h" +#include "libusockets.h" +#include #include #include #include - #ifndef _WIN32 #include #endif #define CONCURRENT_CONNECTIONS 2 - +// clang-format off int default_is_low_prio_handler(struct us_socket_t *s) { return 0; } @@ -44,7 +44,7 @@ int us_raw_root_certs(struct us_cert_string_t**out){ void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls) { /* us_listen_socket_t extends us_socket_t so we close in similar ways */ if (!us_socket_is_closed(0, &ls->s)) { - us_internal_socket_context_unlink_listen_socket(ls->s.context, ls); + us_internal_socket_context_unlink_listen_socket(ssl, ls->s.context, ls); us_poll_stop((struct us_poll_t *) &ls->s, ls->s.context->loop); bsd_close_socket(us_poll_fd((struct us_poll_t *) &ls->s)); @@ -72,12 +72,12 @@ void us_socket_context_close(int ssl, struct us_socket_context_t *context) { struct us_socket_t *s = context->head_sockets; while (s) { struct us_socket_t *nextS = s->next; - us_socket_close(ssl, s, 0, 0); + us_socket_close(ssl, s, LIBUS_SOCKET_CLOSE_CODE_CLEAN_SHUTDOWN, 0); s = nextS; } } -void us_internal_socket_context_unlink_listen_socket(struct us_socket_context_t *context, struct us_listen_socket_t *ls) { +void us_internal_socket_context_unlink_listen_socket(int ssl, struct us_socket_context_t *context, struct us_listen_socket_t *ls) { /* We have to properly update the iterator used to sweep sockets for timeouts */ if (ls == (struct us_listen_socket_t *) context->iterator) { context->iterator = ls->s.next; @@ -95,9 +95,10 @@ void us_internal_socket_context_unlink_listen_socket(struct us_socket_context_t ls->s.next->prev = ls->s.prev; } } + us_socket_context_unref(ssl, context); } -void us_internal_socket_context_unlink_socket(struct us_socket_context_t *context, struct us_socket_t *s) { +void us_internal_socket_context_unlink_socket(int ssl, struct us_socket_context_t *context, struct us_socket_t *s) { /* We have to properly update the iterator used to sweep sockets for timeouts */ if (s == context->iterator) { context->iterator = s->next; @@ -115,6 +116,7 @@ void us_internal_socket_context_unlink_socket(struct us_socket_context_t *contex s->next->prev = s->prev; } } + us_socket_context_unref(ssl, context); } /* We always add in the top, so we don't modify any s.next */ @@ -126,6 +128,7 @@ void us_internal_socket_context_link_listen_socket(struct us_socket_context_t *c context->head_listen_sockets->s.prev = &ls->s; } context->head_listen_sockets = ls; + context->ref_count++; } /* We always add in the top, so we don't modify any s.next */ @@ -137,6 +140,7 @@ void us_internal_socket_context_link_socket(struct us_socket_context_t *context, context->head_sockets->prev = s; } context->head_sockets = s; + context->ref_count++; } struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *context) { @@ -231,6 +235,7 @@ struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t * struct us_socket_context_t *context = us_calloc(1, sizeof(struct us_socket_context_t) + context_ext_size); context->loop = loop; context->is_low_prio = default_is_low_prio_handler; + context->ref_count = 1; us_internal_loop_link(loop, context); @@ -252,6 +257,7 @@ struct us_socket_context_t *us_create_bun_socket_context(int ssl, struct us_loop struct us_socket_context_t *context = us_calloc(1, sizeof(struct us_socket_context_t) + context_ext_size); context->loop = loop; context->is_low_prio = default_is_low_prio_handler; + context->ref_count = 1; us_internal_loop_link(loop, context); @@ -272,7 +278,8 @@ struct us_bun_verify_error_t us_socket_verify_error(int ssl, struct us_socket_t } -void us_socket_context_free(int ssl, struct us_socket_context_t *context) { + +void us_internal_socket_context_free(int ssl, struct us_socket_context_t *context) { #ifndef LIBUS_NO_SSL if (ssl) { /* This function will call us again with SSL=false */ @@ -285,7 +292,23 @@ void us_socket_context_free(int ssl, struct us_socket_context_t *context) { * This is the opposite order compared to when creating the context - SSL code is cleaning up before non-SSL */ us_internal_loop_unlink(context->loop, context); - us_free(context); + /* Link this context to the close-list and let it be deleted after this iteration */ + context->next = context->loop->data.closed_context_head; + context->loop->data.closed_context_head = context; +} + +void us_socket_context_ref(int ssl, struct us_socket_context_t *context) { + context->ref_count++; +} + +void us_socket_context_unref(int ssl, struct us_socket_context_t *context) { + if (--context->ref_count == 0) { + us_internal_socket_context_free(ssl, context); + } +} + +void us_socket_context_free(int ssl, struct us_socket_context_t *context) { + us_socket_context_unref(ssl, context); } struct us_listen_socket_t *us_socket_context_listen(int ssl, struct us_socket_context_t *context, const char *host, int port, int options, int socket_ext_size) { @@ -709,7 +732,7 @@ struct us_socket_t *us_socket_context_adopt_socket(int ssl, struct us_socket_con if (s->low_prio_state != 1) { /* This properly updates the iterator if in on_timeout */ - us_internal_socket_context_unlink_socket(s->context, s); + us_internal_socket_context_unlink_socket(ssl, s->context, s); } diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index ea3b03f331..9ad345cd05 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -14,8 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +// clang-format off #if (defined(LIBUS_USE_OPENSSL) || defined(LIBUS_USE_WOLFSSL)) + + +#include "internal/internal.h" +#include "libusockets.h" +#include + /* These are in sni_tree.cpp */ void *sni_new(); void sni_free(void *sni, void (*cb)(void *)); @@ -23,10 +29,6 @@ int sni_add(void *sni, const char *hostname, void *user); void *sni_remove(void *sni, const char *hostname); void *sni_find(void *sni, const char *hostname); -#include "internal/internal.h" -#include "libusockets.h" -#include - /* This module contains the entire OpenSSL implementation * of the SSL socket and socket context interfaces. */ #ifdef LIBUS_USE_OPENSSL @@ -71,10 +73,6 @@ struct us_internal_ssl_socket_context_t { // socket context SSL_CTX *ssl_context; int is_parent; -#if ALLOW_SERVER_RENEGOTIATION - unsigned int client_renegotiation_limit; - unsigned int client_renegotiation_window; -#endif /* These decorate the base implementation */ struct us_internal_ssl_socket_t *(*on_open)(struct us_internal_ssl_socket_t *, int is_client, char *ip, @@ -108,15 +106,9 @@ enum { struct us_internal_ssl_socket_t { struct us_socket_t s; SSL *ssl; // this _must_ be the first member after s -#if ALLOW_SERVER_RENEGOTIATION - unsigned int client_pending_renegotiations; - uint64_t last_ssl_renegotiation; - unsigned int is_client : 1; -#endif unsigned int ssl_write_wants_read : 1; // we use this for now unsigned int ssl_read_wants_write : 1; unsigned int handshake_state : 2; - unsigned int received_ssl_shutdown : 1; }; int passphrase_cb(char *buf, int size, int rwflag, void *u) { @@ -194,16 +186,9 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, (struct loop_ssl_data *)loop->data.ssl_data; s->ssl = SSL_new(context->ssl_context); -#if ALLOW_SERVER_RENEGOTIATION - s->client_pending_renegotiations = context->client_renegotiation_limit; - s->last_ssl_renegotiation = 0; - s->is_client = is_client ? 1 : 0; - -#endif s->ssl_write_wants_read = 0; s->ssl_read_wants_write = 0; s->handshake_state = HANDSHAKE_PENDING; - s->received_ssl_shutdown = 0; SSL_set_bio(s->ssl, loop_ssl_data->shared_rbio, loop_ssl_data->shared_wbio); // if we allow renegotiation, we need to set the mode here @@ -213,24 +198,18 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, // this can be a DoS vector for servers, so we enable it using a limit // we do not use ssl_renegotiate_freely, since ssl_renegotiate_explicit is // more performant when using BoringSSL -#if ALLOW_SERVER_RENEGOTIATION - if (context->client_renegotiation_limit) { - SSL_set_renegotiate_mode(s->ssl, ssl_renegotiate_explicit); - } else { - SSL_set_renegotiate_mode(s->ssl, ssl_renegotiate_never); - } -#endif + BIO_up_ref(loop_ssl_data->shared_rbio); BIO_up_ref(loop_ssl_data->shared_wbio); if (is_client) { -#if ALLOW_SERVER_RENEGOTIATION == 0 SSL_set_renegotiate_mode(s->ssl, ssl_renegotiate_explicit); -#endif SSL_set_connect_state(s->ssl); } else { SSL_set_accept_state(s->ssl); + // we do not allow renegotiation on the server side (should be the default for BoringSSL, but we set to make openssl compatible) + SSL_set_renegotiate_mode(s->ssl, ssl_renegotiate_never); } struct us_internal_ssl_socket_t *result = @@ -246,6 +225,36 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, return result; } +/// @brief Complete the shutdown or do a fast shutdown when needed, this should only be called before closing the socket +/// @param s +void us_internal_handle_shutdown(struct us_internal_ssl_socket_t *s, int force_fast_shutdown) { + // if we are already shutdown or in the middle of a handshake we dont need to do anything + if(!s->ssl || us_socket_is_shut_down(0, &s->s)) return; + + // we are closing the socket but did not sent a shutdown yet + int state = SSL_get_shutdown(s->ssl); + int sent_shutdown = state & SSL_SENT_SHUTDOWN; + int received_shutdown = state & SSL_RECEIVED_SHUTDOWN; + // if we are missing a shutdown call, we need to do a fast shutdown here + if(!sent_shutdown || !received_shutdown) { + // Zero means that we should wait for the peer to close the connection + // but we are already closing the connection so we do a fast shutdown here + int ret = SSL_shutdown(s->ssl); + if(ret == 0 && force_fast_shutdown) { + // do a fast shutdown (dont wait for peer) + ret = SSL_shutdown(s->ssl); + } + if(ret < 0) { + // we got some error here, but we dont care about it, we are closing the socket + int err = SSL_get_error(s->ssl, ret); + if (err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) { + // clear + ERR_clear_error(); + } + } + } +} + void us_internal_on_ssl_handshake( struct us_internal_ssl_socket_context_t *context, void (*on_handshake)(struct us_internal_ssl_socket_t *, int success, @@ -259,6 +268,8 @@ void us_internal_on_ssl_handshake( struct us_internal_ssl_socket_t * us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, void *reason) { + + if (s->handshake_state != HANDSHAKE_COMPLETED) { // if we have some pending handshake we cancel it and try to check the // latest handshake error this way we will always call on_handshake with the @@ -269,8 +280,14 @@ us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, us_internal_trigger_handshake_callback(s, 0); } - return (struct us_internal_ssl_socket_t *)us_socket_close( - 0, (struct us_socket_t *)s, code, reason); + // if we are in the middle of a close_notify we need to finish it (code != 0 forces a fast shutdown) + us_internal_handle_shutdown(s, code != 0); + + // only close the socket if we are not in the middle of a handshake + if(!s->ssl || SSL_get_shutdown(s->ssl) & SSL_RECEIVED_SHUTDOWN) { + return (struct us_internal_ssl_socket_t *)us_socket_close(0, (struct us_socket_t *)s, code, reason); + } + return s; } void us_internal_trigger_handshake_callback(struct us_internal_ssl_socket_t *s, @@ -292,26 +309,7 @@ int us_internal_ssl_renegotiate(struct us_internal_ssl_socket_t *s) { // if is a server and we have no pending renegotiation we can check // the limits s->handshake_state = HANDSHAKE_RENEGOTIATION_PENDING; -#if ALLOW_SERVER_RENEGOTIATION - if (!s->is_client && !SSL_renegotiate_pending(s->ssl)) { - uint64_t now = time(NULL); - struct us_internal_ssl_socket_context_t *context = - (struct us_internal_ssl_socket_context_t *)us_socket_context(0, &s->s); - // if is not the first time we negotiate and we are outside the time - // window, reset the limits - if (s->last_ssl_renegotiation && (now - s->last_ssl_renegotiation) >= - context->client_renegotiation_window) { - // reset the limits - s->client_pending_renegotiations = context->client_renegotiation_limit; - } - // if we have no more renegotiations, we should close the connection - if (s->client_pending_renegotiations == 0) { - return 0; - } - s->last_ssl_renegotiation = now; - s->client_pending_renegotiations--; - } -#endif + if (!SSL_renegotiate(s->ssl)) { // we failed to renegotiate us_internal_trigger_handshake_callback(s, 0); @@ -347,7 +345,6 @@ void us_internal_update_handshake(struct us_internal_ssl_socket_t *s) { int result = SSL_do_handshake(s->ssl); if (SSL_get_shutdown(s->ssl) & SSL_RECEIVED_SHUTDOWN) { - s->received_ssl_shutdown = 1; us_internal_ssl_socket_close(s, 0, NULL); return; } @@ -387,16 +384,15 @@ ssl_on_close(struct us_internal_ssl_socket_t *s, int code, void *reason) { struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *)us_socket_context(0, &s->s); - SSL_free(s->ssl); - - return context->on_close(s, code, reason); + struct us_internal_ssl_socket_t * ret = context->on_close(s, code, reason); + SSL_free(s->ssl); // free SSL after on_close + s->ssl = NULL; // set to NULL + return ret; } struct us_internal_ssl_socket_t * ssl_on_end(struct us_internal_ssl_socket_t *s) { // whatever state we are in, a TCP FIN is always an answered shutdown - - /* Todo: this should report CLEANLY SHUTDOWN as reason */ return us_internal_ssl_socket_close(s, 0, NULL); } @@ -420,31 +416,13 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, loop_ssl_data->ssl_socket = &s->s; loop_ssl_data->msg_more = 0; - if (us_socket_is_closed(0, &s->s) || s->received_ssl_shutdown) { + if (us_socket_is_closed(0, &s->s)) { return NULL; } if (us_internal_ssl_socket_is_shut_down(s)) { - - int ret = 0; - if ((ret = SSL_shutdown(s->ssl)) == 1) { - // two phase shutdown is complete here - - /* Todo: this should also report some kind of clean shutdown */ - return us_internal_ssl_socket_close(s, 0, NULL); - } else if (ret < 0) { - - int err = SSL_get_error(s->ssl, ret); - - if (err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) { - // we need to clear the error queue in case these added to the thread - // local queue - ERR_clear_error(); - } - } - - // no further processing of data when in shutdown state - return s; + us_internal_ssl_socket_close(s, 0, NULL); + return NULL; } // bug checking: this loop needs a lot of attention and clean-ups and @@ -452,17 +430,12 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, int read = 0; restart: // read until shutdown - while (!s->received_ssl_shutdown) { + while (1) { int just_read = SSL_read(s->ssl, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING + read, LIBUS_RECV_BUFFER_LENGTH - read); - // we need to check if we received a shutdown here - if (SSL_get_shutdown(s->ssl) & SSL_RECEIVED_SHUTDOWN) { - s->received_ssl_shutdown = 1; - // we will only close after we handle the data and errors - } - + if (just_read <= 0) { int err = SSL_get_error(s->ssl, just_read); // as far as I know these are the only errors we want to handle @@ -477,8 +450,9 @@ restart: // clean and close renegotiation failed err = SSL_ERROR_SSL; } else if (err == SSL_ERROR_ZERO_RETURN) { - // zero return can be EOF/FIN, if we have data just signal on_data and - // close + // Remotely-Initiated Shutdown + // See: https://www.openssl.org/docs/manmaster/man3/SSL_shutdown.html + if (read) { context = (struct us_internal_ssl_socket_context_t *)us_socket_context( @@ -488,11 +462,12 @@ restart: s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read); if (!s || us_socket_is_closed(0, &s->s)) { - return s; + return NULL; // stop processing data } } // terminate connection here - return us_internal_ssl_socket_close(s, 0, NULL); + us_internal_ssl_socket_close(s, 0, NULL); + return NULL; // stop processing data } if (err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) { @@ -501,7 +476,8 @@ restart: } // terminate connection here - return us_internal_ssl_socket_close(s, 0, NULL); + us_internal_ssl_socket_close(s, 0, NULL); + return NULL; // stop processing data } else { // emit the data we have and exit @@ -527,7 +503,7 @@ restart: s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read); if (!s || us_socket_is_closed(0, &s->s)) { - return s; + return NULL; // stop processing data } break; @@ -550,19 +526,13 @@ restart: s = context->on_data( s, loop_ssl_data->ssl_read_output + LIBUS_RECV_BUFFER_PADDING, read); if (!s || us_socket_is_closed(0, &s->s)) { - return s; + return NULL; } read = 0; goto restart; } } - - // we received the shutdown after reading so we close - if (s->received_ssl_shutdown) { - us_internal_ssl_socket_close(s, 0, NULL); - return NULL; - } // trigger writable if we failed last write with want read if (s->ssl_write_wants_read) { s->ssl_write_wants_read = 0; @@ -576,7 +546,7 @@ restart: &s->s); // cast here! // if we are closed here, then exit if (!s || us_socket_is_closed(0, &s->s)) { - return s; + return NULL; } } @@ -1032,7 +1002,7 @@ long us_internal_verify_peer_certificate( // NOLINT(runtime/int) struct us_bun_verify_error_t us_internal_verify_error(struct us_internal_ssl_socket_t *s) { - if (us_socket_is_closed(0, &s->s) || us_internal_ssl_socket_is_shut_down(s)) { + if (!s->ssl || us_socket_is_closed(0, &s->s) || us_internal_ssl_socket_is_shut_down(s)) { return (struct us_bun_verify_error_t){ .error = 0, .code = NULL, .reason = NULL}; } @@ -1317,10 +1287,6 @@ void us_bun_internal_ssl_socket_context_add_server_name( /* We do not want to hold any nullptr's in our SNI tree */ if (ssl_context) { -#if ALLOW_SERVER_RENEGOTIATION - context->client_renegotiation_limit = options.client_renegotiation_limit; - context->client_renegotiation_window = options.client_renegotiation_window; -#endif if (sni_add(context->sni, hostname_pattern, ssl_context)) { /* If we already had that name, ignore */ free_ssl_context(ssl_context); @@ -1469,10 +1435,6 @@ us_internal_bun_create_ssl_socket_context( context->on_handshake = NULL; context->handshake_data = NULL; -#if ALLOW_SERVER_RENEGOTIATION - context->client_renegotiation_limit = options.client_renegotiation_limit; - context->client_renegotiation_window = options.client_renegotiation_window; -#endif /* We, as parent context, may ignore data */ context->sc.is_low_prio = (int (*)(struct us_socket_t *))ssl_is_low_prio; @@ -1503,7 +1465,7 @@ void us_internal_ssl_socket_context_free( sni_free(context->sni, sni_hostname_destructor); } - us_socket_context_free(0, &context->sc); + us_internal_socket_context_free(0, &context->sc); } struct us_listen_socket_t *us_internal_ssl_socket_context_listen( @@ -1714,7 +1676,7 @@ void *us_internal_connecting_ssl_socket_ext(struct us_connecting_socket_t *s) { } int us_internal_ssl_socket_is_shut_down(struct us_internal_ssl_socket_t *s) { - return us_socket_is_shut_down(0, &s->s) || + return !s->ssl || us_socket_is_shut_down(0, &s->s) || SSL_get_shutdown(s->ssl) & SSL_SENT_SHUTDOWN; } @@ -1740,11 +1702,8 @@ void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s) { loop_ssl_data->ssl_socket = &s->s; loop_ssl_data->msg_more = 0; - // sets SSL_SENT_SHUTDOWN no matter what (not actually true if error!) + // sets SSL_SENT_SHUTDOWN and waits for the other side to do the same int ret = SSL_shutdown(s->ssl); - if (ret == 0) { - ret = SSL_shutdown(s->ssl); - } if (SSL_in_init(s->ssl) || SSL_get_quiet_shutdown(s->ssl)) { // when SSL_in_init or quiet shutdown in BoringSSL, we call shutdown @@ -2049,7 +2008,6 @@ us_socket_context_on_socket_connect_error( socket->ssl_write_wants_read = 0; socket->ssl_read_wants_write = 0; socket->handshake_state = HANDSHAKE_PENDING; - socket->received_ssl_shutdown = 0; return socket; } diff --git a/packages/bun-usockets/src/internal/internal.h b/packages/bun-usockets/src/internal/internal.h index f0a34a823d..0f9b199824 100644 --- a/packages/bun-usockets/src/internal/internal.h +++ b/packages/bun-usockets/src/internal/internal.h @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// clang-format off #pragma once #ifndef INTERNAL_H #define INTERNAL_H @@ -144,11 +145,15 @@ void us_internal_free_loop_ssl_data(struct us_loop_t *loop); /* Socket context related */ void us_internal_socket_context_link_socket(struct us_socket_context_t *context, struct us_socket_t *s); -void us_internal_socket_context_unlink_socket( +void us_internal_socket_context_unlink_socket(int ssl, struct us_socket_context_t *context, struct us_socket_t *s); void us_internal_socket_after_resolve(struct us_connecting_socket_t *s); void us_internal_socket_after_open(struct us_socket_t *s, int error); +struct us_internal_ssl_socket_t * +us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, + void *reason); + int us_internal_handle_dns_results(struct us_loop_t *loop); /* Sockets are polls */ @@ -244,12 +249,13 @@ struct us_listen_socket_t { /* Listen sockets are keps in their own list */ void us_internal_socket_context_link_listen_socket( struct us_socket_context_t *context, struct us_listen_socket_t *s); -void us_internal_socket_context_unlink_listen_socket( +void us_internal_socket_context_unlink_listen_socket(int ssl, struct us_socket_context_t *context, struct us_listen_socket_t *s); struct us_socket_context_t { alignas(LIBUS_EXT_ALIGNMENT) struct us_loop_t *loop; uint32_t global_tick; + uint32_t ref_count; unsigned char timestamp; unsigned char long_timestamp; struct us_socket_t *head_sockets; @@ -280,7 +286,8 @@ struct us_internal_ssl_socket_t; typedef void (*us_internal_on_handshake_t)( struct us_internal_ssl_socket_t *, int success, struct us_bun_verify_error_t verify_error, void *custom_data); - + +void us_internal_socket_context_free(int ssl, struct us_socket_context_t *context); /* SNI functions */ void us_internal_ssl_socket_context_add_server_name( struct us_internal_ssl_socket_context_t *context, diff --git a/packages/bun-usockets/src/internal/loop_data.h b/packages/bun-usockets/src/internal/loop_data.h index ed6ec96ad3..1f0a3adb76 100644 --- a/packages/bun-usockets/src/internal/loop_data.h +++ b/packages/bun-usockets/src/internal/loop_data.h @@ -27,6 +27,7 @@ struct us_internal_loop_data_t { int last_write_failed; struct us_socket_context_t *head; struct us_socket_context_t *iterator; + struct us_socket_context_t *closed_context_head; char *recv_buf; char *send_buf; void *ssl_data; diff --git a/packages/bun-usockets/src/internal/networking/bsd.h b/packages/bun-usockets/src/internal/networking/bsd.h index 9e9b421011..08704d9e5b 100644 --- a/packages/bun-usockets/src/internal/networking/bsd.h +++ b/packages/bun-usockets/src/internal/networking/bsd.h @@ -17,6 +17,7 @@ #ifndef BSD_H #define BSD_H +#pragma once // top-most wrapper of bsd-like syscalls @@ -25,7 +26,7 @@ #include "libusockets.h" -#ifdef _WIN32 +#ifdef _WIN32 #ifndef NOMINMAX #define NOMINMAX #endif @@ -34,7 +35,7 @@ #pragma comment(lib, "ws2_32.lib") #define SETSOCKOPT_PTR_TYPE const char * #define LIBUS_SOCKET_ERROR INVALID_SOCKET -#else +#else /* POSIX */ #ifndef _GNU_SOURCE #define _GNU_SOURCE #endif diff --git a/packages/bun-usockets/src/libusockets.h b/packages/bun-usockets/src/libusockets.h index 248fb16f3b..f6aad3c647 100644 --- a/packages/bun-usockets/src/libusockets.h +++ b/packages/bun-usockets/src/libusockets.h @@ -15,7 +15,7 @@ * limitations under the License. */ // clang-format off - +#pragma once #ifndef us_calloc #define us_calloc calloc #endif @@ -49,6 +49,7 @@ #define LIBUS_EXT_ALIGNMENT 16 #define ALLOW_SERVER_RENEGOTIATION 0 +#define LIBUS_SOCKET_CLOSE_CODE_CLEAN_SHUTDOWN 0 #define LIBUS_SOCKET_CLOSE_CODE_CONNECTION_RESET 1 /* Define what a socket descriptor is based on platform */ @@ -229,8 +230,11 @@ struct us_socket_context_t *us_create_socket_context(int ssl, struct us_loop_t * struct us_socket_context_t *us_create_bun_socket_context(int ssl, struct us_loop_t *loop, int ext_size, struct us_bun_socket_context_options_t options); -/* Delete resources allocated at creation time. */ +/* Delete resources allocated at creation time (will call unref now and only free when ref count == 0). */ void us_socket_context_free(int ssl, struct us_socket_context_t *context); +void us_socket_context_ref(int ssl, struct us_socket_context_t *context); +void us_socket_context_unref(int ssl, struct us_socket_context_t *context); + struct us_bun_verify_error_t us_socket_verify_error(int ssl, struct us_socket_t *context); /* Setters of various async callbacks */ void us_socket_context_on_open(int ssl, struct us_socket_context_t *context, diff --git a/packages/bun-usockets/src/loop.c b/packages/bun-usockets/src/loop.c index 0d1128446e..cfd414117b 100644 --- a/packages/bun-usockets/src/loop.c +++ b/packages/bun-usockets/src/loop.c @@ -47,6 +47,8 @@ void us_internal_loop_data_init(struct us_loop_t *loop, void (*wakeup_cb)(struct loop->data.parent_ptr = 0; loop->data.parent_tag = 0; + loop->data.closed_context_head = 0; + loop->data.wakeup_async = us_internal_create_async(loop, 1, 0); us_internal_async_set(loop->data.wakeup_async, (void (*)(struct us_internal_async *)) wakeup_cb); } @@ -234,6 +236,15 @@ void us_internal_free_closed_sockets(struct us_loop_t *loop) { loop->data.closed_connecting_head = 0; } +void us_internal_free_closed_contexts(struct us_loop_t *loop) { + for (struct us_socket_context_t *ctx = loop->data.closed_context_head; ctx; ) { + struct us_socket_context_t *next = ctx->next; + us_free(ctx); + ctx = next; + } + loop->data.closed_context_head = 0; +} + void sweep_timer_cb(struct us_internal_callback_t *cb) { us_internal_timer_sweep(cb->loop); } @@ -253,6 +264,7 @@ void us_internal_loop_pre(struct us_loop_t *loop) { void us_internal_loop_post(struct us_loop_t *loop) { us_internal_handle_dns_results(loop); us_internal_free_closed_sockets(loop); + us_internal_free_closed_contexts(loop); loop->data.post_cb(loop); } @@ -356,7 +368,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) s->context->loop->data.low_prio_budget--; /* Still having budget for this iteration - do normal processing */ } else { us_poll_change(&s->p, us_socket_context(0, s)->loop, us_poll_events(&s->p) & LIBUS_SOCKET_WRITABLE); - us_internal_socket_context_unlink_socket(s->context, s); + us_internal_socket_context_unlink_socket(0, s->context, s); /* Link this socket to the low-priority queue - we use a LIFO queue, to prioritize newer clients that are * maybe not already timeouted - sounds unfair, but works better in real-life with smaller client-timeouts @@ -411,7 +423,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) if (us_socket_is_shut_down(0, s)) { /* We got FIN back after sending it */ /* Todo: We should give "CLEAN SHUTDOWN" as reason here */ - s = us_socket_close(0, s, 0, NULL); + s = us_socket_close(0, s, LIBUS_SOCKET_CLOSE_CODE_CLEAN_SHUTDOWN, NULL); } else { /* We got FIN, so stop polling for readable */ us_poll_change(&s->p, us_socket_context(0, s)->loop, us_poll_events(&s->p) & LIBUS_SOCKET_WRITABLE); @@ -419,7 +431,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) } } else if (length == LIBUS_SOCKET_ERROR && !bsd_would_block()) { /* Todo: decide also here what kind of reason we should give */ - s = us_socket_close(0, s, 0, NULL); + s = us_socket_close(0, s, LIBUS_SOCKET_CLOSE_CODE_CLEAN_SHUTDOWN, NULL); return; } diff --git a/packages/bun-usockets/src/socket.c b/packages/bun-usockets/src/socket.c index 58aaab6d39..4fe79141da 100644 --- a/packages/bun-usockets/src/socket.c +++ b/packages/bun-usockets/src/socket.c @@ -137,7 +137,7 @@ void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) { c->closed = 1; for (struct us_socket_t *s = c->connecting_head; s; s = s->connect_next) { - us_internal_socket_context_unlink_socket(s->context, s); + us_internal_socket_context_unlink_socket(ssl, s->context, s); us_poll_stop((struct us_poll_t *) s, s->context->loop); bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); @@ -157,6 +157,9 @@ void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) { } struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, void *reason) { + if(ssl) { + return (struct us_socket_t *)us_internal_ssl_socket_close((struct us_internal_ssl_socket_t *) s, code, reason); + } if (!us_socket_is_closed(0, s)) { if (s->low_prio_state == 1) { /* Unlink this socket from the low-priority queue */ @@ -169,7 +172,7 @@ struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, vo s->next = 0; s->low_prio_state = 0; } else { - us_internal_socket_context_unlink_socket(s->context, s); + us_internal_socket_context_unlink_socket(ssl, s->context, s); } #ifdef LIBUS_USE_KQUEUE // kqueue automatically removes the fd from the set on close @@ -219,7 +222,7 @@ struct us_socket_t *us_socket_detach(int ssl, struct us_socket_t *s) { s->next = 0; s->low_prio_state = 0; } else { - us_internal_socket_context_unlink_socket(s->context, s); + us_internal_socket_context_unlink_socket(ssl, s->context, s); } us_poll_stop((struct us_poll_t *) s, s->context->loop); diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index d3ba902e8e..76ae82137e 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -19,46 +19,6 @@ const BoringSSL = bun.BoringSSL; const X509 = @import("./x509.zig"); const Async = bun.Async; -// const Corker = struct { -// ptr: ?*[16384]u8 = null, -// holder: ?*anyopaque = null, -// list: bun.ByteList = .{}, - -// pub fn write(this: *Corker, owner: *anyopaque, bytes: []const u8) usize { -// if (this.holder != null and this.holder.? != owner) { -// return 0; -// } - -// this.holder = owner; -// if (this.ptr == null) { -// this.ptr = bun.default_allocator.alloc(u8, 16384) catch @panic("Out of memory allocating corker"); -// bun.assert(this.list.cap == 0); -// bun.assert(this.list.len == 0); -// this.list.cap = 16384; -// this.list.ptr = this.ptr.?; -// this.list.len = 0; -// } -// } - -// pub fn flushIfNecessary(this: *Corker, comptime ssl: bool, socket: uws.NewSocketHandler(ssl), owner: *anyopaque) void { -// if (this.holder == null or this.holder.? != owner) { -// return; -// } - -// if (this.ptr == null) { -// return; -// } - -// if (this.list.len == 0) { -// return; -// } - -// const bytes = ths.list.slice(); - -// this.list.len = 0; -// } -// }; - noinline fn getSSLException(globalThis: *JSC.JSGlobalObject, defaultMessage: []const u8) JSValue { var zig_str: ZigString = ZigString.init(""); var output_buf: [4096]u8 = undefined; @@ -822,14 +782,13 @@ pub const Listener = struct { const Socket = NewSocket(ssl); bun.assert(ssl == listener.ssl); - var this_socket = listener.handlers.vm.allocator.create(Socket) catch bun.outOfMemory(); - this_socket.* = Socket{ + var this_socket = Socket.new(.{ .handlers = &listener.handlers, .this_value = .zero, .socket = socket, .protos = listener.protos, - .owned_protos = false, - }; + .flags = .{ .owned_protos = false }, + }); if (listener.strong_data.get()) |default_data| { const globalObject = listener.handlers.globalObject; Socket.dataSetCached(this_socket.getThisValue(globalObject), globalObject, default_data); @@ -1101,42 +1060,38 @@ pub const Listener = struct { handlers_ptr.promise.set(globalObject, promise_value); if (ssl_enabled) { - var tls = handlers.vm.allocator.create(TLSSocket) catch bun.outOfMemory(); - - tls.* = .{ + var tls = TLSSocket.new(.{ .handlers = handlers_ptr, .this_value = .zero, .socket = undefined, .connection = connection, .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch unreachable) else null, .server_name = server_name, - }; + }); TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data); tls.doConnect(connection, socket_context) catch { - tls.handleConnectError(@intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED)); + tls.handleConnectError(socket_context, @intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED)); return promise_value; }; tls.poll_ref.ref(handlers.vm); return promise_value; } else { - var tcp = handlers.vm.allocator.create(TCPSocket) catch bun.outOfMemory(); - - tcp.* = .{ + var tcp = TCPSocket.new(.{ .handlers = handlers_ptr, .this_value = .zero, .socket = undefined, .connection = null, .protos = null, .server_name = null, - }; + }); TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data); tcp.doConnect(connection, socket_context) catch { - tcp.handleConnectError(@intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED)); + tcp.handleConnectError(socket_context, @intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED)); return promise_value; }; tcp.poll_ref.ref(handlers.vm); @@ -1187,27 +1142,24 @@ fn NewSocket(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); socket: Socket, - detached: bool = false, - /// Prevent onClose from calling into JavaScript while we are finalizing - finalizing: bool = false, + flags: Flags = .{}, + ref_count: u32 = 1, wrapped: WrappedType = .none, handlers: *Handlers, this_value: JSC.JSValue = .zero, poll_ref: Async.KeepAlive = Async.KeepAlive.init(), - is_active: bool = false, last_4: [4]u8 = .{ 0, 0, 0, 0 }, - authorized: bool = false, connection: ?Listener.UnixOrHost = null, protos: ?[]const u8, - owned_protos: bool = true, server_name: ?[]const u8 = null, // TODO: switch to something that uses `visitAggregate` and have the // `Listener` keep a list of all the sockets JSValue in there // This is wasteful because it means we are keeping a JSC::Weak for every single open socket has_pending_activity: std.atomic.Value(bool) = std.atomic.Value(bool).init(true), + pub usingnamespace bun.NewRefCounted(@This(), @This().deinit); const This = @This(); const log = Output.scoped(.Socket, false); @@ -1218,7 +1170,14 @@ fn NewSocket(comptime ssl: bool) type { total: usize = 0, }, }; - + const Flags = packed struct { + is_active: bool = false, + /// Prevent onClose from calling into JavaScript while we are finalizing + finalizing: bool = false, + detached: bool = true, + authorized: bool = false, + owned_protos: bool = true, + }; pub usingnamespace if (!ssl) JSC.Codegen.JSTCPSocket else @@ -1269,7 +1228,7 @@ fn NewSocket(comptime ssl: bool) type { ) void { JSC.markBinding(@src()); log("onWritable", .{}); - if (this.detached) return; + if (this.flags.detached) return; const handlers = this.handlers; const callback = handlers.onWritable; if (callback == .zero) return; @@ -1297,11 +1256,11 @@ fn NewSocket(comptime ssl: bool) type { ) void { JSC.markBinding(@src()); log("onTimeout", .{}); - if (this.detached) return; + if (this.flags.detached) return; const handlers = this.handlers; const callback = handlers.onTimeout; - if (callback == .zero or this.finalizing) return; + if (callback == .zero or this.flags.finalizing) return; if (handlers.vm.isShuttingDown()) { return; } @@ -1321,11 +1280,12 @@ fn NewSocket(comptime ssl: bool) type { _ = handlers.callErrorHandler(this_value, &[_]JSC.JSValue{ this_value, err_value }); } } - fn handleConnectError(this: *This, errno: c_int) void { + fn handleConnectError(this: *This, socket_ctx: ?*uws.SocketContext, errno: c_int) void { log("onConnectError({d})", .{errno}); - if (this.detached) return; - this.detached = true; - defer this.markInactive(); + this.flags.detached = true; + defer this.deref(); + + defer this.markInactive(socket_ctx); const handlers = this.handlers; const vm = handlers.vm; @@ -1380,35 +1340,34 @@ fn NewSocket(comptime ssl: bool) type { this.has_pending_activity.store(false, .release); } } - pub fn onConnectError(this: *This, _: Socket, errno: c_int) void { + pub fn onConnectError(this: *This, socket: Socket, errno: c_int) void { JSC.markBinding(@src()); - this.handleConnectError(errno); + this.handleConnectError(socket.context(), errno); } pub fn markActive(this: *This) void { - if (!this.is_active) { + if (!this.flags.is_active) { this.handlers.markActive(); - this.is_active = true; + this.flags.is_active = true; this.has_pending_activity.store(true, .release); } } - pub fn markInactive(this: *This) void { - if (this.is_active) { - if (!this.detached) { + pub fn markInactive(this: *This, socket_ctx: ?*uws.SocketContext) void { + if (this.flags.is_active) { + if (!this.flags.detached) { // we have to close the socket before the socket context is closed // otherwise we will get a segfault // uSockets will defer freeing the TCP socket until the next tick if (!this.socket.isClosed()) { - this.detached = true; this.socket.close(.normal); // onClose will call markInactive again return; } } - this.is_active = false; + this.flags.is_active = false; const vm = this.handlers.vm; - this.handlers.markInactive(ssl, this.socket.context(), this.wrapped); + this.handlers.markInactive(ssl, socket_ctx, this.wrapped); this.poll_ref.unref(vm); this.has_pending_activity.store(false, .release); } @@ -1452,8 +1411,9 @@ fn NewSocket(comptime ssl: bool) type { } } - this.detached = false; + this.flags.detached = false; this.socket = socket; + this.ref(); if (this.wrapped == .none) { socket.ext(**anyopaque).* = bun.cast(**anyopaque, this); @@ -1485,8 +1445,7 @@ fn NewSocket(comptime ssl: bool) type { }); if (result.toError()) |err| { - this.detached = true; - defer this.markInactive(); + defer this.markInactive(socket.context()); if (!this.socket.isClosed()) { log("Closing due to error", .{}); } else { @@ -1512,7 +1471,7 @@ fn NewSocket(comptime ssl: bool) type { pub fn onEnd(this: *This, socket: Socket) void { JSC.markBinding(@src()); log("onEnd", .{}); - if (this.detached) return; + if (this.flags.detached) return; const handlers = this.handlers; @@ -1521,7 +1480,7 @@ fn NewSocket(comptime ssl: bool) type { this.poll_ref.unref(handlers.vm); // If you don't handle TCP fin, we assume you're done. - this.markInactive(); + this.markInactive(this.socket.context()); return; } @@ -1544,10 +1503,10 @@ fn NewSocket(comptime ssl: bool) type { pub fn onHandshake(this: *This, socket: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { log("onHandshake({d})", .{success}); JSC.markBinding(@src()); - if (this.detached) return; + if (this.flags.detached) return; const authorized = if (success == 1) true else false; - this.authorized = authorized; + this.flags.authorized = authorized; const handlers = this.handlers; var callback = handlers.onHandshake; @@ -1620,10 +1579,11 @@ fn NewSocket(comptime ssl: bool) type { pub fn onClose(this: *This, socket: Socket, err: c_int, _: ?*anyopaque) void { JSC.markBinding(@src()); log("onClose", .{}); - this.detached = true; - defer this.markInactive(); + defer this.deref(); + this.flags.detached = true; + defer this.markInactive(socket.context()); - if (this.finalizing) { + if (this.flags.finalizing) { return; } @@ -1660,11 +1620,11 @@ fn NewSocket(comptime ssl: bool) type { pub fn onData(this: *This, socket: Socket, data: []const u8) void { JSC.markBinding(@src()); log("onData({d})", .{data.len}); - if (this.detached) return; + if (this.flags.detached) return; const handlers = this.handlers; const callback = handlers.onData; - if (callback == .zero or this.finalizing) return; + if (callback == .zero or this.flags.finalizing) return; if (handlers.vm.isShuttingDown()) { return; } @@ -1711,7 +1671,7 @@ fn NewSocket(comptime ssl: bool) type { this: *This, _: *JSC.JSGlobalObject, ) JSValue { - if (!this.handlers.is_server or this.detached) { + if (!this.handlers.is_server or this.flags.detached) { return JSValue.jsUndefined(); } @@ -1725,7 +1685,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { log("getReadyState()", .{}); - if (this.detached) { + if (this.flags.detached) { return JSValue.jsNumber(@as(i32, -1)); } else if (this.socket.isClosed()) { return JSValue.jsNumber(@as(i32, 0)); @@ -1743,7 +1703,7 @@ fn NewSocket(comptime ssl: bool) type { _: *JSC.JSGlobalObject, ) JSValue { log("getAuthorized()", .{}); - return JSValue.jsBoolean(this.authorized); + return JSValue.jsBoolean(this.flags.authorized); } pub fn timeout( this: *This, @@ -1752,7 +1712,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { JSC.markBinding(@src()); const args = callframe.arguments(1); - if (this.detached) return JSValue.jsUndefined(); + if (this.flags.detached) return JSValue.jsUndefined(); if (args.len == 0) { globalObject.throw("Expected 1 argument, got 0", .{}); return .zero; @@ -1775,7 +1735,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { JSC.markBinding(@src()); - if (this.detached) { + if (this.flags.detached) { return JSValue.jsNull(); } @@ -1805,7 +1765,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { JSC.markBinding(@src()); - if (this.detached) { + if (this.flags.detached) { return JSValue.jsNumber(@as(i32, -1)); } @@ -1826,7 +1786,7 @@ fn NewSocket(comptime ssl: bool) type { this: *This, _: *JSC.JSGlobalObject, ) JSValue { - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -1837,7 +1797,7 @@ fn NewSocket(comptime ssl: bool) type { this: *This, globalThis: *JSC.JSGlobalObject, ) JSValue { - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -1858,7 +1818,7 @@ fn NewSocket(comptime ssl: bool) type { } fn writeMaybeCorked(this: *This, buffer: []const u8, is_end: bool) i32 { - if (this.detached or this.socket.isShutdown() or this.socket.isClosed()) { + if (this.flags.detached or this.socket.isShutdown() or this.socket.isClosed()) { return -1; } // we don't cork yet but we might later @@ -2028,7 +1988,7 @@ fn NewSocket(comptime ssl: bool) type { _: *JSC.CallFrame, ) JSValue { JSC.markBinding(@src()); - if (!this.detached) + if (!this.flags.detached) this.socket.flush(); return JSValue.jsUndefined(); @@ -2040,7 +2000,7 @@ fn NewSocket(comptime ssl: bool) type { _: *JSC.CallFrame, ) JSValue { JSC.markBinding(@src()); - if (!this.detached) { + if (!this.flags.detached) { this.socket.close(.failure); } @@ -2054,7 +2014,7 @@ fn NewSocket(comptime ssl: bool) type { ) JSValue { JSC.markBinding(@src()); const args = callframe.arguments(1); - if (!this.detached) { + if (!this.flags.detached) { if (args.len > 0 and args.ptr[0].toBoolean()) { this.socket.shutdownRead(); } else { @@ -2076,7 +2036,7 @@ fn NewSocket(comptime ssl: bool) type { log("end({d} args)", .{args.len}); - if (this.detached) { + if (this.flags.detached) { return JSValue.jsNumber(@as(i32, -1)); } @@ -2086,41 +2046,32 @@ fn NewSocket(comptime ssl: bool) type { if (result.wrote == result.total) { this.socket.flush(); // markInactive does .detached = true - this.markInactive(); + this.markInactive(this.socket.context()); } break :brk JSValue.jsNumber(result.wrote); }, }; } - pub fn ref(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) JSValue { + pub fn jsRef(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) JSValue { JSC.markBinding(@src()); - if (this.detached) return JSValue.jsUndefined(); + if (this.flags.detached) return JSValue.jsUndefined(); this.poll_ref.ref(globalObject.bunVM()); return JSValue.jsUndefined(); } - pub fn unref(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) JSValue { + pub fn jsUnref(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) JSValue { JSC.markBinding(@src()); this.poll_ref.unref(globalObject.bunVM()); return JSValue.jsUndefined(); } - pub fn finalize(this: *This) void { - log("finalize() {d}", .{@intFromPtr(this)}); - this.finalizing = true; - if (!this.detached) { - this.detached = true; - if (!this.socket.isClosed()) { - this.socket.close(.failure); - } - } - - this.markInactive(); + pub fn deinit(this: *This) void { + this.markInactive(null); this.poll_ref.unref(JSC.VirtualMachine.get()); // need to deinit event without being attached - if (this.owned_protos) { + if (this.flags.owned_protos) { if (this.protos) |protos| { this.protos = null; default_allocator.free(protos); @@ -2136,6 +2087,19 @@ fn NewSocket(comptime ssl: bool) type { this.connection = null; connection.deinit(); } + this.destroy(); + } + + pub fn finalize(this: *This) void { + log("finalize() {d}", .{@intFromPtr(this)}); + this.flags.finalizing = true; + if (!this.flags.detached) { + if (!this.socket.isClosed()) { + this.socket.close(.failure); + } + } + + this.deref(); } pub fn reload(this: *This, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) JSValue { @@ -2146,7 +2110,7 @@ fn NewSocket(comptime ssl: bool) type { return .zero; } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2184,7 +2148,7 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2201,7 +2165,7 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2242,7 +2206,7 @@ fn NewSocket(comptime ssl: bool) type { if (comptime ssl == false) { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2263,7 +2227,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2290,7 +2254,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2336,7 +2300,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2364,7 +2328,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsBoolean(false); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsBoolean(false); } @@ -2396,7 +2360,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2484,7 +2448,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsNull(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsNull(); } @@ -2553,7 +2517,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } var result = JSValue.createEmptyObject(globalObject, 3); @@ -2600,7 +2564,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2632,7 +2596,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2665,7 +2629,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsNull(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsNull(); } const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(this.socket.getNativeHandle())); @@ -2756,7 +2720,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsNull(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsNull(); } @@ -2779,7 +2743,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsBoolean(false); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsBoolean(false); } @@ -2818,7 +2782,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2872,7 +2836,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -2935,7 +2899,7 @@ fn NewSocket(comptime ssl: bool) type { this.server_name = slice; } - if (this.detached) { + if (this.flags.detached) { // will be attached onOpen return JSValue.jsUndefined(); } @@ -2970,7 +2934,7 @@ fn NewSocket(comptime ssl: bool) type { return JSValue.jsUndefined(); } - if (this.detached) { + if (this.flags.detached) { return JSValue.jsUndefined(); } @@ -3036,13 +3000,12 @@ fn NewSocket(comptime ssl: bool) type { const ext_size = @sizeOf(WrappedSocket); const is_server = this.handlers.is_server; - var tls = handlers.vm.allocator.create(TLSSocket) catch bun.outOfMemory(); + var handlers_ptr = handlers.vm.allocator.create(Handlers) catch bun.outOfMemory(); handlers_ptr.* = handlers; handlers_ptr.is_server = is_server; handlers_ptr.protect(); - - tls.* = .{ + var tls = TLSSocket.new(.{ .handlers = handlers_ptr, .this_value = .zero, .socket = undefined, @@ -3050,7 +3013,7 @@ fn NewSocket(comptime ssl: bool) type { .wrapped = .tls, .protos = if (protos) |p| (bun.default_allocator.dupe(u8, p[0..protos_len]) catch unreachable) else null, .server_name = if (socket_config.server_name) |server_name| (bun.default_allocator.dupe(u8, server_name[0..bun.len(server_name)]) catch unreachable) else null, - }; + }); const tls_js_value = tls.getThisValue(globalObject); TLSSocket.dataSetCached(tls_js_value, globalObject, default_data); @@ -3076,7 +3039,6 @@ fn NewSocket(comptime ssl: bool) type { tls.socket = new_socket; - var raw = handlers.vm.allocator.create(TLSSocket) catch bun.outOfMemory(); var raw_handlers_ptr = handlers.vm.allocator.create(Handlers) catch bun.outOfMemory(); raw_handlers_ptr.* = .{ .vm = globalObject.bunVM(), @@ -3094,15 +3056,16 @@ fn NewSocket(comptime ssl: bool) type { .is_server = is_server, }; - raw.* = .{ + raw_handlers_ptr.protect(); + + var raw = TLSSocket.new(.{ .handlers = raw_handlers_ptr, .this_value = .zero, .socket = new_socket, .connection = if (this.connection) |c| c.clone() else null, .wrapped = .tcp, .protos = null, - }; - raw_handlers_ptr.protect(); + }); const raw_js_value = raw.getThisValue(globalObject); if (JSSocketType(ssl).dataGetCached(this.getThisValue(globalObject))) |raw_default_data| { @@ -3125,10 +3088,10 @@ fn NewSocket(comptime ssl: bool) type { new_socket.startTLS(!this.handlers.is_server); //detach and invalidate the old instance - this.detached = true; - if (this.is_active) { + this.flags.detached = true; + if (this.flags.is_active) { const vm = this.handlers.vm; - this.is_active = false; + this.flags.is_active = false; // will free handlers and the old_context when hits 0 active connections // the connection can be upgraded inside a handler call so we need to garantee that it will be still alive this.handlers.markInactive(ssl, old_context, this.wrapped); diff --git a/src/bun.js/api/sockets.classes.ts b/src/bun.js/api/sockets.classes.ts index d36ca4629c..dc2f4b39c8 100644 --- a/src/bun.js/api/sockets.classes.ts +++ b/src/bun.js/api/sockets.classes.ts @@ -126,11 +126,11 @@ function generate(ssl) { }, ref: { - fn: "ref", + fn: "jsRef", length: 0, }, unref: { - fn: "unref", + fn: "jsUnref", length: 0, }, diff --git a/src/bun.js/web_worker.zig b/src/bun.js/web_worker.zig index ae0a6ac0f0..f07dd0edbd 100644 --- a/src/bun.js/web_worker.zig +++ b/src/bun.js/web_worker.zig @@ -23,7 +23,7 @@ pub const WebWorker = struct { /// Already resolved. specifier: []const u8 = "", store_fd: bool = false, - arena: bun.MimallocArena = undefined, + arena: ?bun.MimallocArena = null, name: [:0]const u8 = "Worker", cpp_worker: *anyopaque, mini: bool = false, @@ -177,7 +177,7 @@ pub const WebWorker = struct { } if (this.hasRequestedTerminate()) { - this.deinit(); + this.exitAndDeinit(); return; } @@ -186,13 +186,13 @@ pub const WebWorker = struct { this.arena = try bun.MimallocArena.init(); var vm = try JSC.VirtualMachine.initWorker(this, .{ - .allocator = this.arena.allocator(), + .allocator = this.arena.?.allocator(), .args = this.parent.bundler.options.transform_options, .store_fd = this.store_fd, .graph = this.parent.standalone_module_graph, }); - vm.allocator = this.arena.allocator(); - vm.arena = &this.arena; + vm.allocator = this.arena.?.allocator(); + vm.arena = &this.arena.?; var b = &vm.bundler; @@ -427,8 +427,9 @@ pub const WebWorker = struct { if (vm_to_deinit) |vm| { vm.deinit(); // NOTE: deinit here isn't implemented, so freeing workers will leak the vm. } - - arena.deinit(); + if (arena) |*arena_| { + arena_.deinit(); + } bun.exitThread(); } diff --git a/src/cli/create_command.zig b/src/cli/create_command.zig index 870360f506..7b7bf50c7c 100644 --- a/src/cli/create_command.zig +++ b/src/cli/create_command.zig @@ -1981,7 +1981,7 @@ pub const Example = struct { HTTP.FetchRedirect.follow, ); async_http.client.progress_node = progress; - async_http.client.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); + async_http.client.flags.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); const response = try async_http.sendSync(true); @@ -2058,7 +2058,7 @@ pub const Example = struct { HTTP.FetchRedirect.follow, ); async_http.client.progress_node = progress; - async_http.client.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); + async_http.client.flags.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); var response = try async_http.sendSync(true); @@ -2147,7 +2147,7 @@ pub const Example = struct { HTTP.FetchRedirect.follow, ); async_http.client.progress_node = progress; - async_http.client.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); + async_http.client.flags.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); refresher.maybeRefresh(); @@ -2188,7 +2188,7 @@ pub const Example = struct { null, HTTP.FetchRedirect.follow, ); - async_http.client.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); + async_http.client.flags.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); if (Output.enable_ansi_colors) { async_http.client.progress_node = progress_node; diff --git a/src/cli/upgrade_command.zig b/src/cli/upgrade_command.zig index 0c5ba5c28e..c89b011c19 100644 --- a/src/cli/upgrade_command.zig +++ b/src/cli/upgrade_command.zig @@ -248,7 +248,7 @@ pub const UpgradeCommand = struct { null, HTTP.FetchRedirect.follow, ); - async_http.client.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); + async_http.client.flags.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); if (!silent) async_http.client.progress_node = progress.?; const response = try async_http.sendSync(true); @@ -531,7 +531,7 @@ pub const UpgradeCommand = struct { HTTP.FetchRedirect.follow, ); async_http.client.progress_node = progress; - async_http.client.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); + async_http.client.flags.reject_unauthorized = env_loader.getTLSRejectUnauthorized(); const response = try async_http.sendSync(true); diff --git a/src/compile_target.zig b/src/compile_target.zig index e3330614e7..67d0c0aab1 100644 --- a/src/compile_target.zig +++ b/src/compile_target.zig @@ -168,7 +168,7 @@ pub fn downloadToPath(this: *const CompileTarget, env: *bun.DotEnv.Loader, alloc HTTP.FetchRedirect.follow, ); async_http.client.progress_node = progress; - async_http.client.reject_unauthorized = env.getTLSRejectUnauthorized(); + async_http.client.flags.reject_unauthorized = env.getTLSRejectUnauthorized(); const response = try async_http.sendSync(true); diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 4878d8dfad..b6e80dd4b3 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -35,6 +35,7 @@ pub const InternalLoopData = extern struct { last_write_failed: i32, head: ?*SocketContext, iterator: ?*SocketContext, + closed_context_head: ?*SocketContext, recv_buf: [*]u8, send_buf: [*]u8, ssl_data: ?*anyopaque, @@ -964,8 +965,8 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { // We close immediately in this case // uSockets doesn't know if this is a TLS socket or not. - // So we have to do that logic in here. - ThisSocket.from(socket).close(.failure); + // So we need to close it like a TCP socket. + NewSocketHandler(false).from(socket).close(.failure); Fields.onConnectError( val, @@ -1406,6 +1407,8 @@ pub extern fn us_create_socket_context(ssl: i32, loop: ?*Loop, ext_size: i32, op pub extern fn us_create_bun_socket_context(ssl: i32, loop: ?*Loop, ext_size: i32, options: us_bun_socket_context_options_t) ?*SocketContext; pub extern fn us_bun_socket_context_add_server_name(ssl: i32, context: ?*SocketContext, hostname_pattern: [*c]const u8, options: us_bun_socket_context_options_t, ?*anyopaque) void; pub extern fn us_socket_context_free(ssl: i32, context: ?*SocketContext) void; +pub extern fn us_socket_context_ref(ssl: i32, context: ?*SocketContext) void; +pub extern fn us_socket_context_unref(ssl: i32, context: ?*SocketContext) void; extern fn us_socket_context_on_open(ssl: i32, context: ?*SocketContext, on_open: *const fn (*Socket, i32, [*c]u8, i32) callconv(.C) ?*Socket) void; extern fn us_socket_context_on_close(ssl: i32, context: ?*SocketContext, on_close: *const fn (*Socket, i32, ?*anyopaque) callconv(.C) ?*Socket) void; extern fn us_socket_context_on_data(ssl: i32, context: ?*SocketContext, on_data: *const fn (*Socket, [*c]u8, i32) callconv(.C) ?*Socket) void; diff --git a/src/http.zig b/src/http.zig index 5e461392eb..e1e6a35d9e 100644 --- a/src/http.zig +++ b/src/http.zig @@ -515,17 +515,17 @@ fn NewHTTPContext(comptime ssl: bool) type { const active = getTagged(ptr); if (active.get(HTTPClient)) |client| { - if (handshake_error.error_no != 0 and (client.reject_unauthorized or !authorized)) { + if (handshake_error.error_no != 0 and (client.flags.reject_unauthorized or !authorized)) { client.closeAndFail(BoringSSL.getCertErrorFromNo(handshake_error.error_no), comptime ssl, socket); return; } // no handshake_error at this point if (authorized) { - client.did_have_handshaking_error = handshake_error.error_no != 0; + client.flags.did_have_handshaking_error = handshake_error.error_no != 0; // if checkServerIdentity returns false, we dont call open this means that the connection was rejected if (!client.checkServerIdentity(comptime ssl, socket, handshake_error)) { - client.did_have_handshaking_error = true; + client.flags.did_have_handshaking_error = true; if (!socket.isClosed()) terminateSocket(socket); return; @@ -737,7 +737,7 @@ fn NewHTTPContext(comptime ssl: bool) type { client.connected_url.hostname = hostname; if (client.isKeepAlivePossible()) { - if (this.existingSocket(client.reject_unauthorized, hostname, port)) |sock| { + if (this.existingSocket(client.flags.reject_unauthorized, hostname, port)) |sock| { sock.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr()); client.allow_retry = true; client.onOpen(comptime ssl, sock); @@ -879,7 +879,7 @@ pub const HTTPThread = struct { try custom_ssl_context_map.put(requested_config, custom_context); // We might deinit the socket context, so we disable keepalive to make sure we don't // free it while in use. - client.disable_keepalive = true; + client.flags.disable_keepalive = true; return try custom_context.connect(client, client.url.hostname, client.url.getPortAuto()); } } @@ -901,12 +901,11 @@ pub const HTTPThread = struct { if (socket_async_http_abort_tracker.fetchSwapRemove(http.async_http_id)) |socket_ptr| { if (http.is_tls) { const socket = uws.SocketTLS.fromAny(socket_ptr.value); - socket.shutdown(); - socket.shutdownRead(); + // do a fast shutdown here since we are aborting and we dont want to wait for the close_notify from the other side + socket.close(.failure); } else { const socket = uws.SocketTCP.fromAny(socket_ptr.value); - socket.shutdown(); - socket.shutdownRead(); + socket.close(.failure); } } } @@ -1018,7 +1017,7 @@ pub fn checkServerIdentity( if (comptime is_ssl == false) { @panic("checkServerIdentity called on non-ssl socket"); } - if (client.reject_unauthorized) { + if (client.flags.reject_unauthorized) { const ssl_ptr = @as(*BoringSSL.SSL, @ptrCast(socket.getNativeHandle())); if (BoringSSL.SSL_get_peer_cert_chain(ssl_ptr)) |cert_chain| { if (BoringSSL.sk_X509_value(cert_chain, 0)) |x509| { @@ -1133,7 +1132,7 @@ pub fn firstCall( socket: NewHTTPContext(is_ssl).HTTPSocket, ) void { if (comptime FeatureFlags.is_fetch_preconnect_supported) { - if (client.is_preconnect_only) { + if (client.flags.is_preconnect_only) { client.onPreconnect(is_ssl, socket); return; } @@ -1160,14 +1159,14 @@ pub fn onClose( if (picohttp.phr_decode_chunked_is_in_data(&client.state.chunked_decoder) == 0) { const buf = client.state.getBodyBuffer(); if (buf.list.items.len > 0) { - client.state.received_last_chunk = true; + client.state.flags.received_last_chunk = true; client.progressUpdate(comptime is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); return; } } } else if (client.state.content_length == null and client.state.response_stage == .body) { // no content length informed so we are done here - client.state.received_last_chunk = true; + client.state.flags.received_last_chunk = true; client.progressUpdate(comptime is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); return; } @@ -1188,7 +1187,7 @@ pub fn onTimeout( comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket, ) void { - if (client.disable_timeout) return; + if (client.flags.disable_timeout) return; log("Timeout {s}\n", .{client.url.href}); defer NewHTTPContext(is_ssl).terminateSocket(socket); @@ -1447,13 +1446,8 @@ pub const InternalState = struct { /// this can happen after await fetch(...) and the body can continue streaming when this is already null /// the user will receive only chunks of the body stored in body_out_str cloned_metadata: ?HTTPResponseMetadata = null, + flags: InternalStateFlags = InternalStateFlags{}, - allow_keepalive: bool = true, - received_last_chunk: bool = false, - did_set_content_encoding: bool = false, - is_redirect_pending: bool = false, - is_libdeflate_fast_path_disabled: bool = false, - resend_request_body_on_redirect: bool = false, transfer_encoding: Encoding = Encoding.identity, encoding: Encoding = Encoding.identity, content_encoding_i: u8 = std.math.maxInt(u8), @@ -1473,6 +1467,15 @@ pub const InternalState = struct { response_stage: HTTPStage = .pending, certificate_info: ?CertificateInfo = null, + pub const InternalStateFlags = packed struct { + allow_keepalive: bool = true, + received_last_chunk: bool = false, + did_set_content_encoding: bool = false, + is_redirect_pending: bool = false, + is_libdeflate_fast_path_disabled: bool = false, + resend_request_body_on_redirect: bool = false, + }; + pub fn init(body: HTTPRequestBody, body_out_str: *MutableString) InternalState { return .{ .original_request_body = body, @@ -1516,7 +1519,7 @@ pub const InternalState = struct { .original_request_body = .{ .bytes = "" }, .request_body = "", .certificate_info = null, - .is_redirect_pending = false, + .flags = .{}, }; } @@ -1530,7 +1533,7 @@ pub const InternalState = struct { fn isDone(this: *InternalState) bool { if (this.isChunkedEncoding()) { - return this.received_last_chunk; + return this.flags.received_last_chunk; } if (this.content_length) |content_length| { @@ -1538,7 +1541,7 @@ pub const InternalState = struct { } // Content-Type: text/event-stream we should be done only when Close/End/Timeout connection - return this.received_last_chunk; + return this.flags.received_last_chunk; } fn decompressBytes(this: *InternalState, buffer: []const u8, body_out_str: *MutableString, is_final_chunk: bool) !void { @@ -1552,8 +1555,8 @@ pub const InternalState = struct { if (FeatureFlags.isLibdeflateEnabled()) { // Fast-path: use libdeflate - if (is_final_chunk and !this.is_libdeflate_fast_path_disabled and this.encoding.canUseLibDeflate() and this.isDone()) libdeflate: { - this.is_libdeflate_fast_path_disabled = true; + if (is_final_chunk and !this.flags.is_libdeflate_fast_path_disabled and this.encoding.canUseLibDeflate() and this.isDone()) libdeflate: { + this.flags.is_libdeflate_fast_path_disabled = true; log("Decompressing {d} bytes with libdeflate\n", .{buffer.len}); var deflater = http_thread.deflater(); @@ -1621,7 +1624,7 @@ pub const InternalState = struct { } pub fn processBodyBuffer(this: *InternalState, buffer: MutableString, is_final_chunk: bool) !bool { - if (this.is_redirect_pending) return false; + if (this.flags.is_redirect_pending) return false; var body_out_str = this.body_out_str.?; @@ -1652,6 +1655,18 @@ pub const HTTPVerboseLevel = enum { curl, }; +pub const Flags = packed struct { + disable_timeout: bool = false, + disable_keepalive: bool = false, + disable_decompression: bool = false, + did_have_handshaking_error: bool = false, + force_last_modified: bool = false, + redirected: bool = false, + proxy_tunneling: bool = false, + reject_unauthorized: bool = true, + is_preconnect_only: bool = false, +}; + // TODO: reduce the size of this struct // Many of these fields can be moved to a packed struct and use less space method: Method, @@ -1666,31 +1681,25 @@ allow_retry: bool = false, redirect_type: FetchRedirect = FetchRedirect.follow, redirect: []u8 = &.{}, progress_node: ?*Progress.Node = null, -disable_timeout: bool = false, -disable_keepalive: bool = false, -disable_decompression: bool = false, -state: InternalState = .{}, -did_have_handshaking_error: bool = false, +flags: Flags = Flags{}, + +state: InternalState = .{}, tls_props: ?*SSLConfig = null, result_callback: HTTPClientResult.Callback = undefined, /// Some HTTP servers (such as npm) report Last-Modified times but ignore If-Modified-Since. /// This is a workaround for that. -force_last_modified: bool = false, if_modified_since: string = "", request_content_len_buf: ["-4294967295".len]u8 = undefined, http_proxy: ?URL = null, proxy_authorization: ?[]u8 = null, -proxy_tunneling: bool = false, proxy_tunnel: ?ProxyTunnel = null, signals: Signals = .{}, async_http_id: u32 = 0, hostname: ?[]u8 = null, -reject_unauthorized: bool = true, unix_socket_path: JSC.ZigString.Slice = JSC.ZigString.Slice.empty, -is_preconnect_only: bool = false, pub fn deinit(this: *HTTPClient) void { if (this.redirect.len > 0) { @@ -1719,7 +1728,7 @@ pub fn isKeepAlivePossible(this: *HTTPClient) bool { } //check state - if (this.state.allow_keepalive and !this.disable_keepalive) return true; + if (this.state.flags.allow_keepalive and !this.flags.disable_keepalive) return true; } return false; } @@ -1976,7 +1985,7 @@ pub const AsyncHTTP = struct { }); this.async_http = AsyncHTTP.init(bun.default_allocator, .GET, url, .{}, "", &this.response_buffer, "", HTTPClientResult.Callback.New(*Preconnect, Preconnect.onResult).init(this), .manual, .{}); - this.async_http.client.is_preconnect_only = true; + this.async_http.client.flags.is_preconnect_only = true; http_thread.schedule(Batch.from(&this.async_http.task)); } @@ -2024,19 +2033,19 @@ pub const AsyncHTTP = struct { this.client.unix_socket_path = val; } if (options.disable_timeout) |val| { - this.client.disable_timeout = val; + this.client.flags.disable_timeout = val; } if (options.verbose) |val| { this.client.verbose = val; } if (options.disable_decompression) |val| { - this.client.disable_decompression = val; + this.client.flags.disable_decompression = val; } if (options.disable_keepalive) |val| { - this.client.disable_keepalive = val; + this.client.flags.disable_keepalive = val; } if (options.reject_unauthorized) |val| { - this.client.reject_unauthorized = val; + this.client.flags.reject_unauthorized = val; } if (options.tls_props) |val| { this.client.tls_props = val; @@ -2115,7 +2124,7 @@ pub const AsyncHTTP = struct { if (this.http_proxy) |proxy| { //TODO: need to understand how is possible to reuse Proxy with TSL, so disable keepalive if url is HTTPS - this.client.disable_keepalive = this.url.isHTTPS(); + this.client.flags.disable_keepalive = this.url.isHTTPS(); // Username between 0 and 4096 chars if (proxy.username.len > 0 and proxy.username.len < 4096) { // Password between 0 and 4096 chars @@ -2216,7 +2225,7 @@ pub const AsyncHTTP = struct { // TODO: this condition seems wrong: if we started with a non-default value, we might // report a redirect even if none happened - this.redirected = this.client.remaining_redirect_count != default_redirect_count; + this.redirected = this.client.flags.redirected; if (result.isSuccess()) { this.err = null; if (result.metadata) |metadata| { @@ -2306,7 +2315,7 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { hashHeaderConst("Content-Length"), => continue, hashHeaderConst("if-modified-since") => { - if (this.force_last_modified and this.if_modified_since.len == 0) { + if (this.flags.force_last_modified and this.if_modified_since.len == 0) { this.if_modified_since = this.headerStr(header_values[i]); } }, @@ -2367,7 +2376,7 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { header_count += 1; } - if (!override_accept_encoding and !this.disable_decompression) { + if (!override_accept_encoding and !this.flags.disable_decompression) { request_headers_buf[header_count] = accept_encoding_header; header_count += 1; @@ -2397,30 +2406,31 @@ pub fn doRedirect( ) void { this.unix_socket_path.deinit(); this.unix_socket_path = JSC.ZigString.Slice.empty; - const request_body = if (this.state.resend_request_body_on_redirect and this.state.original_request_body == .bytes) + const request_body = if (this.state.flags.resend_request_body_on_redirect and this.state.original_request_body == .bytes) this.state.original_request_body.bytes else ""; this.state.response_message_buffer.deinit(); + const body_out_str = this.state.body_out_str.?; + this.remaining_redirect_count -|= 1; + this.flags.redirected = true; + assert(this.redirect_type == FetchRedirect.follow); + // we need to clean the client reference before closing the socket because we are going to reuse the same ref in a another request if (this.isKeepAlivePossible()) { assert(this.connected_url.hostname.len > 0); ctx.releaseSocket( socket, - this.did_have_handshaking_error and !this.reject_unauthorized, + this.flags.did_have_handshaking_error and !this.flags.reject_unauthorized, this.connected_url.hostname, this.connected_url.getPortAuto(), ); } else { NewHTTPContext(is_ssl).closeSocket(socket); } - this.connected_url = URL{}; - const body_out_str = this.state.body_out_str.?; - this.remaining_redirect_count -|= 1; - assert(this.redirect_type == FetchRedirect.follow); // TODO: should this check be before decrementing the redirect count? // the current logic will allow one less redirect than requested @@ -2430,7 +2440,7 @@ pub fn doRedirect( } this.state.reset(this.allocator); // also reset proxy to redirect - this.proxy_tunneling = false; + this.flags.proxy_tunneling = false; if (this.proxy_tunnel != null) { var tunnel = this.proxy_tunnel.?; tunnel.deinit(); @@ -2537,7 +2547,7 @@ pub fn onPreconnect(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCon const ctx = if (comptime is_ssl) &http_thread.https_context else &http_thread.http_context; ctx.releaseSocket( socket, - this.did_have_handshaking_error and !this.reject_unauthorized, + this.flags.did_have_handshaking_error and !this.flags.reject_unauthorized, this.url.hostname, this.url.getPortAuto(), ); @@ -2546,7 +2556,7 @@ pub fn onPreconnect(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCon this.state.response_stage = .done; this.state.request_stage = .done; this.state.stage = .done; - this.proxy_tunneling = false; + this.flags.proxy_tunneling = false; this.result_callback.run(@fieldParentPtr("client", this), HTTPClientResult{ .fail = null, .metadata = null, .has_more = false }); } @@ -2557,7 +2567,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s } if (comptime FeatureFlags.is_fetch_preconnect_supported) { - if (this.is_preconnect_only) { + if (this.flags.is_preconnect_only) { this.onPreconnect(is_ssl, socket); return; } @@ -2579,7 +2589,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s if (this.url.isHTTPS()) { //DO the tunneling! - this.proxy_tunneling = true; + this.flags.proxy_tunneling = true; writeProxyConnect(@TypeOf(writer), writer, this) catch { this.closeAndFail(error.OutOfMemory, is_ssl, socket); return; @@ -2610,7 +2620,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s const headers_len = list.items.len; assert(list.items.len == writer.context.items.len); - if (this.state.request_body.len > 0 and list.capacity - list.items.len > 0 and !this.proxy_tunneling) { + if (this.state.request_body.len > 0 and list.capacity - list.items.len > 0 and !this.flags.proxy_tunneling) { var remain = list.items.ptr[list.items.len..list.capacity]; const wrote = @min(remain.len, this.state.request_body.len); assert(wrote > 0); @@ -2643,7 +2653,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s const has_sent_headers = this.state.request_sent_len >= headers_len; if (has_sent_headers and this.verbose != .none) { - printRequest(request, this.url.href, !this.reject_unauthorized, this.state.request_body, this.verbose == .curl); + printRequest(request, this.url.href, !this.flags.reject_unauthorized, this.state.request_body, this.verbose == .curl); } if (has_sent_headers and this.state.request_body.len > 0) { @@ -2661,7 +2671,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s } if (has_sent_headers) { - if (this.proxy_tunneling) { + if (this.flags.proxy_tunneling) { this.state.request_stage = .proxy_handshake; } else { this.state.request_stage = .body; @@ -2982,10 +2992,10 @@ pub fn onData(this: *HTTPClient, comptime is_ssl: bool, incoming_data: []const u return; }; - if (this.state.content_encoding_i < response.headers.len and !this.state.did_set_content_encoding) { + if (this.state.content_encoding_i < response.headers.len and !this.state.flags.did_set_content_encoding) { // if it compressed with this header, it is no longer because we will decompress it const mutable_headers = std.ArrayListUnmanaged(picohttp.Header){ .items = response.headers, .capacity = response.headers.len }; - this.state.did_set_content_encoding = true; + this.state.flags.did_set_content_encoding = true; response.headers = mutable_headers.items; this.state.content_encoding_i = std.math.maxInt(@TypeOf(this.state.content_encoding_i)); // we need to reset the pending response because we removed a header @@ -2993,7 +3003,7 @@ pub fn onData(this: *HTTPClient, comptime is_ssl: bool, incoming_data: []const u } if (should_continue == .finished) { - if (this.state.is_redirect_pending) { + if (this.state.flags.is_redirect_pending) { this.doRedirect(is_ssl, ctx, socket); return; } @@ -3001,14 +3011,14 @@ pub fn onData(this: *HTTPClient, comptime is_ssl: bool, incoming_data: []const u // clone metadata and return the progress at this point this.cloneMetadata(); // if is chuncked but no body is expected we mark the last chunk - this.state.received_last_chunk = true; + this.state.flags.received_last_chunk = true; // if is not we ignore the content_length this.state.content_length = 0; this.progressUpdate(is_ssl, ctx, socket); return; } - if (this.proxy_tunneling and this.proxy_tunnel == null) { + if (this.flags.proxy_tunneling and this.proxy_tunnel == null) { // we are proxing we dont need to cloneMetadata yet this.startProxyHandshake(is_ssl, socket); return; @@ -3143,7 +3153,7 @@ pub fn onData(this: *HTTPClient, comptime is_ssl: bool, incoming_data: []const u defer data.deinit(); const decoded_data = data.slice(); if (decoded_data.len == 0) return; - this.proxy_tunneling = false; + this.flags.proxy_tunneling = false; this.state.response_stage = .proxy_decoded_headers; //actual do the header parsing! this.onData(is_ssl, decoded_data, ctx, socket); @@ -3184,7 +3194,7 @@ fn fail(this: *HTTPClient, err: anyerror) void { const callback = this.result_callback; const result = this.toResult(); this.state.reset(this.allocator); - this.proxy_tunneling = false; + this.flags.proxy_tunneling = false; callback.run(@fieldParentPtr("client", this), result); } @@ -3223,7 +3233,7 @@ fn cloneMetadata(this: *HTTPClient) void { } pub fn setTimeout(this: *HTTPClient, socket: anytype, minutes: c_uint) void { - if (this.disable_timeout) { + if (this.flags.disable_timeout) { socket.timeout(0); socket.setTimeoutMinutes(0); return; @@ -3235,16 +3245,17 @@ pub fn setTimeout(this: *HTTPClient, socket: anytype, minutes: c_uint) void { pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPContext(is_ssl), socket: NewHTTPContext(is_ssl).HTTPSocket) void { if (this.state.stage != .done and this.state.stage != .fail) { - const out_str = this.state.body_out_str.?; - const body = out_str.*; - const result = this.toResult(); - const is_done = !result.has_more; - if (this.state.is_redirect_pending and this.state.fail == null) { + if (this.state.flags.is_redirect_pending and this.state.fail == null) { if (this.state.isDone()) { this.doRedirect(is_ssl, ctx, socket); } return; } + const out_str = this.state.body_out_str.?; + const body = out_str.*; + const result = this.toResult(); + const is_done = !result.has_more; + if (this.signals.aborted != null and is_done) { _ = socket_async_http_abort_tracker.swapRemove(this.async_http_id); } @@ -3257,7 +3268,7 @@ pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPCon if (this.isKeepAlivePossible() and !socket.isClosedOrHasError()) { ctx.releaseSocket( socket, - this.did_have_handshaking_error and !this.reject_unauthorized, + this.flags.did_have_handshaking_error and !this.flags.reject_unauthorized, this.connected_url.hostname, this.connected_url.getPortAuto(), ); @@ -3269,7 +3280,7 @@ pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPCon this.state.response_stage = .done; this.state.request_stage = .done; this.state.stage = .done; - this.proxy_tunneling = false; + this.flags.proxy_tunneling = false; } result.body.?.* = body; @@ -3290,6 +3301,8 @@ pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPCon pub const HTTPClientResult = struct { body: ?*MutableString = null, has_more: bool = false, + redirected: bool = false, + fail: ?anyerror = null, /// Owns the response metadata aka headers, url and status code @@ -3300,7 +3313,6 @@ pub const HTTPClientResult = struct { /// If chunked encoded this will represent the total received size (ignoring the chunk headers) /// If is not chunked encoded and Content-Length is not provided this will be unknown body_size: BodySize = .unknown, - redirected: bool = false, certificate_info: ?CertificateInfo = null, pub const BodySize = union(enum) { @@ -3368,7 +3380,7 @@ pub fn toResult(this: *HTTPClient) HTTPClientResult { return HTTPClientResult{ .metadata = metadata, .body = this.state.body_out_str, - .redirected = this.remaining_redirect_count != default_redirect_count, + .redirected = this.flags.redirected, .fail = this.state.fail, // check if we are reporting cert errors, do not have a fail state and we are not done .has_more = this.state.fail == null and !this.state.isDone(), @@ -3379,6 +3391,7 @@ pub fn toResult(this: *HTTPClient) HTTPClientResult { return HTTPClientResult{ .body = this.state.body_out_str, .metadata = null, + .redirected = this.flags.redirected, .fail = this.state.fail, // check if we are reporting cert errors, do not have a fail state and we are not done .has_more = certificate_info != null or (this.state.fail == null and !this.state.isDone()), @@ -3417,7 +3430,7 @@ fn handleResponseBodyFromSinglePacket(this: *HTTPClient, incoming_data: []const } } // we can ignore the body data in redirects - if (this.state.is_redirect_pending) return; + if (this.state.flags.is_redirect_pending) return; if (this.state.encoding.isCompressed()) { try this.state.decompressBytes(incoming_data, this.state.body_out_str.?, true); @@ -3449,7 +3462,7 @@ fn handleResponseBodyFromMultiplePackets(this: *HTTPClient, incoming_data: []con } // we can ignore the body data in redirects - if (!this.state.is_redirect_pending) { + if (!this.state.flags.is_redirect_pending) { if (buffer.list.items.len == 0 and incoming_data.len < preallocate_max) { buffer.list.ensureTotalCapacityPrecise(buffer.allocator, incoming_data.len) catch {}; } @@ -3473,7 +3486,7 @@ fn handleResponseBodyFromMultiplePackets(this: *HTTPClient, incoming_data: []con // We can only use the libdeflate fast path when we are not streaming // If we ever call processBodyBuffer again, it cannot go through the fast path. - this.state.is_libdeflate_fast_path_disabled = true; + this.state.flags.is_libdeflate_fast_path_disabled = true; if (this.progress_node) |progress| { progress.activate(); @@ -3539,7 +3552,7 @@ fn handleResponseBodyChunkedEncodingFromMultiplePackets( // streaming chunks if (this.signals.get(.body_streaming)) { // If we're streaming, we cannot use the libdeflate fast path - this.state.is_libdeflate_fast_path_disabled = true; + this.state.flags.is_libdeflate_fast_path_disabled = true; return try this.state.processBodyBuffer(buffer, false); } @@ -3547,7 +3560,7 @@ fn handleResponseBodyChunkedEncodingFromMultiplePackets( }, // Done else => { - this.state.received_last_chunk = true; + this.state.flags.received_last_chunk = true; _ = try this.state.processBodyBuffer( buffer, true, @@ -3620,7 +3633,7 @@ fn handleResponseBodyChunkedEncodingFromSinglePacket( // streaming chunks if (this.signals.get(.body_streaming)) { // If we're streaming, we cannot use the libdeflate fast path - this.state.is_libdeflate_fast_path_disabled = true; + this.state.flags.is_libdeflate_fast_path_disabled = true; return try this.state.processBodyBuffer(body_buffer.*, true); } @@ -3629,7 +3642,7 @@ fn handleResponseBodyChunkedEncodingFromSinglePacket( }, // Done else => { - this.state.received_last_chunk = true; + this.state.flags.received_last_chunk = true; try this.handleResponseBodyFromSinglePacket(buffer); assert(this.state.body_out_str.?.list.items.ptr != buffer.ptr); if (this.progress_node) |progress| { @@ -3674,7 +3687,7 @@ pub fn handleResponseMetadata( } }, hashHeaderConst("Content-Encoding") => { - if (!this.disable_decompression) { + if (!this.flags.disable_decompression) { if (strings.eqlComptime(header.value, "gzip")) { this.state.encoding = Encoding.gzip; this.state.content_encoding_i = @as(u8, @truncate(header_i)); @@ -3689,15 +3702,15 @@ pub fn handleResponseMetadata( }, hashHeaderConst("Transfer-Encoding") => { if (strings.eqlComptime(header.value, "gzip")) { - if (!this.disable_decompression) { + if (!this.flags.disable_decompression) { this.state.transfer_encoding = Encoding.gzip; } } else if (strings.eqlComptime(header.value, "deflate")) { - if (!this.disable_decompression) { + if (!this.flags.disable_decompression) { this.state.transfer_encoding = Encoding.deflate; } } else if (strings.eqlComptime(header.value, "br")) { - if (!this.disable_decompression) { + if (!this.flags.disable_decompression) { this.state.transfer_encoding = .brotli; } } else if (strings.eqlComptime(header.value, "identity")) { @@ -3714,12 +3727,12 @@ pub fn handleResponseMetadata( hashHeaderConst("Connection") => { if (response.status_code >= 200 and response.status_code <= 299) { if (!strings.eqlComptime(header.value, "keep-alive")) { - this.state.allow_keepalive = false; + this.state.flags.allow_keepalive = false; } } }, hashHeaderConst("Last-Modified") => { - pretend_304 = this.force_last_modified and response.status_code > 199 and response.status_code < 300 and this.if_modified_since.len > 0 and strings.eql(this.if_modified_since, header.value); + pretend_304 = this.flags.force_last_modified and response.status_code > 199 and response.status_code < 300 and this.if_modified_since.len > 0 and strings.eql(this.if_modified_since, header.value); }, else => {}, @@ -3735,7 +3748,7 @@ pub fn handleResponseMetadata( } // Don't do this for proxies because those connections will be open for awhile. - if (!this.proxy_tunneling) { + if (!this.flags.proxy_tunneling) { // according to RFC 7230 section 3.3.3: // 1. Any response to a HEAD request and any response with a 1xx (Informational), @@ -3757,18 +3770,18 @@ pub fn handleResponseMetadata( // // but, we must only do this IF the status code allows it to contain a body. else if (this.state.content_length == null and this.state.transfer_encoding != .chunked) { - this.state.allow_keepalive = false; + this.state.flags.allow_keepalive = false; } } - if (this.proxy_tunneling and this.proxy_tunnel == null) { + if (this.flags.proxy_tunneling and this.proxy_tunnel == null) { if (response.status_code == 200) { // signal to continue the proxing return ShouldContinue.continue_streaming; } //proxy denied connection so return proxy result (407, 403 etc) - this.proxy_tunneling = false; + this.flags.proxy_tunneling = false; } const status_code = response.status_code; @@ -3973,9 +3986,9 @@ pub fn handleResponseMetadata( } } - this.state.is_redirect_pending = true; + this.state.flags.is_redirect_pending = true; if (this.method.hasRequestBody()) { - this.state.resend_request_body_on_redirect = true; + this.state.flags.resend_request_body_on_redirect = true; } }, else => {}, @@ -3994,7 +4007,7 @@ pub fn handleResponseMetadata( log("handleResponseMetadata: content_length is null and transfer_encoding {}", .{this.state.transfer_encoding}); } - if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events)) { + if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events)) { return ShouldContinue.continue_streaming; } else { return ShouldContinue.finished; diff --git a/src/install/install.zig b/src/install/install.zig index 8d6bb0a0bf..0f73dffe9d 100644 --- a/src/install/install.zig +++ b/src/install/install.zig @@ -429,7 +429,7 @@ const NetworkTask = struct { this.http = AsyncHTTP.init(allocator, .GET, url, header_builder.entries, header_builder.content.ptr.?[0..header_builder.content.len], &this.response_buffer, "", this.getCompletionCallback(), HTTP.FetchRedirect.follow, .{ .http_proxy = this.package_manager.httpProxy(url), }); - this.http.client.reject_unauthorized = this.package_manager.tlsRejectUnauthorized(); + this.http.client.flags.reject_unauthorized = this.package_manager.tlsRejectUnauthorized(); if (PackageManager.verbose_install) { this.http.client.verbose = .headers; @@ -449,7 +449,7 @@ const NetworkTask = struct { // Incase the ETag causes invalidation, we fallback to the last modified date. if (last_modified.len != 0 and bun.getRuntimeFeatureFlag("BUN_FEATURE_FLAG_LAST_MODIFIED_PRETEND_304")) { - this.http.client.force_last_modified = true; + this.http.client.flags.force_last_modified = true; this.http.client.if_modified_since = last_modified; } } @@ -513,7 +513,7 @@ const NetworkTask = struct { this.http = AsyncHTTP.init(allocator, .GET, url, header_builder.entries, header_buf, &this.response_buffer, "", this.getCompletionCallback(), HTTP.FetchRedirect.follow, .{ .http_proxy = this.package_manager.httpProxy(url), }); - this.http.client.reject_unauthorized = this.package_manager.tlsRejectUnauthorized(); + this.http.client.flags.reject_unauthorized = this.package_manager.tlsRejectUnauthorized(); if (PackageManager.verbose_install) { this.http.client.verbose = .headers; } diff --git a/src/js/node/http2.ts b/src/js/node/http2.ts index 384281b45a..8a17aa5fb2 100644 --- a/src/js/node/http2.ts +++ b/src/js/node/http2.ts @@ -852,9 +852,9 @@ class ClientHttp2Session extends Http2Session { process.nextTick(emitWantTrailersNT, self.#streams, streamId); } }, - goaway(self: ClientHttp2Session, errorCode: number, lastStreamId: number, opaqueData: Buffer) { + goaway(self: ClientHttp2Session, errorCode: number, lastStreamId: number, opaqueData?: Buffer) { if (!self) return; - self.emit("goaway", errorCode, lastStreamId, opaqueData); + self.emit("goaway", errorCode, lastStreamId, opaqueData || Buffer.allocUnsafe(0)); if (errorCode !== 0) { for (let [_, stream] of self.#streams) { stream.rstCode = errorCode; diff --git a/test/cli/hot/watch.test.ts b/test/cli/hot/watch.test.ts index db664ad1e7..1cd7d77ace 100644 --- a/test/cli/hot/watch.test.ts +++ b/test/cli/hot/watch.test.ts @@ -1,19 +1,18 @@ import { spawn } from "bun"; import { describe, expect, test } from "bun:test"; -import { writeFile } from "fs/promises"; +import { writeFile } from "node:fs/promises"; import { bunEnv, bunExe, forEachLine, tempDirWithFiles } from "harness"; -import { join } from "path"; +import { join } from "node:path"; describe("--watch works", async () => { - for (let watchedFile of ["tmp.js", "entry.js"]) { - test("with " + watchedFile, async () => { + for (const watchedFile of ["tmp.js", "entry.js"]) { + test(`with ${watchedFile}`, async () => { const tmpdir_ = tempDirWithFiles("watch-fixture", { "tmp.js": "console.log('hello #1')", "entry.js": "import './tmp.js'", "package.json": JSON.stringify({ name: "foo", version: "0.0.1" }), }); const tmpfile = join(tmpdir_, "tmp.js"); - await writeFile(tmpfile, "console.log('hello #1')"); const process = spawn({ cmd: [bunExe(), "--watch", join(tmpdir_, watchedFile)], cwd: tmpdir_, @@ -22,7 +21,7 @@ describe("--watch works", async () => { }); const { stdout } = process; - let iter = forEachLine(stdout); + const iter = forEachLine(stdout); let { value: line, done } = await iter.next(); expect(done).toBe(false); expect(line).toBe("hello #1"); @@ -43,7 +42,7 @@ describe("--watch works", async () => { ({ value: line } = await iter.next()); expect(line).toBe("hello #5"); - process.kill(); + process.kill("SIGKILL"); await process.exited; }); } diff --git a/test/integration/next-pages/test/dev-server-ssr-100.test.ts b/test/integration/next-pages/test/dev-server-ssr-100.test.ts index 5e5983b33d..cd204005da 100644 --- a/test/integration/next-pages/test/dev-server-ssr-100.test.ts +++ b/test/integration/next-pages/test/dev-server-ssr-100.test.ts @@ -96,7 +96,7 @@ beforeAll(async () => { stdin: "inherit", }); if (!install.success) { - const reason = installProcess.signalCode || `code ${installProcess.exitCode}`; + const reason = install.signalCode || `code ${install.exitCode}`; throw new Error(`Failed to install dependencies: ${reason}`); } diff --git a/test/js/node/http/node-http-response-write-encode-fixture.js b/test/js/node/http/node-http-response-write-encode-fixture.js new file mode 100644 index 0000000000..b47e988767 --- /dev/null +++ b/test/js/node/http/node-http-response-write-encode-fixture.js @@ -0,0 +1,82 @@ +import { createServer } from "node:http"; +import { expect } from "bun:test"; +function disableAggressiveGCScope() { + const gc = Bun.unsafe.gcAggressionLevel(0); + return { + [Symbol.dispose]() { + Bun.unsafe.gcAggressionLevel(gc); + }, + }; +} +// x = ascii +// รก = latin1 supplementary character +// ๐Ÿ“™ = emoji +// ๐Ÿ‘๐Ÿฝ = its a grapheme of ๐Ÿ‘ ๐ŸŸค +// "\u{1F600}" = utf16 +const chars = ["x", "รก", "๐Ÿ“™", "๐Ÿ‘๐Ÿฝ", "\u{1F600}"]; + +// 128 = small than waterMark, 256 = waterMark, 1024 = large than waterMark +// 8Kb = small than cork buffer +// 16Kb = cork buffer +// 32Kb = large than cork buffer +const start_size = 128; +const increment_step = 1024; +const end_size = 32 * 1024; +let expected = ""; + +const { promise, reject, resolve } = Promise.withResolvers(); + +async function finish(err) { + server.closeAllConnections(); + Bun.gc(true); + if (err) reject(err); + resolve(err); +} +const server = createServer((_, response) => { + response.write(expected); + response.write(""); + response.end(); +}).listen(0, "localhost", async (err, hostname, port) => { + try { + expect(err).toBeFalsy(); + expect(port).toBeGreaterThan(0); + + for (const char of chars) { + for (let size = start_size; size <= end_size; size += increment_step) { + expected = char + Buffer.alloc(size, "-").toString("utf8") + "x"; + + try { + const url = `http://${hostname}:${port}`; + const count = 20; + const all = []; + const batchSize = 20; + while (all.length < count) { + const batch = Array.from({ length: batchSize }, () => fetch(url).then(a => a.text())); + + all.push(...(await Promise.all(batch))); + } + + using _ = disableAggressiveGCScope(); + for (const result of all) { + expect(result).toBe(expected); + } + } catch (err) { + return finish(err); + } + } + + // still always run GC at the end here. + Bun.gc(true); + } + finish(); + } catch (err) { + finish(err); + } +}); + +promise + .then(() => process.exit(0)) + .catch(err => { + console.error(err); + process.exit(1); + }); diff --git a/test/js/node/http/node-http.test.ts b/test/js/node/http/node-http.test.ts index 883f30ba7b..ab76e1584d 100644 --- a/test/js/node/http/node-http.test.ts +++ b/test/js/node/http/node-http.test.ts @@ -1830,68 +1830,16 @@ if (process.platform !== "win32") { }); } -it("#10177 response.write with non-ascii latin1 should not cause duplicated character or segfault", done => { - // x = ascii - // รก = latin1 supplementary character - // ๐Ÿ“™ = emoji - // ๐Ÿ‘๐Ÿฝ = its a grapheme of ๐Ÿ‘ ๐ŸŸค - // "\u{1F600}" = utf16 - const chars = ["x", "รก", "๐Ÿ“™", "๐Ÿ‘๐Ÿฝ", "\u{1F600}"]; - - // 128 = small than waterMark, 256 = waterMark, 1024 = large than waterMark - // 8Kb = small than cork buffer - // 16Kb = cork buffer - // 32Kb = large than cork buffer - const start_size = 128; - const increment_step = 1024; - const end_size = 32 * 1024; - let expected = ""; - - function finish(err) { - server.closeAllConnections(); - Bun.gc(true); - done(err); - } - const server = require("http") - .createServer((_, response) => { - response.write(expected); - response.write(""); - response.end(); - }) - .listen(0, "localhost", async (err, hostname, port) => { - expect(err).toBeFalsy(); - expect(port).toBeGreaterThan(0); - - for (const char of chars) { - for (let size = start_size; size <= end_size; size += increment_step) { - expected = char + Buffer.alloc(size, "-").toString("utf8") + "x"; - - try { - const url = `http://${hostname}:${port}`; - const count = 20; - const all = []; - const batchSize = 20; - while (all.length < count) { - const batch = Array.from({ length: batchSize }, () => fetch(url).then(a => a.text())); - - all.push(...(await Promise.all(batch))); - } - - using _ = disableAggressiveGCScope(); - for (const result of all) { - expect(result).toBe(expected); - } - } catch (err) { - return finish(err); - } - } - - // still always run GC at the end here. - Bun.gc(true); - } - finish(); - }); -}, 20_000); +it("#10177 response.write with non-ascii latin1 should not cause duplicated character or segfault", () => { + // this can cause a segfault so we run it in a separate process + const { exitCode } = Bun.spawnSync({ + cmd: [bunExe(), "run", path.join(import.meta.dir, "node-http-response-write-encode-fixture.js")], + env: bunEnv, + stdout: "inherit", + stderr: "inherit", + }); + expect(exitCode).toBe(0); +}, 60_000); it("#11425 http no payload limit", done => { const server = Server((req, res) => { @@ -1944,7 +1892,7 @@ it("should emit events in the right order", async () => { it("destroy should end download", async () => { // just simulate some file that will take forever to download const payload = Buffer.from("X".repeat(128 * 1024)); - + let sendedByteLength = 0; using server = Bun.serve({ port: 0, async fetch(req) { @@ -1952,6 +1900,7 @@ it("destroy should end download", async () => { req.signal.onabort = () => (running = false); return new Response(async function* () { while (running) { + sendedByteLength += payload.byteLength; yield payload; await Bun.sleep(10); } @@ -1976,8 +1925,10 @@ it("destroy should end download", async () => { req.destroy(); await Bun.sleep(10); const initialByteLength = receivedByteLength; - expect(receivedByteLength).toBeLessThanOrEqual(payload.length * 3); + // we should receive the same amount of data we sent + expect(initialByteLength).toBeLessThanOrEqual(sendedByteLength); await Bun.sleep(10); + // we should not receive more data after destroy expect(initialByteLength).toBe(receivedByteLength); await Bun.sleep(10); } diff --git a/test/js/node/worker_threads/worker_threads.test.ts b/test/js/node/worker_threads/worker_threads.test.ts index 55d1afc389..22cb0b8690 100644 --- a/test/js/node/worker_threads/worker_threads.test.ts +++ b/test/js/node/worker_threads/worker_threads.test.ts @@ -133,7 +133,7 @@ test("threadId module and worker property is consistent", async () => { expect(worker1.threadId).toBeGreaterThan(0); expect(() => worker1.postMessage({ workerId: worker1.threadId })).not.toThrow(); const worker2 = new Worker(new URL("./worker-thread-id.ts", import.meta.url).href); - expect(worker2.threadId).toBe(worker1.threadId + 1); + expect(worker2.threadId).toBeGreaterThan(worker1.threadId); expect(() => worker2.postMessage({ workerId: worker2.threadId })).not.toThrow(); await worker1.terminate(); await worker2.terminate(); diff --git a/test/js/third_party/grpc-js/common.ts b/test/js/third_party/grpc-js/common.ts index 9796b1013d..f4d5a35930 100644 --- a/test/js/third_party/grpc-js/common.ts +++ b/test/js/third_party/grpc-js/common.ts @@ -40,6 +40,7 @@ export class TestServer { GRPC_TEST_USE_TLS: this.useTls ? "true" : "false", GRPC_TEST_OPTIONS: JSON.stringify(this.#options), GRPC_SERVICE_TYPE: this.service_type.toString(), + "grpc-node.max_session_memory": 1024, }); this.address = result.address as AddressInfo; this.url = result.url as string; diff --git a/test/js/web/fetch/fetch-leak-test-fixture-4.js b/test/js/web/fetch/fetch-leak-test-fixture-4.js index 687424fa74..73c8ccd5d6 100644 --- a/test/js/web/fetch/fetch-leak-test-fixture-4.js +++ b/test/js/web/fetch/fetch-leak-test-fixture-4.js @@ -7,6 +7,7 @@ function getHeapStats() { const server = process.argv[2]; const batch = 50; const iterations = 10; +const threshold = batch * 2 + batch / 2; try { for (let i = 0; i < iterations; i++) { @@ -20,9 +21,10 @@ try { { Bun.gc(true); + await Bun.sleep(10); const stats = getHeapStats(); - expect(stats.Response || 0).toBeLessThanOrEqual(batch + 5); - expect(stats.Promise || 0).toBeLessThanOrEqual(batch + 5); + expect(stats.Response || 0).toBeLessThanOrEqual(threshold); + expect(stats.Promise || 0).toBeLessThanOrEqual(threshold); } } process.exit(0); diff --git a/test/js/web/fetch/fetch-leak.test.js b/test/js/web/fetch/fetch-leak.test.js index dfcadeb51c..c449560509 100644 --- a/test/js/web/fetch/fetch-leak.test.js +++ b/test/js/web/fetch/fetch-leak.test.js @@ -70,7 +70,7 @@ describe("fetch doesn't leak", () => { } if (compressed) { - env.COUNT = "5000"; + env.COUNT = "1000"; } const proc = Bun.spawn({ diff --git a/test/js/web/fetch/fetch.test.ts b/test/js/web/fetch/fetch.test.ts index b8ef539066..2e587e31cb 100644 --- a/test/js/web/fetch/fetch.test.ts +++ b/test/js/web/fetch/fetch.test.ts @@ -504,7 +504,7 @@ describe("fetch", () => { }); expect(response.status).toBe(302); expect(response.headers.get("location")).toBe("https://example.com"); - expect(response.redirected).toBe(true); + expect(response.redirected).toBe(false); // not redirected }); it('redirect: "follow"', async () => { diff --git a/test/js/web/websocket/websocket.test.js b/test/js/web/websocket/websocket.test.js index c323a555b7..abc3b3eb12 100644 --- a/test/js/web/websocket/websocket.test.js +++ b/test/js/web/websocket/websocket.test.js @@ -529,8 +529,8 @@ describe("WebSocket", () => { await openAndCloseWS(); if (i % 100 === 0) { current_websocket_count = getWebSocketCount(); - // if we have more than 20 websockets open, we have a problem - expect(current_websocket_count).toBeLessThanOrEqual(20); + // if we have more than 1 batch of websockets open, we have a problem + expect(current_websocket_count).toBeLessThanOrEqual(100); if (initial_websocket_count === 0) { initial_websocket_count = current_websocket_count; } diff --git a/test/js/web/workers/worker.test.ts b/test/js/web/workers/worker.test.ts index 609eb37dfc..48f6d0f636 100644 --- a/test/js/web/workers/worker.test.ts +++ b/test/js/web/workers/worker.test.ts @@ -3,8 +3,6 @@ import { bunEnv, bunExe, isWindows } from "harness"; import path from "path"; import wt from "worker_threads"; -const todoIfWindows = isWindows ? test.todo : test; - describe("web worker", () => { async function waitForWorkerResult(worker: Worker, message: any): Promise { const promise = new Promise((resolve, reject) => { @@ -237,7 +235,7 @@ describe("worker_threads", () => { }); }); - todoIfWindows("worker terminate", async () => { + test("worker terminate", async () => { const worker = new wt.Worker(new URL("worker-fixture-hang.js", import.meta.url).href, { smol: true, }); @@ -245,7 +243,7 @@ describe("worker_threads", () => { expect(code).toBe(0); }); - todoIfWindows("worker with process.exit (delay) and terminate", async () => { + test("worker with process.exit (delay) and terminate", async () => { const worker = new wt.Worker(new URL("worker-fixture-process-exit.js", import.meta.url).href, { smol: true, }); diff --git a/test/regression/issue/07500/07500.test.ts b/test/regression/issue/07500/07500.test.ts index 5a1c9a6682..896e8dd638 100644 --- a/test/regression/issue/07500/07500.test.ts +++ b/test/regression/issue/07500/07500.test.ts @@ -10,7 +10,7 @@ test("7500 - Bun.stdin.text() doesn't read all data", async () => { .split(" ") .join("\n"); await Bun.write(filename, text); - const cat = "cat"; + const cat = isWindows ? "Get-Content" : "cat"; const bunCommand = `${bunExe()} ${join(import.meta.dir, "07500.fixture.js")}`; const shellCommand = `${cat} ${filename} | ${bunCommand}`.replace(/\\/g, "\\\\");