From aef0b5b4a678d56fefcf8954f1e1cbf760884d7e Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Mon, 15 Dec 2025 18:43:51 -0800 Subject: [PATCH] fix(usockets): safely handle socket reallocation during context adoption (#25361) ## Summary - Fix use-after-free vulnerability during socket adoption by properly tracking reallocated sockets - Add safety checks to prevent linking closed sockets to context lists - Properly track socket state with new `is_closed`, `adopted`, and `is_tls` flags ## What does this PR do? This PR improves event loop stability by addressing potential use-after-free issues that can occur when sockets are reallocated during adoption (e.g., when upgrading a TCP socket to TLS). ### Key Changes **Socket State Tracking ([internal.h](packages/bun-usockets/src/internal/internal.h))** - Added `is_closed` flag to explicitly track when a socket has been closed - Added `adopted` flag to mark sockets that were reallocated during context adoption - Added `is_tls` flag to track TLS socket state for proper low-priority queue handling **Safe Socket Adoption ([context.c](packages/bun-usockets/src/context.c))** - When `us_poll_resize()` returns a new pointer (reallocation occurred), the old socket is now: - Marked as closed (`is_closed = 1`) - Added to the closed socket cleanup list - Marked as adopted (`adopted = 1`) - Has its `prev` pointer set to the new socket for event redirection - Added guards to `us_internal_socket_context_link_socket/listen_socket/connecting_socket` to prevent linking already-closed sockets **Event Loop Handling ([loop.c](packages/bun-usockets/src/loop.c))** - After callbacks that can trigger socket adoption (`on_open`, `on_writable`, `on_data`), the event loop now checks if the socket was reallocated and redirects to the new socket - Low-priority socket handling now properly checks `is_closed` state and uses `is_tls` flag for correct SSL handling **Poll Resize Safety ([epoll_kqueue.c](packages/bun-usockets/src/eventing/epoll_kqueue.c))** - Changed `us_poll_resize()` to always allocate new memory with `us_calloc()` instead of `us_realloc()` to ensure the old pointer remains valid for cleanup - Now takes `old_ext_size` parameter to correctly calculate memory sizes - Re-enabled `us_internal_loop_update_pending_ready_polls()` call in `us_poll_change()` to ensure pending events are properly redirected ### How did you verify your code works? Run existing CI and existing socket upgrade tests under asan build --- packages/bun-usockets/src/context.c | 48 +++++++++++++++---- packages/bun-usockets/src/crypto/openssl.c | 18 ++++--- .../bun-usockets/src/eventing/epoll_kqueue.c | 38 +++++++++------ packages/bun-usockets/src/eventing/libuv.c | 22 ++++++++- packages/bun-usockets/src/internal/internal.h | 10 +++- packages/bun-usockets/src/libusockets.h | 6 +-- packages/bun-usockets/src/loop.c | 33 +++++++++++-- packages/bun-usockets/src/socket.c | 25 +++++----- packages/bun-uws/src/HttpResponse.h | 2 +- src/bun.js/api/bun/socket.zig | 9 ++-- src/deps/libuwsockets.cpp | 2 + src/deps/uws/SocketContext.zig | 6 +-- src/deps/uws/socket.zig | 8 ++-- src/js/node/net.ts | 2 +- test/js/bun/http/bun-websocket-cpu-fixture.js | 4 +- 15 files changed, 169 insertions(+), 64 deletions(-) diff --git a/packages/bun-usockets/src/context.c b/packages/bun-usockets/src/context.c index 6e2c3f3e18..043cdd05f9 100644 --- a/packages/bun-usockets/src/context.c +++ b/packages/bun-usockets/src/context.c @@ -54,8 +54,8 @@ void us_listen_socket_close(int ssl, struct us_listen_socket_t *ls) { s->next = loop->data.closed_head; loop->data.closed_head = s; - /* Any socket with prev = context is marked as closed */ - s->prev = (struct us_socket_t *) context; + /* Mark the socket as closed */ + s->flags.is_closed = 1; } /* We cannot immediately free a listen socket as we can be inside an accept loop */ @@ -154,7 +154,9 @@ void us_internal_socket_context_unlink_connecting_socket(int ssl, struct us_sock /* We always add in the top, so we don't modify any s.next */ void us_internal_socket_context_link_listen_socket(int ssl, struct us_socket_context_t *context, struct us_listen_socket_t *ls) { + struct us_socket_t* s = &ls->s; + if(us_socket_is_closed(ssl, s)) return; s->context = context; s->next = (struct us_socket_t *) context->head_listen_sockets; s->prev = 0; @@ -166,6 +168,8 @@ void us_internal_socket_context_link_listen_socket(int ssl, struct us_socket_con } void us_internal_socket_context_link_connecting_socket(int ssl, struct us_socket_context_t *context, struct us_connecting_socket_t *c) { + if(c->closed) return; + c->context = context; c->next_pending = context->head_connecting_sockets; c->prev_pending = 0; @@ -180,6 +184,8 @@ void us_internal_socket_context_link_connecting_socket(int ssl, struct us_socket /* We always add in the top, so we don't modify any s.next */ void us_internal_socket_context_link_socket(int ssl, struct us_socket_context_t *context, struct us_socket_t *s) { + if(us_socket_is_closed(ssl,s)) return; + s->context = context; s->next = context->head_sockets; s->prev = 0; @@ -386,6 +392,9 @@ struct us_listen_socket_t *us_socket_context_listen(int ssl, struct us_socket_co s->flags.low_prio_state = 0; s->flags.is_paused = 0; s->flags.is_ipc = 0; + s->flags.is_closed = 0; + s->flags.adopted = 0; + s->flags.is_tls = ssl; s->next = 0; s->flags.allow_half_open = (options & LIBUS_SOCKET_ALLOW_HALF_OPEN); us_internal_socket_context_link_listen_socket(ssl, context, ls); @@ -422,6 +431,9 @@ struct us_listen_socket_t *us_socket_context_listen_unix(int ssl, struct us_sock s->flags.allow_half_open = (options & LIBUS_SOCKET_ALLOW_HALF_OPEN); s->flags.is_paused = 0; s->flags.is_ipc = 0; + s->flags.is_closed = 0; + s->flags.adopted = 0; + s->flags.is_tls = ssl; s->next = 0; us_internal_socket_context_link_listen_socket(ssl, context, ls); @@ -430,7 +442,7 @@ struct us_listen_socket_t *us_socket_context_listen_unix(int ssl, struct us_sock return ls; } -struct us_socket_t* us_socket_context_connect_resolved_dns(struct us_socket_context_t *context, struct sockaddr_storage* addr, int options, int socket_ext_size) { +struct us_socket_t* us_socket_context_connect_resolved_dns(int ssl, struct us_socket_context_t *context, struct sockaddr_storage* addr, int options, int socket_ext_size) { LIBUS_SOCKET_DESCRIPTOR connect_socket_fd = bsd_create_connect_socket(addr, options); if (connect_socket_fd == LIBUS_SOCKET_ERROR) { return NULL; @@ -453,6 +465,9 @@ struct us_socket_t* us_socket_context_connect_resolved_dns(struct us_socket_cont socket->flags.allow_half_open = (options & LIBUS_SOCKET_ALLOW_HALF_OPEN); socket->flags.is_paused = 0; socket->flags.is_ipc = 0; + socket->flags.is_closed = 0; + socket->flags.adopted = 0; + socket->flags.is_tls = ssl; socket->connect_state = NULL; socket->connect_next = NULL; @@ -514,7 +529,7 @@ void *us_socket_context_connect(int ssl, struct us_socket_context_t *context, co struct sockaddr_storage addr; if (try_parse_ip(host, port, &addr)) { *has_dns_resolved = 1; - return us_socket_context_connect_resolved_dns(context, &addr, options, socket_ext_size); + return us_socket_context_connect_resolved_dns(ssl, context, &addr, options, socket_ext_size); } struct addrinfo_request* ai_req; @@ -534,7 +549,7 @@ void *us_socket_context_connect(int ssl, struct us_socket_context_t *context, co struct sockaddr_storage addr; init_addr_with_port(&entries->info, port, &addr); *has_dns_resolved = 1; - struct us_socket_t *s = us_socket_context_connect_resolved_dns(context, &addr, options, socket_ext_size); + struct us_socket_t *s = us_socket_context_connect_resolved_dns(ssl, context, &addr, options, socket_ext_size); Bun__addrinfo_freeRequest(ai_req, s == NULL); return s; } @@ -583,6 +598,9 @@ int start_connections(struct us_connecting_socket_t *c, int count) { flags->allow_half_open = (c->options & LIBUS_SOCKET_ALLOW_HALF_OPEN); flags->is_paused = 0; flags->is_ipc = 0; + flags->is_closed = 0; + flags->adopted = 0; + flags->is_tls = c->ssl; /* Link it into context so that timeout fires properly */ us_internal_socket_context_link_socket(0, context, s); @@ -760,6 +778,9 @@ struct us_socket_t *us_socket_context_connect_unix(int ssl, struct us_socket_con connect_socket->flags.allow_half_open = (options & LIBUS_SOCKET_ALLOW_HALF_OPEN); connect_socket->flags.is_paused = 0; connect_socket->flags.is_ipc = 0; + connect_socket->flags.is_closed = 0; + connect_socket->flags.adopted = 0; + connect_socket->flags.is_tls = ssl; connect_socket->connect_state = NULL; connect_socket->connect_next = NULL; us_internal_socket_context_link_socket(ssl, context, connect_socket); @@ -780,10 +801,10 @@ struct us_socket_context_t *us_create_child_socket_context(int ssl, struct us_so } /* Note: This will set timeout to 0 */ -struct us_socket_t *us_socket_context_adopt_socket(int ssl, struct us_socket_context_t *context, struct us_socket_t *s, int ext_size) { +struct us_socket_t *us_socket_context_adopt_socket(int ssl, struct us_socket_context_t *context, struct us_socket_t *s, int old_ext_size, int ext_size) { #ifndef LIBUS_NO_SSL if (ssl) { - return (struct us_socket_t *) us_internal_ssl_socket_context_adopt_socket((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t *) s, ext_size); + return (struct us_socket_t *) us_internal_ssl_socket_context_adopt_socket((struct us_internal_ssl_socket_context_t *) context, (struct us_internal_ssl_socket_t *) s, old_ext_size, ext_size); } #endif @@ -807,7 +828,18 @@ struct us_socket_t *us_socket_context_adopt_socket(int ssl, struct us_socket_con struct us_socket_t *new_s = s; if (ext_size != -1) { struct us_poll_t *pool_ref = &s->p; - new_s = (struct us_socket_t *) us_poll_resize(pool_ref, loop, sizeof(struct us_socket_t) + ext_size); + new_s = (struct us_socket_t *) us_poll_resize(pool_ref, loop, sizeof(struct us_socket_t) + old_ext_size, sizeof(struct us_socket_t) + ext_size); + if(new_s != s) { + /* Mark the old socket as closed */ + s->flags.is_closed = 1; + /* Link this socket to the close-list and let it be deleted after this iteration */ + s->next = s->context->loop->data.closed_head; + s->context->loop->data.closed_head = s; + /* Mark the old socket as adopted (reallocated) */ + s->flags.adopted = 1; + /* Tell the event loop what is the new socket so we can process to send info to the right place and callbacks like more data and EOF*/ + s->prev = new_s; + } if (c) { c->connecting_head = new_s; c->context = context; diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index 348819d0e8..db597cf44f 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -396,7 +396,7 @@ void us_internal_update_handshake(struct us_internal_ssl_socket_t *s) { } int result = SSL_do_handshake(s->ssl); - + if (SSL_get_shutdown(s->ssl) & SSL_RECEIVED_SHUTDOWN) { us_internal_ssl_socket_close(s, 0, NULL); return; @@ -417,6 +417,7 @@ void us_internal_update_handshake(struct us_internal_ssl_socket_t *s) { } s->handshake_state = HANDSHAKE_PENDING; s->ssl_write_wants_read = 1; + s->s.context->loop->data.last_write_failed = 1; return; } @@ -434,6 +435,7 @@ ssl_on_close(struct us_internal_ssl_socket_t *s, int code, void *reason) { struct us_internal_ssl_socket_t * ret = context->on_close(s, code, reason); SSL_free(s->ssl); // free SSL after on_close s->ssl = NULL; // set to NULL + return ret; } @@ -1855,15 +1857,16 @@ void us_internal_ssl_socket_shutdown(struct us_internal_ssl_socket_t *s) { struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket( struct us_internal_ssl_socket_context_t *context, - struct us_internal_ssl_socket_t *s, int ext_size) { + struct us_internal_ssl_socket_t *s, int old_ext_size, int ext_size) { // todo: this is completely untested + int new_old_ext_size = sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + old_ext_size; int new_ext_size = ext_size; if (ext_size != -1) { new_ext_size = sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + ext_size; } return (struct us_internal_ssl_socket_t *)us_socket_context_adopt_socket( 0, &context->sc, &s->s, - new_ext_size); + new_old_ext_size, new_ext_size); } struct us_internal_ssl_socket_t * @@ -1920,10 +1923,11 @@ ssl_wrapped_context_on_data(struct us_internal_ssl_socket_t *s, char *data, struct us_wrapped_socket_context_t *wrapped_context = (struct us_wrapped_socket_context_t *)us_internal_ssl_socket_context_ext( context); - // raw data if needed + // raw data if needed if (wrapped_context->old_events.on_data) { wrapped_context->old_events.on_data((struct us_socket_t *)s, data, length); } + // ssl wrapped data return ssl_on_data(s, data, length); } @@ -2028,7 +2032,7 @@ us_internal_ssl_socket_open(struct us_internal_ssl_socket_t *s, int is_client, // already opened if (s->ssl) return s; - + // start SSL open return ssl_on_open(s, is_client, ip, ip_length, NULL); } @@ -2040,6 +2044,7 @@ struct us_socket_t *us_socket_upgrade_to_tls(us_socket_r s, us_socket_context_r struct us_internal_ssl_socket_t *socket = (struct us_internal_ssl_socket_t *)us_socket_context_adopt_socket( 0, new_context, s, + sizeof(void*), (sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t)) + sizeof(void*)); socket->ssl = NULL; socket->ssl_write_wants_read = 0; @@ -2058,7 +2063,7 @@ struct us_socket_t *us_socket_upgrade_to_tls(us_socket_r s, us_socket_context_r struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls( struct us_socket_t *s, struct us_bun_socket_context_options_t options, - struct us_socket_events_t events, int socket_ext_size) { + struct us_socket_events_t events, int old_socket_ext_size, int socket_ext_size) { /* Cannot wrap a closed socket */ if (us_socket_is_closed(0, s)) { return NULL; @@ -2163,6 +2168,7 @@ us_socket_context_on_socket_connect_error( struct us_internal_ssl_socket_t *socket = (struct us_internal_ssl_socket_t *)us_socket_context_adopt_socket( 0, context, s, + old_socket_ext_size, sizeof(struct us_internal_ssl_socket_t) - sizeof(struct us_socket_t) + socket_ext_size); socket->ssl = NULL; diff --git a/packages/bun-usockets/src/eventing/epoll_kqueue.c b/packages/bun-usockets/src/eventing/epoll_kqueue.c index e796b16c05..5c68c028ed 100644 --- a/packages/bun-usockets/src/eventing/epoll_kqueue.c +++ b/packages/bun-usockets/src/eventing/epoll_kqueue.c @@ -325,7 +325,7 @@ void us_internal_loop_update_pending_ready_polls(struct us_loop_t *loop, struct int num_entries_possibly_remaining = 1; #else /* Ready polls may contain same poll twice under kqueue, as one poll may hold two filters */ - int num_entries_possibly_remaining = 2;//((old_events & LIBUS_SOCKET_READABLE) ? 1 : 0) + ((old_events & LIBUS_SOCKET_WRITABLE) ? 1 : 0); + int num_entries_possibly_remaining = 2; #endif /* Todo: for kqueue if we track things in us_change_poll it is possible to have a fast path with no seeking in cases of: @@ -377,22 +377,30 @@ int kqueue_change(int kqfd, int fd, int old_events, int new_events, void *user_d } #endif -struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loop_t *loop, unsigned int ext_size) { - int events = us_poll_events(p); - +struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loop_t *loop, unsigned int old_ext_size, unsigned int ext_size) { - struct us_poll_t *new_p = us_realloc(p, sizeof(struct us_poll_t) + ext_size); - if (p != new_p) { + unsigned int old_size = sizeof(struct us_poll_t) + old_ext_size; + unsigned int new_size = sizeof(struct us_poll_t) + ext_size; + if(new_size <= old_size) return p; + + struct us_poll_t *new_p = us_calloc(1, new_size); + memcpy(new_p, p, old_size); + + /* Increment poll count for the new poll - the old poll will be freed separately + * which decrements the count, keeping the total correct */ + loop->num_polls++; + + int events = us_poll_events(p); #ifdef LIBUS_USE_EPOLL - /* Hack: forcefully update poll by stripping away already set events */ - new_p->state.poll_type = us_internal_poll_type(new_p); - us_poll_change(new_p, loop, events); + /* Hack: forcefully update poll by stripping away already set events */ + new_p->state.poll_type = us_internal_poll_type(new_p); + us_poll_change(new_p, loop, events); #else - /* Forcefully update poll by resetting them with new_p as user data */ - kqueue_change(loop->fd, new_p->state.fd, 0, LIBUS_SOCKET_WRITABLE | LIBUS_SOCKET_READABLE, new_p); -#endif /* This is needed for epoll also (us_change_poll doesn't update the old poll) */ - us_internal_loop_update_pending_ready_polls(loop, p, new_p, events, events); - } + /* Forcefully update poll by resetting them with new_p as user data */ + kqueue_change(loop->fd, new_p->state.fd, 0, LIBUS_SOCKET_WRITABLE | LIBUS_SOCKET_READABLE, new_p); +#endif + /* This is needed for epoll also (us_change_poll doesn't update the old poll) */ + us_internal_loop_update_pending_ready_polls(loop, p, new_p, events, events); return new_p; } @@ -444,7 +452,7 @@ void us_poll_change(struct us_poll_t *p, struct us_loop_t *loop, int events) { kqueue_change(loop->fd, p->state.fd, old_events, events, p); #endif /* Set all removed events to null-polls in pending ready poll list */ - // us_internal_loop_update_pending_ready_polls(loop, p, p, old_events, events); + us_internal_loop_update_pending_ready_polls(loop, p, p, old_events, events); } } diff --git a/packages/bun-usockets/src/eventing/libuv.c b/packages/bun-usockets/src/eventing/libuv.c index 72358e4c14..c71b6e13ec 100644 --- a/packages/bun-usockets/src/eventing/libuv.c +++ b/packages/bun-usockets/src/eventing/libuv.c @@ -71,6 +71,11 @@ void us_poll_init(struct us_poll_t *p, LIBUS_SOCKET_DESCRIPTOR fd, } void us_poll_free(struct us_poll_t *p, struct us_loop_t *loop) { + // poll was adopted and dont own uv_poll_t anymore + if(!p->uv_p) { + free(p); + return; + } /* The idea here is like so; in us_poll_stop we call uv_close after setting * data of uv-poll to 0. This means that in close_cb_free we call free on 0 * with does nothing, since us_poll_stop should not really free the poll. @@ -86,6 +91,7 @@ void us_poll_free(struct us_poll_t *p, struct us_loop_t *loop) { } void us_poll_start(struct us_poll_t *p, struct us_loop_t *loop, int events) { + if(!p->uv_p) return; p->poll_type = us_internal_poll_type(p) | ((events & LIBUS_SOCKET_READABLE) ? POLL_TYPE_POLLING_IN : 0) | ((events & LIBUS_SOCKET_WRITABLE) ? POLL_TYPE_POLLING_OUT : 0); @@ -99,6 +105,7 @@ void us_poll_start(struct us_poll_t *p, struct us_loop_t *loop, int events) { } void us_poll_change(struct us_poll_t *p, struct us_loop_t *loop, int events) { + if(!p->uv_p) return; if (us_poll_events(p) != events) { p->poll_type = us_internal_poll_type(p) | @@ -109,6 +116,7 @@ void us_poll_change(struct us_poll_t *p, struct us_loop_t *loop, int events) { } void us_poll_stop(struct us_poll_t *p, struct us_loop_t *loop) { + if(!p->uv_p) return; uv_poll_stop(p->uv_p); /* We normally only want to close the poll here, not free it. But if we stop @@ -217,10 +225,20 @@ struct us_poll_t *us_create_poll(struct us_loop_t *loop, int fallthrough, /* If we update our block position we have to update the uv_poll data to point * to us */ struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loop_t *loop, - unsigned int ext_size) { + unsigned int old_ext_size, unsigned int ext_size) { + + // cannot resize if we dont own uv_poll_t + if(!p->uv_p) return p; + + unsigned int old_size = sizeof(struct us_poll_t) + old_ext_size; + unsigned int new_size = sizeof(struct us_poll_t) + ext_size; + if(new_size <= old_size) return p; + + struct us_poll_t *new_p = calloc(1, new_size); + memcpy(new_p, p, old_size); - struct us_poll_t *new_p = realloc(p, sizeof(struct us_poll_t) + ext_size); new_p->uv_p->data = new_p; + p->uv_p = NULL; return new_p; } diff --git a/packages/bun-usockets/src/internal/internal.h b/packages/bun-usockets/src/internal/internal.h index 7ee718e723..5799f9e9a9 100644 --- a/packages/bun-usockets/src/internal/internal.h +++ b/packages/bun-usockets/src/internal/internal.h @@ -170,6 +170,12 @@ struct us_socket_flags { unsigned char low_prio_state: 2; /* If true, the socket should be read using readmsg to support receiving file descriptors */ bool is_ipc: 1; + /* If true, the socket has been closed */ + bool is_closed: 1; + /* If true, the socket was reallocated during adoption */ + bool adopted: 1; + /* If true, the socket is a TLS socket */ + bool is_tls: 1; } __attribute__((packed)); @@ -435,11 +441,11 @@ void us_internal_ssl_socket_shutdown(us_internal_ssl_socket_r s); struct us_internal_ssl_socket_t *us_internal_ssl_socket_context_adopt_socket( us_internal_ssl_socket_context_r context, - us_internal_ssl_socket_r s, int ext_size); + us_internal_ssl_socket_r s, int old_ext_size, int ext_size); struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls( us_socket_r s, struct us_bun_socket_context_options_t options, - struct us_socket_events_t events, int socket_ext_size); + struct us_socket_events_t events, int old_socket_ext_size, int socket_ext_size); struct us_internal_ssl_socket_context_t * us_internal_create_child_ssl_socket_context( us_internal_ssl_socket_context_r context, int context_ext_size); diff --git a/packages/bun-usockets/src/libusockets.h b/packages/bun-usockets/src/libusockets.h index 0e746a0388..3d59b7502c 100644 --- a/packages/bun-usockets/src/libusockets.h +++ b/packages/bun-usockets/src/libusockets.h @@ -349,7 +349,7 @@ struct us_loop_t *us_socket_context_loop(int ssl, us_socket_context_r context) n /* Invalidates passed socket, returning a new resized socket which belongs to a different socket context. * Used mainly for "socket upgrades" such as when transitioning from HTTP to WebSocket. */ -struct us_socket_t *us_socket_context_adopt_socket(int ssl, us_socket_context_r context, us_socket_r s, int ext_size); +struct us_socket_t *us_socket_context_adopt_socket(int ssl, us_socket_context_r context, us_socket_r s, int old_ext_size, int ext_size); struct us_socket_t *us_socket_upgrade_to_tls(us_socket_r s, us_socket_context_r new_context, const char *sni); @@ -411,7 +411,7 @@ void *us_poll_ext(us_poll_r p) nonnull_fn_decl; LIBUS_SOCKET_DESCRIPTOR us_poll_fd(us_poll_r p) nonnull_fn_decl; /* Resize an active poll */ -struct us_poll_t *us_poll_resize(us_poll_r p, us_loop_r loop, unsigned int ext_size) nonnull_fn_decl; +struct us_poll_t *us_poll_resize(us_poll_r p, us_loop_r loop, unsigned int old_ext_size, unsigned int ext_size) nonnull_fn_decl; /* Public interfaces for sockets */ @@ -470,7 +470,7 @@ void us_socket_local_address(int ssl, us_socket_r s, char *nonnull_arg buf, int /* Bun extras */ struct us_socket_t *us_socket_pair(struct us_socket_context_t *ctx, int socket_ext_size, LIBUS_SOCKET_DESCRIPTOR* fds); struct us_socket_t *us_socket_from_fd(struct us_socket_context_t *ctx, int socket_ext_size, LIBUS_SOCKET_DESCRIPTOR fd, int ipc); -struct us_socket_t *us_socket_wrap_with_tls(int ssl, us_socket_r s, struct us_bun_socket_context_options_t options, struct us_socket_events_t events, int socket_ext_size); +struct us_socket_t *us_socket_wrap_with_tls(int ssl, us_socket_r s, struct us_bun_socket_context_options_t options, struct us_socket_events_t events, int old_socket_ext_size, int socket_ext_size); int us_socket_raw_write(int ssl, us_socket_r s, const char *data, int length); struct us_socket_t* us_socket_open(int ssl, struct us_socket_t * s, int is_client, char* ip, int ip_length); int us_raw_root_certs(struct us_cert_string_t**out); diff --git a/packages/bun-usockets/src/loop.c b/packages/bun-usockets/src/loop.c index 1f4c232474..ec92d1a30e 100644 --- a/packages/bun-usockets/src/loop.c +++ b/packages/bun-usockets/src/loop.c @@ -193,9 +193,17 @@ void us_internal_handle_low_priority_sockets(struct us_loop_t *loop) { loop_data->low_prio_head = s->next; if (s->next) s->next->prev = 0; s->next = 0; + int ssl = s->flags.is_tls; + + if(us_socket_is_closed(ssl, s)) { + s->flags.low_prio_state = 2; + us_socket_context_unref(ssl, s->context); + continue; + } - us_internal_socket_context_link_socket(0, s->context, s); - us_poll_change(&s->p, us_socket_context(0, s)->loop, us_poll_events(&s->p) | LIBUS_SOCKET_READABLE); + us_internal_socket_context_link_socket(ssl, s->context, s); + us_socket_context_unref(ssl, s->context); + us_poll_change(&s->p, us_socket_context(ssl, s)->loop, us_poll_events(&s->p) | LIBUS_SOCKET_READABLE); s->flags.low_prio_state = 2; } @@ -243,6 +251,7 @@ void us_internal_free_closed_sockets(struct us_loop_t *loop) { /* Free all closed sockets (maybe it is better to reverse order?) */ for (struct us_socket_t *s = loop->data.closed_head; s; ) { struct us_socket_t *next = s->next; + s->prev = s->next = 0; us_poll_free((struct us_poll_t *) s, loop); s = next; } @@ -347,6 +356,9 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int eof, in s->flags.allow_half_open = listen_socket->s.flags.allow_half_open; s->flags.is_paused = 0; s->flags.is_ipc = 0; + s->flags.is_closed = 0; + s->flags.adopted = 0; + s->flags.is_tls = listen_socket->s.flags.is_tls; /* We always use nodelay */ bsd_socket_nodelay(client_fd, 1); @@ -354,7 +366,10 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int eof, in us_internal_socket_context_link_socket(0, listen_socket->s.context, s); listen_socket->s.context->on_open(s, 0, bsd_addr_get_ip(&addr), bsd_addr_get_ip_length(&addr)); - + /* After socket adoption, track the new socket; the old one becomes invalid */ + if(s && s->flags.adopted && s->prev) { + s = s->prev; + } /* Exit accept loop if listen socket was closed in on_open handler */ if (us_socket_is_closed(0, &listen_socket->s)) { break; @@ -369,6 +384,10 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int eof, in case POLL_TYPE_SOCKET: { /* We should only use s, no p after this point */ struct us_socket_t *s = (struct us_socket_t *) p; + /* After socket adoption, track the new socket; the old one becomes invalid */ + if(s && s->flags.adopted && s->prev) { + s = s->prev; + } /* The context can change after calling a callback but the loop is always the same */ struct us_loop_t* loop = s->context->loop; if (events & LIBUS_SOCKET_WRITABLE && !error) { @@ -381,6 +400,10 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int eof, in #endif s = s->context->on_writable(s); + /* After socket adoption, track the new socket; the old one becomes invalid */ + if(s && s->flags.adopted && s->prev) { + s = s->prev; + } if (!s || us_socket_is_closed(0, s)) { return; @@ -477,6 +500,10 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int eof, in if (length > 0) { s = s->context->on_data(s, loop->data.recv_buf + LIBUS_RECV_BUFFER_PADDING, length); + /* After socket adoption, track the new socket; the old one becomes invalid */ + if(s && s->flags.adopted && s->prev) { + s = s->prev; + } // loop->num_ready_polls isn't accessible on Windows. #ifndef WIN32 // rare case: we're reading a lot of data, there's more to be read, and either: diff --git a/packages/bun-usockets/src/socket.c b/packages/bun-usockets/src/socket.c index a4b02a7f42..b7577b8090 100644 --- a/packages/bun-usockets/src/socket.c +++ b/packages/bun-usockets/src/socket.c @@ -125,7 +125,7 @@ int us_socket_is_closed(int ssl, struct us_socket_t *s) { if(ssl) { return us_internal_ssl_socket_is_closed((struct us_internal_ssl_socket_t *) s); } - return s->prev == (struct us_socket_t *) s->context; + return s->flags.is_closed; } int us_connecting_socket_is_closed(int ssl, struct us_connecting_socket_t *c) { @@ -159,8 +159,8 @@ void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) { s->next = s->context->loop->data.closed_head; s->context->loop->data.closed_head = s; - /* Any socket with prev = context is marked as closed */ - s->prev = (struct us_socket_t *) s->context; + /* Mark the socket as closed */ + s->flags.is_closed = 1; } if(!c->error) { // if we have no error, we have to set that we were aborted aka we called close @@ -218,11 +218,10 @@ struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, vo bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); + /* Mark the socket as closed */ + s->flags.is_closed = 1; - /* Any socket with prev = context is marked as closed */ - s->prev = (struct us_socket_t *) s->context; - - /* mark it as closed and call the callback */ + /* call the callback */ struct us_socket_t *res = s; if (!(us_internal_poll_type(&s->p) & POLL_TYPE_SEMI_SOCKET)) { res = s->context->on_close(s, code, reason); @@ -268,8 +267,8 @@ struct us_socket_t *us_socket_detach(int ssl, struct us_socket_t *s) { s->next = s->context->loop->data.closed_head; s->context->loop->data.closed_head = s; - /* Any socket with prev = context is marked as closed */ - s->prev = (struct us_socket_t *) s->context; + /* Mark the socket as closed */ + s->flags.is_closed = 1; return s; } @@ -321,8 +320,10 @@ struct us_socket_t *us_socket_from_fd(struct us_socket_context_t *ctx, int socke s->flags.low_prio_state = 0; s->flags.allow_half_open = 0; s->flags.is_paused = 0; - s->flags.is_ipc = 0; s->flags.is_ipc = ipc; + s->flags.is_closed = 0; + s->flags.adopted = 0; + s->flags.is_tls = 0; s->connect_state = NULL; /* We always use nodelay */ @@ -476,13 +477,13 @@ int us_connecting_socket_get_error(int ssl, struct us_connecting_socket_t *c) { Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context context ext will not be copied to the new context, new context will contain us_wrapped_socket_context_t on ext */ -struct us_socket_t *us_socket_wrap_with_tls(int ssl, struct us_socket_t *s, struct us_bun_socket_context_options_t options, struct us_socket_events_t events, int socket_ext_size) { +struct us_socket_t *us_socket_wrap_with_tls(int ssl, struct us_socket_t *s, struct us_bun_socket_context_options_t options, struct us_socket_events_t events, int old_socket_ext_size, int socket_ext_size) { // only accepts non-TLS sockets if (ssl) { return NULL; } - return(struct us_socket_t *) us_internal_ssl_socket_wrap_with_tls(s, options, events, socket_ext_size); + return(struct us_socket_t *) us_internal_ssl_socket_wrap_with_tls(s, options, events, old_socket_ext_size, socket_ext_size); } // if a TLS socket calls this, it will start SSL call open event and TLS handshake if required diff --git a/packages/bun-uws/src/HttpResponse.h b/packages/bun-uws/src/HttpResponse.h index 6d40523508..d1ecb53c2c 100644 --- a/packages/bun-uws/src/HttpResponse.h +++ b/packages/bun-uws/src/HttpResponse.h @@ -331,7 +331,7 @@ public: /* Adopting a socket invalidates it, do not rely on it directly to carry any data */ - us_socket_t *usSocket = us_socket_context_adopt_socket(SSL, (us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(WebSocketData) + sizeof(UserData)); + us_socket_t *usSocket = us_socket_context_adopt_socket(SSL, (us_socket_context_t *) webSocketContext, (us_socket_t *) this, sizeof(HttpResponseData), sizeof(WebSocketData) + sizeof(UserData)); WebSocket *webSocket = (WebSocket *) usSocket; /* For whatever reason we were corked, update cork to the new socket */ diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index be324693b9..5422a69b07 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -528,9 +528,10 @@ pub fn NewSocket(comptime ssl: bool) type { }; } - pub fn onHandshake(this: *This, _: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + pub fn onHandshake(this: *This, s: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { jsc.markBinding(@src()); this.flags.handshake_complete = true; + this.socket = s; if (this.socket.isDetached()) return; const handlers = this.getHandlers(); log("onHandshake {s} ({d})", .{ if (handlers.is_server) "S" else "C", success }); @@ -642,8 +643,9 @@ pub fn NewSocket(comptime ssl: bool) type { }; } - pub fn onData(this: *This, _: Socket, data: []const u8) void { + pub fn onData(this: *This, s: Socket, data: []const u8) void { jsc.markBinding(@src()); + this.socket = s; if (this.socket.isDetached()) return; const handlers = this.getHandlers(); log("onData {s} ({d})", .{ if (handlers.is_server) "S" else "C", data.len }); @@ -1433,7 +1435,6 @@ pub fn NewSocket(comptime ssl: bool) type { } const options = socket_config.asUSockets(); - const ext_size = @sizeOf(WrappedSocket); const handlers_ptr = bun.handleOom(handlers.vm.allocator.create(Handlers)); handlers_ptr.* = handlers; @@ -1463,7 +1464,7 @@ pub fn NewSocket(comptime ssl: bool) type { // reconfigure context to use the new wrapper handlers Socket.unsafeConfigure(this.socket.context().?, true, true, WrappedSocket, TCPHandler); const TLSHandler = NewWrappedHandler(true); - const new_socket = this.socket.wrapTLS(options, ext_size, true, WrappedSocket, TLSHandler) orelse { + const new_socket = this.socket.wrapTLS(options, @sizeOf(*anyopaque), @sizeOf(WrappedSocket), true, WrappedSocket, TLSHandler) orelse { const err = BoringSSL.ERR_get_error(); defer if (err != 0) BoringSSL.ERR_clear_error(); tls.wrapped = .none; diff --git a/src/deps/libuwsockets.cpp b/src/deps/libuwsockets.cpp index 991a66af0d..9fb39434e5 100644 --- a/src/deps/libuwsockets.cpp +++ b/src/deps/libuwsockets.cpp @@ -1700,6 +1700,7 @@ size_t uws_req_get_header(uws_req_t *res, const char *lower_case_header, void us_socket_mark_needs_more_not_ssl(uws_res_r res) { us_socket_r s = (us_socket_t *)res; + if(us_socket_is_closed(s->flags.is_tls, s)) return; s->context->loop->data.last_write_failed = 1; us_poll_change(&s->p, s->context->loop, LIBUS_SOCKET_READABLE | LIBUS_SOCKET_WRITABLE); @@ -1864,6 +1865,7 @@ __attribute__((callback (corker, ctx))) } void us_socket_sendfile_needs_more(us_socket_r s) { + if(us_socket_is_closed(s->flags.is_tls, s)) return; s->context->loop->data.last_write_failed = 1; us_poll_change(&s->p, s->context->loop, LIBUS_SOCKET_READABLE | LIBUS_SOCKET_WRITABLE); } diff --git a/src/deps/uws/SocketContext.zig b/src/deps/uws/SocketContext.zig index caf8b40ef7..3338280885 100644 --- a/src/deps/uws/SocketContext.zig +++ b/src/deps/uws/SocketContext.zig @@ -192,8 +192,8 @@ pub const SocketContext = opaque { c.us_socket_context_remove_server_name(@intFromBool(ssl), this, hostname_pattern); } - pub fn adoptSocket(this: *SocketContext, ssl: bool, s: *us_socket_t, ext_size: i32) ?*us_socket_t { - return c.us_socket_context_adopt_socket(@intFromBool(ssl), this, s, ext_size); + pub fn adoptSocket(this: *SocketContext, ssl: bool, s: *us_socket_t, old_ext_size: i32, ext_size: i32) ?*us_socket_t { + return c.us_socket_context_adopt_socket(@intFromBool(ssl), this, s, old_ext_size, ext_size); } pub fn connect(this: *SocketContext, ssl: bool, host: [*:0]const u8, port: i32, options: i32, socket_ext_size: i32, has_dns_resolved: *i32) ?*anyopaque { @@ -252,7 +252,7 @@ pub const c = struct { pub extern fn us_create_bun_nossl_socket_context(loop: ?*Loop, ext_size: i32) ?*SocketContext; pub extern fn us_create_bun_ssl_socket_context(loop: ?*Loop, ext_size: i32, options: SocketContext.BunSocketContextOptions, err: *create_bun_socket_error_t) ?*SocketContext; pub extern fn us_create_child_socket_context(ssl: i32, context: ?*SocketContext, context_ext_size: i32) ?*SocketContext; - pub extern fn us_socket_context_adopt_socket(ssl: i32, context: *SocketContext, s: *us_socket_t, ext_size: i32) ?*us_socket_t; + pub extern fn us_socket_context_adopt_socket(ssl: i32, context: *SocketContext, s: *us_socket_t, old_ext_size: i32, ext_size: i32) ?*us_socket_t; pub extern fn us_socket_context_close(ssl: i32, ctx: *anyopaque) void; pub extern fn us_socket_context_connect(ssl: i32, context: *SocketContext, host: [*:0]const u8, port: i32, options: i32, socket_ext_size: i32, has_dns_resolved: *i32) ?*anyopaque; pub extern fn us_socket_context_connect_unix(ssl: i32, context: *SocketContext, path: [*:0]const u8, pathlen: usize, options: i32, socket_ext_size: i32) ?*us_socket_t; diff --git a/src/deps/uws/socket.zig b/src/deps/uws/socket.zig index bab59f189c..2e8720bfec 100644 --- a/src/deps/uws/socket.zig +++ b/src/deps/uws/socket.zig @@ -125,6 +125,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { pub fn wrapTLS( this: ThisSocket, options: SocketContext.BunSocketContextOptions, + old_socket_ext_size: i32, socket_ext_size: i32, comptime deref: bool, comptime ContextType: type, @@ -280,7 +281,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { const this_socket = this.socket.get() orelse return null; - const socket = c.us_socket_wrap_with_tls(ssl_int, this_socket, options, events, socket_ext_size) orelse return null; + const socket = c.us_socket_wrap_with_tls(ssl_int, this_socket, options, events, old_socket_ext_size, socket_ext_size) orelse return null; return NewSocketHandler(true).from(socket); } @@ -1066,7 +1067,8 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { ) bool { // ext_size of -1 means we want to keep the current ext size // in particular, we don't want to allocate a new socket - const new_socket = socket_ctx.adoptSocket(comptime is_ssl, socket, -1) orelse return false; + // old_ext_size is irrelevant when ext_size is -1 (no resize occurs) + const new_socket = socket_ctx.adoptSocket(comptime is_ssl, socket, -1, -1) orelse return false; bun.assert(new_socket == socket); var adopted = ThisSocket.from(new_socket); if (adopted.ext(*anyopaque)) |holder| { @@ -1326,7 +1328,7 @@ const c = struct { on_connect_error_connecting_socket: ?*const fn (*ConnectingSocket, i32) callconv(.c) ?*ConnectingSocket = null, on_handshake: ?*const fn (*us_socket_t, i32, uws.us_bun_verify_error_t, ?*anyopaque) callconv(.c) void = null, }; - pub extern fn us_socket_wrap_with_tls(ssl: i32, s: *uws.us_socket_t, options: uws.SocketContext.BunSocketContextOptions, events: c.us_socket_events_t, socket_ext_size: i32) ?*uws.us_socket_t; + pub extern fn us_socket_wrap_with_tls(ssl: i32, s: *uws.us_socket_t, options: uws.SocketContext.BunSocketContextOptions, events: c.us_socket_events_t, old_socket_ext_size: i32, socket_ext_size: i32) ?*uws.us_socket_t; }; const debug = bun.Output.scoped(.uws, .visible); diff --git a/src/js/node/net.ts b/src/js/node/net.ts index 9f28b94367..bb339ab35b 100644 --- a/src/js/node/net.ts +++ b/src/js/node/net.ts @@ -467,7 +467,7 @@ const ServerHandlers: SocketHandler = { } } SocketHandlers.error(socket, error, true); - data.server.emit("clientError", error, data); + this.server?.emit("clientError", error, data); }, timeout(socket) { SocketHandlers.timeout(socket); diff --git a/test/js/bun/http/bun-websocket-cpu-fixture.js b/test/js/bun/http/bun-websocket-cpu-fixture.js index a2d03babfe..078de7851a 100644 --- a/test/js/bun/http/bun-websocket-cpu-fixture.js +++ b/test/js/bun/http/bun-websocket-cpu-fixture.js @@ -23,7 +23,9 @@ const server = Bun.serve({ }); const ws = new WebSocket(`wss://${server.hostname}:${server.port}`, { tls: { rejectUnauthorized: false } }); -await Bun.sleep(1000); +const { promise: openWS, resolve: onWSOpen } = Promise.withResolvers(); +ws.onopen = onWSOpen; +await openWS; for (let i = 0; i < 1000; i++) { ws.send("hello"); }