mirror of
https://github.com/oven-sh/bun
synced 2026-02-15 05:12:29 +00:00
fix(sockets) add socket wrapper and refactor context ownership handling in socket.zig (#13176)
This commit is contained in:
@@ -128,7 +128,7 @@ void us_internal_socket_context_link_listen_socket(struct us_socket_context_t *c
|
||||
context->head_listen_sockets->s.prev = &ls->s;
|
||||
}
|
||||
context->head_listen_sockets = ls;
|
||||
context->ref_count++;
|
||||
us_socket_context_ref(0, context);
|
||||
}
|
||||
|
||||
/* We always add in the top, so we don't modify any s.next */
|
||||
@@ -140,7 +140,7 @@ void us_internal_socket_context_link_socket(struct us_socket_context_t *context,
|
||||
context->head_sockets->prev = s;
|
||||
}
|
||||
context->head_sockets = s;
|
||||
context->ref_count++;
|
||||
us_socket_context_ref(0, context);
|
||||
}
|
||||
|
||||
struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *context) {
|
||||
@@ -277,9 +277,8 @@ struct us_bun_verify_error_t us_socket_verify_error(int ssl, struct us_socket_t
|
||||
return (struct us_bun_verify_error_t) { .error = 0, .code = NULL, .reason = NULL };
|
||||
}
|
||||
|
||||
|
||||
|
||||
void us_internal_socket_context_free(int ssl, struct us_socket_context_t *context) {
|
||||
|
||||
#ifndef LIBUS_NO_SSL
|
||||
if (ssl) {
|
||||
/* This function will call us again with SSL=false */
|
||||
@@ -300,9 +299,10 @@ void us_internal_socket_context_free(int ssl, struct us_socket_context_t *contex
|
||||
void us_socket_context_ref(int ssl, struct us_socket_context_t *context) {
|
||||
context->ref_count++;
|
||||
}
|
||||
|
||||
void us_socket_context_unref(int ssl, struct us_socket_context_t *context) {
|
||||
if (--context->ref_count == 0) {
|
||||
uint32_t ref_count = context->ref_count;
|
||||
context->ref_count--;
|
||||
if (ref_count == 1) {
|
||||
us_internal_socket_context_free(ssl, context);
|
||||
}
|
||||
}
|
||||
@@ -481,6 +481,7 @@ void *us_socket_context_connect(int ssl, struct us_socket_context_t *context, co
|
||||
struct us_connecting_socket_t *c = us_calloc(1, sizeof(struct us_connecting_socket_t) + socket_ext_size);
|
||||
c->socket_ext_size = socket_ext_size;
|
||||
c->context = context;
|
||||
us_socket_context_ref(ssl, context);
|
||||
c->options = options;
|
||||
c->ssl = ssl > 0;
|
||||
c->timeout = 255;
|
||||
@@ -548,7 +549,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) {
|
||||
c->pending_resolve_callback = 0;
|
||||
// if the socket was closed while we were resolving the address, free it
|
||||
if (c->closed) {
|
||||
us_connecting_socket_free(c);
|
||||
us_connecting_socket_free(c->ssl, c);
|
||||
return;
|
||||
}
|
||||
struct addrinfo_result *result = Bun__addrinfo_getRequestResult(c->addrinfo_req);
|
||||
@@ -556,7 +557,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) {
|
||||
c->error = result->error;
|
||||
c->context->on_connect_error(c, result->error);
|
||||
Bun__addrinfo_freeRequest(c->addrinfo_req, 0);
|
||||
us_connecting_socket_close(0, c);
|
||||
us_connecting_socket_close(c->ssl, c);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -567,7 +568,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) {
|
||||
c->error = ECONNREFUSED;
|
||||
c->context->on_connect_error(c, ECONNREFUSED);
|
||||
Bun__addrinfo_freeRequest(c->addrinfo_req, 1);
|
||||
us_connecting_socket_close(0, c);
|
||||
us_connecting_socket_close(c->ssl, c);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -638,7 +639,7 @@ void us_internal_socket_after_open(struct us_socket_t *s, int error) {
|
||||
c->error = ECONNREFUSED;
|
||||
c->context->on_connect_error(c, error);
|
||||
Bun__addrinfo_freeRequest(c->addrinfo_req, ECONNREFUSED);
|
||||
us_connecting_socket_close(0, c);
|
||||
us_connecting_socket_close(c->ssl, c);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -667,7 +668,7 @@ void us_internal_socket_after_open(struct us_socket_t *s, int error) {
|
||||
}
|
||||
// now that the socket is open, we can release the associated us_connecting_socket_t if it exists
|
||||
Bun__addrinfo_freeRequest(c->addrinfo_req, 0);
|
||||
us_connecting_socket_free(c);
|
||||
us_connecting_socket_free(c->ssl, c);
|
||||
s->connect_state = NULL;
|
||||
}
|
||||
|
||||
|
||||
@@ -1820,6 +1820,7 @@ ssl_wrapped_context_on_close(struct us_internal_ssl_socket_t *s, int code,
|
||||
wrapped_context->old_events.on_close((struct us_socket_t *)s, code, reason);
|
||||
}
|
||||
|
||||
us_socket_context_unref(0, wrapped_context->tcp_context);
|
||||
return s;
|
||||
}
|
||||
|
||||
@@ -1976,6 +1977,7 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls(
|
||||
}
|
||||
|
||||
struct us_socket_context_t *old_context = us_socket_context(0, s);
|
||||
us_socket_context_ref(0,old_context);
|
||||
|
||||
struct us_socket_context_t *context = us_create_bun_socket_context(
|
||||
1, old_context->loop, sizeof(struct us_wrapped_socket_context_t),
|
||||
@@ -1998,6 +2000,7 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls(
|
||||
};
|
||||
wrapped_context->old_events = old_events;
|
||||
wrapped_context->events = events;
|
||||
wrapped_context->tcp_context = old_context;
|
||||
|
||||
// no need to wrap open because socket is already open (only new context will
|
||||
// be called so we can configure hostname and ssl stuff normally here before
|
||||
|
||||
@@ -189,6 +189,7 @@ struct us_connecting_socket_t {
|
||||
};
|
||||
|
||||
struct us_wrapped_socket_context_t {
|
||||
struct us_socket_context_t* tcp_context;
|
||||
struct us_socket_events_t events;
|
||||
struct us_socket_events_t old_events;
|
||||
};
|
||||
|
||||
@@ -314,7 +314,7 @@ struct us_socket_t *us_socket_context_connect_unix(int ssl, us_socket_context_r
|
||||
* Can also be used to determine if a socket is a listen_socket or not, but you probably know that already. */
|
||||
int us_socket_is_established(int ssl, us_socket_r s) nonnull_fn_decl;
|
||||
|
||||
void us_connecting_socket_free(struct us_connecting_socket_t *c) nonnull_fn_decl;
|
||||
void us_connecting_socket_free(int ssl, struct us_connecting_socket_t *c) nonnull_fn_decl;
|
||||
|
||||
/* Cancel a connecting socket. Can be used together with us_socket_timeout to limit connection times.
|
||||
* Entirely destroys the socket - this function works like us_socket_close but does not trigger on_close event since
|
||||
|
||||
@@ -129,19 +129,20 @@ int us_socket_is_established(int ssl, struct us_socket_t *s) {
|
||||
return us_internal_poll_type((struct us_poll_t *) s) != POLL_TYPE_SEMI_SOCKET;
|
||||
}
|
||||
|
||||
void us_connecting_socket_free(struct us_connecting_socket_t *c) {
|
||||
void us_connecting_socket_free(int ssl, struct us_connecting_socket_t *c) {
|
||||
// we can't just free c immediately, as it may be enqueued in the dns_ready_head list
|
||||
// instead, we move it to a close list and free it after the iteration
|
||||
c->next = c->context->loop->data.closed_connecting_head;
|
||||
c->context->loop->data.closed_connecting_head = c;
|
||||
us_socket_context_unref(ssl, c->context);
|
||||
}
|
||||
|
||||
void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) {
|
||||
if (c->closed) return;
|
||||
c->closed = 1;
|
||||
|
||||
for (struct us_socket_t *s = c->connecting_head; s; s = s->connect_next) {
|
||||
us_internal_socket_context_unlink_socket(ssl, s->context, s);
|
||||
|
||||
us_poll_stop((struct us_poll_t *) s, s->context->loop);
|
||||
bsd_close_socket(us_poll_fd((struct us_poll_t *) s));
|
||||
|
||||
@@ -156,7 +157,7 @@ void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) {
|
||||
// we can only schedule the socket to be freed if there is no pending callback
|
||||
// otherwise, the callback will see that the socket is closed and will free it
|
||||
if (!c->pending_resolve_callback) {
|
||||
us_connecting_socket_free(c);
|
||||
us_connecting_socket_free(ssl, c);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -361,4 +361,4 @@ public:
|
||||
|
||||
}
|
||||
|
||||
#endif // UWS_ASYNCSOCKET_H
|
||||
#endif // UWS_ASYNCSOCKET_H
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2131,7 +2131,9 @@ pub const Subprocess = struct {
|
||||
|
||||
if (subprocess.ipc_data) |*ipc_data| {
|
||||
if (Environment.isPosix) {
|
||||
posix_ipc_info.ext(*Subprocess).* = subprocess;
|
||||
if (posix_ipc_info.ext(*Subprocess)) |ctx| {
|
||||
ctx.* = subprocess;
|
||||
}
|
||||
} else {
|
||||
if (ipc_data.configureServer(
|
||||
Subprocess,
|
||||
|
||||
@@ -83,9 +83,16 @@ pub const InternalLoopData = extern struct {
|
||||
pub const InternalSocket = union(enum) {
|
||||
done: *Socket,
|
||||
connecting: *ConnectingSocket,
|
||||
|
||||
detached: void,
|
||||
pub fn isDetached(this: InternalSocket) bool {
|
||||
return this == .detached;
|
||||
}
|
||||
pub fn detach(this: *InternalSocket) void {
|
||||
this.* = .detached;
|
||||
}
|
||||
pub fn close(this: InternalSocket, comptime is_ssl: bool, code: CloseCode) void {
|
||||
switch (this) {
|
||||
.detached => {},
|
||||
.done => |socket| {
|
||||
debug("us_socket_close({d})", .{@intFromPtr(socket)});
|
||||
_ = us_socket_close(
|
||||
@@ -109,6 +116,7 @@ pub const InternalSocket = union(enum) {
|
||||
return switch (this) {
|
||||
.done => |socket| us_socket_is_closed(@intFromBool(is_ssl), socket) > 0,
|
||||
.connecting => |socket| us_connecting_socket_is_closed(@intFromBool(is_ssl), socket) > 0,
|
||||
.detached => true,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -116,6 +124,7 @@ pub const InternalSocket = union(enum) {
|
||||
return switch (this) {
|
||||
.done => this.done,
|
||||
.connecting => null,
|
||||
.detached => null,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -123,12 +132,16 @@ pub const InternalSocket = union(enum) {
|
||||
return switch (this) {
|
||||
.done => switch (other) {
|
||||
.done => this.done == other.done,
|
||||
.connecting => false,
|
||||
.connecting, .detached => false,
|
||||
},
|
||||
.connecting => switch (other) {
|
||||
.done => false,
|
||||
.done, .detached => false,
|
||||
.connecting => this.connecting == other.connecting,
|
||||
},
|
||||
.detached => switch (other) {
|
||||
.detached => true,
|
||||
.done, .connecting => false,
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
@@ -138,7 +151,13 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
const ssl_int: i32 = @intFromBool(is_ssl);
|
||||
socket: InternalSocket,
|
||||
const ThisSocket = @This();
|
||||
|
||||
pub const detached: NewSocketHandler(is_ssl) = NewSocketHandler(is_ssl){ .socket = .{ .detached = {} } };
|
||||
pub fn detach(this: *ThisSocket) void {
|
||||
this.socket.detach();
|
||||
}
|
||||
pub fn isDetached(this: ThisSocket) bool {
|
||||
return this.socket.isDetached();
|
||||
}
|
||||
pub fn verifyError(this: ThisSocket) us_bun_verify_error_t {
|
||||
const socket = this.socket.get() orelse return std.mem.zeroes(us_bun_verify_error_t);
|
||||
const ssl_error: us_bun_verify_error_t = uws.us_socket_verify_error(comptime ssl_int, socket);
|
||||
@@ -154,6 +173,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
switch (this.socket) {
|
||||
.done => |socket| us_socket_timeout(comptime ssl_int, socket, seconds),
|
||||
.connecting => |socket| us_connecting_socket_timeout(comptime ssl_int, socket, seconds),
|
||||
.detached => {},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,6 +197,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
us_connecting_socket_long_timeout(comptime ssl_int, socket, 0);
|
||||
}
|
||||
},
|
||||
.detached => {},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,19 +211,23 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
us_connecting_socket_timeout(comptime ssl_int, socket, 0);
|
||||
us_connecting_socket_long_timeout(comptime ssl_int, socket, minutes);
|
||||
},
|
||||
.detached => {},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn startTLS(this: ThisSocket, is_client: bool) void {
|
||||
const socket = this.socket.get() orelse @panic("socket is not open");
|
||||
const socket = this.socket.get() orelse return;
|
||||
_ = us_socket_open(comptime ssl_int, socket, @intFromBool(is_client), null, 0);
|
||||
}
|
||||
|
||||
pub fn ssl(this: ThisSocket) *BoringSSL.SSL {
|
||||
pub fn ssl(this: ThisSocket) ?*BoringSSL.SSL {
|
||||
if (comptime is_ssl) {
|
||||
return @as(*BoringSSL.SSL, @ptrCast(this.getNativeHandle()));
|
||||
if(this.getNativeHandle()) |handle| {
|
||||
return @as(*BoringSSL.SSL, @ptrCast(handle));
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@panic("socket is not a TLS socket");
|
||||
return null;
|
||||
}
|
||||
|
||||
// Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context
|
||||
@@ -229,7 +254,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
}
|
||||
|
||||
if (comptime deref_) {
|
||||
return (TLSSocket.from(socket)).ext(ContextType).*;
|
||||
return (TLSSocket.from(socket)).ext(ContextType).?.*;
|
||||
}
|
||||
|
||||
return (TLSSocket.from(socket)).ext(ContextType);
|
||||
@@ -289,7 +314,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
}
|
||||
pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket {
|
||||
Fields.onConnectError(
|
||||
TLSSocket.from(socket).ext(ContextType).*,
|
||||
TLSSocket.from(socket).ext(ContextType).?.*,
|
||||
TLSSocket.from(socket),
|
||||
code,
|
||||
);
|
||||
@@ -328,7 +353,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
.on_long_timeout = SocketHandler.on_long_timeout,
|
||||
};
|
||||
|
||||
const this_socket = this.socket.get() orelse @panic("socket is not open");
|
||||
const this_socket = this.socket.get() orelse return null;
|
||||
|
||||
const socket = us_socket_wrap_with_tls(ssl_int, this_socket, options, events, socket_ext_size) orelse return null;
|
||||
return NewSocketHandler(true).from(socket);
|
||||
@@ -338,6 +363,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
return @ptrCast(switch (this.socket) {
|
||||
.done => |socket| us_socket_get_native_handle(comptime ssl_int, socket),
|
||||
.connecting => |socket| us_connecting_socket_get_native_handle(comptime ssl_int, socket),
|
||||
.detached => null,
|
||||
} orelse return null);
|
||||
}
|
||||
|
||||
@@ -361,7 +387,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
us_socket_sendfile_needs_more(socket);
|
||||
}
|
||||
|
||||
pub fn ext(this: ThisSocket, comptime ContextType: type) *ContextType {
|
||||
pub fn ext(this: ThisSocket, comptime ContextType: type) ?*ContextType {
|
||||
const alignment = if (ContextType == *anyopaque)
|
||||
@sizeOf(usize)
|
||||
else
|
||||
@@ -370,6 +396,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
const ptr = switch (this.socket) {
|
||||
.done => |sock| us_socket_ext(comptime ssl_int, sock),
|
||||
.connecting => |sock| us_connecting_socket_ext(comptime ssl_int, sock),
|
||||
.detached => return null,
|
||||
};
|
||||
|
||||
return @as(*align(alignment) ContextType, @ptrCast(@alignCast(ptr)));
|
||||
@@ -380,6 +407,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
switch (this.socket) {
|
||||
.done => |socket| return us_socket_context(comptime ssl_int, socket),
|
||||
.connecting => |socket| return us_connecting_socket_context(comptime ssl_int, socket),
|
||||
.detached => return null,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,6 +462,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
socket,
|
||||
);
|
||||
},
|
||||
.detached => {},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,6 +482,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
socket,
|
||||
);
|
||||
},
|
||||
.detached => {},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,6 +500,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
socket,
|
||||
) > 0;
|
||||
},
|
||||
.detached => return true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,6 +526,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
socket,
|
||||
);
|
||||
},
|
||||
.detached => return 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -642,8 +674,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
) ?ThisSocket {
|
||||
const socket_ = ThisSocket{ .socket = .{ .done = us_socket_from_fd(ctx, @sizeOf(*anyopaque), bun.socketcast(handle)) orelse return null } };
|
||||
|
||||
const holder = socket_.ext(*anyopaque);
|
||||
holder.* = this;
|
||||
if(socket_.ext(*anyopaque)) |holder| {
|
||||
holder.* = this;
|
||||
}
|
||||
|
||||
if (comptime socket_field_name) |field| {
|
||||
@field(this, field) = socket_;
|
||||
@@ -679,8 +712,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
return error.FailedToOpenSocket;
|
||||
|
||||
const socket_ = ThisSocket{ .socket = .{ .done = socket } };
|
||||
const holder = socket_.ext(*anyopaque);
|
||||
holder.* = ctx;
|
||||
if(socket_.ext(*anyopaque)) |holder| {
|
||||
holder.* = ctx;
|
||||
}
|
||||
return socket_;
|
||||
}
|
||||
|
||||
@@ -721,9 +755,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
ThisSocket{
|
||||
.socket = .{ .connecting = @ptrCast(socket_ptr) },
|
||||
};
|
||||
|
||||
const holder = socket.ext(*anyopaque);
|
||||
holder.* = ptr;
|
||||
if(socket.ext(*anyopaque)) |holder| {
|
||||
holder.* = ptr;
|
||||
}
|
||||
return socket;
|
||||
}
|
||||
|
||||
@@ -751,7 +785,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
}
|
||||
|
||||
if (comptime deref_) {
|
||||
return (SocketHandlerType.from(socket)).ext(ContextType).*;
|
||||
return (SocketHandlerType.from(socket)).ext(ContextType).?.*;
|
||||
}
|
||||
|
||||
return (SocketHandlerType.from(socket)).ext(ContextType);
|
||||
@@ -806,7 +840,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
const val = if (comptime ContextType == anyopaque)
|
||||
us_connecting_socket_ext(comptime ssl_int, socket)
|
||||
else if (comptime deref_)
|
||||
SocketHandlerType.fromConnecting(socket).ext(ContextType).*
|
||||
SocketHandlerType.fromConnecting(socket).ext(ContextType).?.*
|
||||
else
|
||||
SocketHandlerType.fromConnecting(socket).ext(ContextType);
|
||||
Fields.onConnectError(
|
||||
@@ -820,7 +854,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
const val = if (comptime ContextType == anyopaque)
|
||||
us_socket_ext(comptime ssl_int, socket)
|
||||
else if (comptime deref_)
|
||||
SocketHandlerType.from(socket).ext(ContextType).*
|
||||
SocketHandlerType.from(socket).ext(ContextType).?.*
|
||||
else
|
||||
SocketHandlerType.from(socket).ext(ContextType);
|
||||
Fields.onConnectError(
|
||||
@@ -883,7 +917,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
}
|
||||
|
||||
if (comptime deref_) {
|
||||
return (ThisSocket.from(socket)).ext(ContextType).*;
|
||||
return (ThisSocket.from(socket)).ext(ContextType).?.*;
|
||||
}
|
||||
|
||||
return (ThisSocket.from(socket)).ext(ContextType);
|
||||
@@ -945,7 +979,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
const val = if (comptime ContextType == anyopaque)
|
||||
us_connecting_socket_ext(comptime ssl_int, socket)
|
||||
else if (comptime deref_)
|
||||
ThisSocket.fromConnecting(socket).ext(ContextType).*
|
||||
ThisSocket.fromConnecting(socket).ext(ContextType).?.*
|
||||
else
|
||||
ThisSocket.fromConnecting(socket).ext(ContextType);
|
||||
Fields.onConnectError(
|
||||
@@ -959,7 +993,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
const val = if (comptime ContextType == anyopaque)
|
||||
us_socket_ext(comptime ssl_int, socket)
|
||||
else if (comptime deref_)
|
||||
ThisSocket.from(socket).ext(ContextType).*
|
||||
ThisSocket.from(socket).ext(ContextType).?.*
|
||||
else
|
||||
ThisSocket.from(socket).ext(ContextType);
|
||||
|
||||
@@ -1033,14 +1067,14 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
|
||||
const new_socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, -1) orelse return false;
|
||||
bun.assert(new_socket == socket);
|
||||
var adopted = ThisSocket.from(new_socket);
|
||||
const holder = adopted.ext(*anyopaque);
|
||||
holder.* = ctx;
|
||||
if(adopted.ext(*anyopaque)) |holder| {
|
||||
holder.* = ctx;
|
||||
}
|
||||
@field(ctx, socket_field_name) = adopted;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub const SocketTCP = NewSocketHandler(false);
|
||||
pub const SocketTLS = NewSocketHandler(true);
|
||||
|
||||
@@ -2312,7 +2346,6 @@ pub fn NewApp(comptime ssl: bool) type {
|
||||
}
|
||||
}
|
||||
pub fn onAborted(res: *Response, comptime UserDataType: type, comptime handler: fn (UserDataType, *Response) void, opcional_data: UserDataType) void {
|
||||
|
||||
const Wrapper = struct {
|
||||
pub fn handle(this: *uws_res, user_data: ?*anyopaque) callconv(.C) void {
|
||||
if (comptime UserDataType == void) {
|
||||
|
||||
25
src/http.zig
25
src/http.zig
@@ -342,7 +342,9 @@ fn NewHTTPContext(comptime ssl: bool) type {
|
||||
};
|
||||
|
||||
pub fn markSocketAsDead(socket: HTTPSocket) void {
|
||||
socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr());
|
||||
if (socket.ext(**anyopaque)) |ctx| {
|
||||
ctx.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr());
|
||||
}
|
||||
}
|
||||
|
||||
fn terminateSocket(socket: HTTPSocket) void {
|
||||
@@ -461,7 +463,9 @@ fn NewHTTPContext(comptime ssl: bool) type {
|
||||
|
||||
if (hostname.len <= MAX_KEEPALIVE_HOSTNAME and !socket.isClosedOrHasError() and socket.isEstablished()) {
|
||||
if (this.pending_sockets.get()) |pending| {
|
||||
socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(pending).ptr());
|
||||
if (socket.ext(**anyopaque)) |ctx| {
|
||||
ctx.* = bun.cast(**anyopaque, ActiveSocket.init(pending).ptr());
|
||||
}
|
||||
socket.flush();
|
||||
socket.timeout(0);
|
||||
socket.setTimeoutMinutes(5);
|
||||
@@ -517,9 +521,9 @@ fn NewHTTPContext(comptime ssl: bool) type {
|
||||
// handshake completed but we may have ssl errors
|
||||
client.flags.did_have_handshaking_error = handshake_error.error_no != 0;
|
||||
if (handshake_success) {
|
||||
if(client.flags.reject_unauthorized) {
|
||||
if (client.flags.reject_unauthorized) {
|
||||
// only reject the connection if reject_unauthorized == true
|
||||
if(client.flags.did_have_handshaking_error) {
|
||||
if (client.flags.did_have_handshaking_error) {
|
||||
client.closeAndFail(BoringSSL.getCertErrorFromNo(handshake_error.error_no), comptime ssl, socket);
|
||||
return;
|
||||
}
|
||||
@@ -537,7 +541,7 @@ fn NewHTTPContext(comptime ssl: bool) type {
|
||||
} else {
|
||||
// if we are here is because server rejected us, and the error_no is the cause of this
|
||||
// if we set reject_unauthorized == false this means the server requires custom CA aka NODE_EXTRA_CA_CERTS
|
||||
if(client.flags.did_have_handshaking_error) {
|
||||
if (client.flags.did_have_handshaking_error) {
|
||||
client.closeAndFail(BoringSSL.getCertErrorFromNo(handshake_error.error_no), comptime ssl, socket);
|
||||
return;
|
||||
}
|
||||
@@ -545,7 +549,6 @@ fn NewHTTPContext(comptime ssl: bool) type {
|
||||
client.closeAndFail(error.ConnectionRefused, comptime ssl, socket);
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (socket.isClosed()) {
|
||||
@@ -745,7 +748,9 @@ fn NewHTTPContext(comptime ssl: bool) type {
|
||||
|
||||
if (client.isKeepAlivePossible()) {
|
||||
if (this.existingSocket(client.flags.reject_unauthorized, hostname, port)) |sock| {
|
||||
sock.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr());
|
||||
if (sock.ext(**anyopaque)) |ctx| {
|
||||
ctx.* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr());
|
||||
}
|
||||
client.allow_retry = true;
|
||||
client.onOpen(comptime ssl, sock);
|
||||
if (comptime ssl) {
|
||||
@@ -2858,10 +2863,10 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s
|
||||
}
|
||||
|
||||
pub fn closeAndFail(this: *HTTPClient, err: anyerror, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket) void {
|
||||
if (!socket.isClosed()) {
|
||||
NewHTTPContext(is_ssl).terminateSocket(socket);
|
||||
}
|
||||
if (this.state.stage != .fail and this.state.stage != .done) {
|
||||
if (!socket.isClosed()) {
|
||||
NewHTTPContext(is_ssl).terminateSocket(socket);
|
||||
}
|
||||
log("closeAndFail: {s}", .{@errorName(err)});
|
||||
this.fail(err);
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ export async function expectMaxObjectTypeCount(
|
||||
await Bun.sleep(wait);
|
||||
gc();
|
||||
}
|
||||
expect(heapStats().objectTypeCounts[type]).toBeLessThanOrEqual(count);
|
||||
expect(heapStats().objectTypeCounts[type] || 0).toBeLessThanOrEqual(count);
|
||||
}
|
||||
|
||||
// we must ensure that finalizers are run
|
||||
|
||||
@@ -126,9 +126,9 @@ async function calculateMemoryLeak(fn: () => Promise<void>) {
|
||||
// If it was leaking the body, the memory usage would be at least 512 KB * 10_000 = 5 GB
|
||||
// If it ends up around 280 MB, it's probably not leaking the body.
|
||||
for (const test_info of [
|
||||
["#10265 should not leak memory when ignoring the body", callIgnore, false, 48],
|
||||
["should not leak memory when buffering the body", callBuffering, false, 48],
|
||||
["should not leak memory when streaming the body", callStreaming, false, 48],
|
||||
["#10265 should not leak memory when ignoring the body", callIgnore, false, 64],
|
||||
["should not leak memory when buffering the body", callBuffering, false, 64],
|
||||
["should not leak memory when streaming the body", callStreaming, false, 64],
|
||||
["should not leak memory when streaming the body incompletely", callIncompleteStreaming, false, 64],
|
||||
["should not leak memory when streaming the body and echoing it back", callStreamingEcho, false, 64],
|
||||
] as const) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { expect, it } from "bun:test";
|
||||
import { bunEnv, bunExe, expectMaxObjectTypeCount, isWindows } from "harness";
|
||||
import { bunEnv, bunExe, expectMaxObjectTypeCount, isWindows, tls } from "harness";
|
||||
import { connect, fileURLToPath, SocketHandler, spawn } from "bun";
|
||||
import type { Socket } from "bun";
|
||||
it("should coerce '0' to 0", async () => {
|
||||
@@ -354,7 +354,7 @@ it("it should not crash when returning a Error on client socket open", async ()
|
||||
});
|
||||
|
||||
it("it should only call open once", async () => {
|
||||
const server = Bun.listen({
|
||||
using server = Bun.listen({
|
||||
port: 0,
|
||||
hostname: "localhost",
|
||||
socket: {
|
||||
@@ -381,7 +381,6 @@ it("it should only call open once", async () => {
|
||||
expect().fail("connectError should not be called");
|
||||
},
|
||||
close(socket) {
|
||||
server.stop();
|
||||
resolve();
|
||||
},
|
||||
data(socket, data) {},
|
||||
@@ -397,7 +396,7 @@ it.skipIf(isWindows)("should not leak file descriptors when connecting", async (
|
||||
});
|
||||
|
||||
it("should not call open if the connection had an error", async () => {
|
||||
const server = Bun.listen({
|
||||
using server = Bun.listen({
|
||||
port: 0,
|
||||
hostname: "0.0.0.0",
|
||||
socket: {
|
||||
@@ -435,12 +434,11 @@ it("should not call open if the connection had an error", async () => {
|
||||
|
||||
await Bun.sleep(50);
|
||||
await promise;
|
||||
server.stop();
|
||||
expect(hadError).toBe(true);
|
||||
});
|
||||
|
||||
it("should connect directly when using an ip address", async () => {
|
||||
const server = Bun.listen({
|
||||
using server = Bun.listen({
|
||||
port: 0,
|
||||
hostname: "127.0.0.1",
|
||||
socket: {
|
||||
@@ -467,7 +465,6 @@ it("should connect directly when using an ip address", async () => {
|
||||
expect().fail("connectError should not be called");
|
||||
},
|
||||
close(socket) {
|
||||
server.stop();
|
||||
resolve();
|
||||
},
|
||||
data(socket, data) {},
|
||||
@@ -498,3 +495,77 @@ it("should not call drain before handshake", async () => {
|
||||
await promise;
|
||||
expect(socket.authorized).toBe(true);
|
||||
});
|
||||
it("should be able to upgrade to TLS", async () => {
|
||||
using server = Bun.serve({
|
||||
tls,
|
||||
async fetch(req) {
|
||||
return new Response("Hello World");
|
||||
},
|
||||
});
|
||||
const { promise: tlsSocketPromise, resolve, reject } = Promise.withResolvers();
|
||||
const { promise: rawSocketPromise, resolve: rawSocketResolve, reject: rawSocketReject } = Promise.withResolvers();
|
||||
{
|
||||
let body = "";
|
||||
let rawBody = Buffer.alloc(0);
|
||||
const socket = await Bun.connect({
|
||||
hostname: "localhost",
|
||||
port: server.port,
|
||||
socket: {
|
||||
data(socket, data) {
|
||||
rawBody = Buffer.concat([rawBody, data]);
|
||||
},
|
||||
close() {
|
||||
rawSocketResolve(rawBody);
|
||||
},
|
||||
error(err) {
|
||||
rawSocketReject(err);
|
||||
},
|
||||
},
|
||||
});
|
||||
const result = socket.upgradeTLS({
|
||||
data: Buffer.from("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n"),
|
||||
tls,
|
||||
socket: {
|
||||
data(socket, data) {
|
||||
body += data.toString("utf8");
|
||||
if (body.includes("\r\n\r\n")) {
|
||||
socket.end();
|
||||
}
|
||||
},
|
||||
close() {
|
||||
resolve(body);
|
||||
},
|
||||
drain(socket) {
|
||||
while (socket.data.byteLength > 0) {
|
||||
const written = socket.write(socket.data);
|
||||
if (written === 0) {
|
||||
break;
|
||||
}
|
||||
socket.data = socket.data.slice(written);
|
||||
}
|
||||
socket.flush();
|
||||
},
|
||||
error(err) {
|
||||
reject(err);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const [raw, tls_socket] = result;
|
||||
expect(raw).toBeDefined();
|
||||
expect(tls_socket).toBeDefined();
|
||||
}
|
||||
const [tlsData, rawData] = await Promise.all([tlsSocketPromise, rawSocketPromise]);
|
||||
expect(tlsData).toContain("HTTP/1.1 200 OK");
|
||||
expect(tlsData).toContain("Content-Length: 11");
|
||||
expect(tlsData).toContain("\r\nHello World");
|
||||
expect(rawData.byteLength).toBeGreaterThanOrEqual(1980);
|
||||
});
|
||||
|
||||
it("should not leak memory", async () => {
|
||||
// assert we don't leak the sockets
|
||||
// we expect 1 or 2 because that's the prototype / structure
|
||||
await expectMaxObjectTypeCount(expect, "Listener", 2);
|
||||
await expectMaxObjectTypeCount(expect, "TCPSocket", 2);
|
||||
await expectMaxObjectTypeCount(expect, "TLSSocket", 2);
|
||||
});
|
||||
|
||||
@@ -15,7 +15,7 @@ it("remoteAddress works", async () => {
|
||||
};
|
||||
reject = reject1;
|
||||
});
|
||||
let server = Bun.listen({
|
||||
using server = Bun.listen({
|
||||
socket: {
|
||||
open(ws) {
|
||||
try {
|
||||
@@ -25,8 +25,6 @@ it("remoteAddress works", async () => {
|
||||
reject(e);
|
||||
|
||||
return;
|
||||
} finally {
|
||||
setTimeout(() => server.stop(true), 0);
|
||||
}
|
||||
},
|
||||
close() {},
|
||||
@@ -63,7 +61,7 @@ it("should not allow invalid tls option", () => {
|
||||
[1, "string", Symbol("symbol")].forEach(value => {
|
||||
expect(() => {
|
||||
// @ts-ignore
|
||||
const server = Bun.listen({
|
||||
using server = Bun.listen({
|
||||
socket: {
|
||||
open(ws) {},
|
||||
close() {},
|
||||
@@ -73,7 +71,6 @@ it("should not allow invalid tls option", () => {
|
||||
hostname: "localhost",
|
||||
tls: value,
|
||||
});
|
||||
server.stop(true);
|
||||
}).toThrow("tls option expects an object");
|
||||
});
|
||||
});
|
||||
@@ -82,7 +79,7 @@ it("should allow using false, null or undefined tls option", () => {
|
||||
[false, null, undefined].forEach(value => {
|
||||
expect(() => {
|
||||
// @ts-ignore
|
||||
const server = Bun.listen({
|
||||
using server = Bun.listen({
|
||||
socket: {
|
||||
open(ws) {},
|
||||
close() {},
|
||||
@@ -92,7 +89,6 @@ it("should allow using false, null or undefined tls option", () => {
|
||||
hostname: "localhost",
|
||||
tls: value,
|
||||
});
|
||||
server.stop(true);
|
||||
}).not.toThrow("tls option expects an object");
|
||||
});
|
||||
});
|
||||
@@ -167,7 +163,7 @@ it("echo server 1 on 1", async () => {
|
||||
},
|
||||
} as SocketHandler<any>;
|
||||
|
||||
var server: TCPSocketListener<any> | undefined = listen({
|
||||
using server: TCPSocketListener<any> | undefined = listen({
|
||||
socket: handlers,
|
||||
hostname: "localhost",
|
||||
port: 0,
|
||||
@@ -186,8 +182,6 @@ it("echo server 1 on 1", async () => {
|
||||
},
|
||||
});
|
||||
await Promise.all([prom, clientProm, serverProm]);
|
||||
server.stop(true);
|
||||
server = serverData = clientData = undefined;
|
||||
})();
|
||||
});
|
||||
|
||||
@@ -276,7 +270,7 @@ describe("tcp socket binaryType", () => {
|
||||
binaryType: type,
|
||||
} as SocketHandler<any>;
|
||||
|
||||
var server: TCPSocketListener<any> | undefined = listen({
|
||||
using server: TCPSocketListener<any> | undefined = listen({
|
||||
socket: handlers,
|
||||
hostname: "localhost",
|
||||
port: 0,
|
||||
@@ -296,8 +290,6 @@ describe("tcp socket binaryType", () => {
|
||||
});
|
||||
|
||||
await Promise.all([prom, clientProm, serverProm]);
|
||||
server.stop(true);
|
||||
server = serverData = clientData = undefined;
|
||||
})();
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user