diff --git a/packages/bun-usockets/src/socket.c b/packages/bun-usockets/src/socket.c index eaa152005b..76713eb77b 100644 --- a/packages/bun-usockets/src/socket.c +++ b/packages/bun-usockets/src/socket.c @@ -22,7 +22,6 @@ #include #include #include - #ifndef WIN32 #include #endif @@ -175,6 +174,9 @@ struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, vo return (struct us_socket_t *)us_internal_ssl_socket_close((struct us_internal_ssl_socket_t *) s, code, reason); } if (!us_socket_is_closed(0, s)) { + /* make sure the context is alive until the callback ends */ + us_socket_context_ref(ssl, s->context); + if (s->low_prio_state == 1) { /* Unlink this socket from the low-priority queue */ if (!s->prev) s->context->loop->data.low_prio_head = s->next; @@ -186,7 +188,6 @@ struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, vo s->next = 0; s->low_prio_state = 0; us_socket_context_unref(ssl, s->context); - } else { us_internal_socket_context_unlink_socket(ssl, s->context, s); } @@ -207,16 +208,25 @@ struct us_socket_t *us_socket_close(int ssl, struct us_socket_t *s, int code, vo bsd_close_socket(us_poll_fd((struct us_poll_t *) s)); - /* Link this socket to the close-list and let it be deleted after this iteration */ - s->next = s->context->loop->data.closed_head; - s->context->loop->data.closed_head = s; /* Any socket with prev = context is marked as closed */ s->prev = (struct us_socket_t *) s->context; + /* mark it as closed and call the callback */ + struct us_socket_t *res = s; if (!(us_internal_poll_type(&s->p) & POLL_TYPE_SEMI_SOCKET)) { - return s->context->on_close(s, code, reason); + res = s->context->on_close(s, code, reason); } + + /* Link this socket to the close-list and let it be deleted after this iteration */ + s->next = s->context->loop->data.closed_head; + s->context->loop->data.closed_head = s; + + /* unref the context after the callback ends */ + us_socket_context_unref(ssl, s->context); + + /* preserve the return value from on_close if its called */ + return res; } return s;