fix(sockets) add socket wrapper and refactor context ownership handling in socket.zig (#13176)

This commit is contained in:
Ciro Spaciari
2024-08-10 01:34:17 +00:00
committed by GitHub
parent 24dbef7713
commit b9ead441c1
14 changed files with 369 additions and 313 deletions

View File

@@ -128,7 +128,7 @@ void us_internal_socket_context_link_listen_socket(struct us_socket_context_t *c
context->head_listen_sockets->s.prev = &ls->s;
}
context->head_listen_sockets = ls;
context->ref_count++;
us_socket_context_ref(0, context);
}
/* We always add in the top, so we don't modify any s.next */
@@ -140,7 +140,7 @@ void us_internal_socket_context_link_socket(struct us_socket_context_t *context,
context->head_sockets->prev = s;
}
context->head_sockets = s;
context->ref_count++;
us_socket_context_ref(0, context);
}
struct us_loop_t *us_socket_context_loop(int ssl, struct us_socket_context_t *context) {
@@ -277,9 +277,8 @@ struct us_bun_verify_error_t us_socket_verify_error(int ssl, struct us_socket_t
return (struct us_bun_verify_error_t) { .error = 0, .code = NULL, .reason = NULL };
}
void us_internal_socket_context_free(int ssl, struct us_socket_context_t *context) {
#ifndef LIBUS_NO_SSL
if (ssl) {
/* This function will call us again with SSL=false */
@@ -300,9 +299,10 @@ void us_internal_socket_context_free(int ssl, struct us_socket_context_t *contex
void us_socket_context_ref(int ssl, struct us_socket_context_t *context) {
context->ref_count++;
}
void us_socket_context_unref(int ssl, struct us_socket_context_t *context) {
if (--context->ref_count == 0) {
uint32_t ref_count = context->ref_count;
context->ref_count--;
if (ref_count == 1) {
us_internal_socket_context_free(ssl, context);
}
}
@@ -481,6 +481,7 @@ void *us_socket_context_connect(int ssl, struct us_socket_context_t *context, co
struct us_connecting_socket_t *c = us_calloc(1, sizeof(struct us_connecting_socket_t) + socket_ext_size);
c->socket_ext_size = socket_ext_size;
c->context = context;
us_socket_context_ref(ssl, context);
c->options = options;
c->ssl = ssl > 0;
c->timeout = 255;
@@ -548,7 +549,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) {
c->pending_resolve_callback = 0;
// if the socket was closed while we were resolving the address, free it
if (c->closed) {
us_connecting_socket_free(c);
us_connecting_socket_free(c->ssl, c);
return;
}
struct addrinfo_result *result = Bun__addrinfo_getRequestResult(c->addrinfo_req);
@@ -556,7 +557,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) {
c->error = result->error;
c->context->on_connect_error(c, result->error);
Bun__addrinfo_freeRequest(c->addrinfo_req, 0);
us_connecting_socket_close(0, c);
us_connecting_socket_close(c->ssl, c);
return;
}
@@ -567,7 +568,7 @@ void us_internal_socket_after_resolve(struct us_connecting_socket_t *c) {
c->error = ECONNREFUSED;
c->context->on_connect_error(c, ECONNREFUSED);
Bun__addrinfo_freeRequest(c->addrinfo_req, 1);
us_connecting_socket_close(0, c);
us_connecting_socket_close(c->ssl, c);
return;
}
}
@@ -638,7 +639,7 @@ void us_internal_socket_after_open(struct us_socket_t *s, int error) {
c->error = ECONNREFUSED;
c->context->on_connect_error(c, error);
Bun__addrinfo_freeRequest(c->addrinfo_req, ECONNREFUSED);
us_connecting_socket_close(0, c);
us_connecting_socket_close(c->ssl, c);
}
}
} else {
@@ -667,7 +668,7 @@ void us_internal_socket_after_open(struct us_socket_t *s, int error) {
}
// now that the socket is open, we can release the associated us_connecting_socket_t if it exists
Bun__addrinfo_freeRequest(c->addrinfo_req, 0);
us_connecting_socket_free(c);
us_connecting_socket_free(c->ssl, c);
s->connect_state = NULL;
}

View File

@@ -1820,6 +1820,7 @@ ssl_wrapped_context_on_close(struct us_internal_ssl_socket_t *s, int code,
wrapped_context->old_events.on_close((struct us_socket_t *)s, code, reason);
}
us_socket_context_unref(0, wrapped_context->tcp_context);
return s;
}
@@ -1976,6 +1977,7 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls(
}
struct us_socket_context_t *old_context = us_socket_context(0, s);
us_socket_context_ref(0,old_context);
struct us_socket_context_t *context = us_create_bun_socket_context(
1, old_context->loop, sizeof(struct us_wrapped_socket_context_t),
@@ -1998,6 +2000,7 @@ struct us_internal_ssl_socket_t *us_internal_ssl_socket_wrap_with_tls(
};
wrapped_context->old_events = old_events;
wrapped_context->events = events;
wrapped_context->tcp_context = old_context;
// no need to wrap open because socket is already open (only new context will
// be called so we can configure hostname and ssl stuff normally here before

View File

@@ -189,6 +189,7 @@ struct us_connecting_socket_t {
};
struct us_wrapped_socket_context_t {
struct us_socket_context_t* tcp_context;
struct us_socket_events_t events;
struct us_socket_events_t old_events;
};

View File

@@ -314,7 +314,7 @@ struct us_socket_t *us_socket_context_connect_unix(int ssl, us_socket_context_r
* Can also be used to determine if a socket is a listen_socket or not, but you probably know that already. */
int us_socket_is_established(int ssl, us_socket_r s) nonnull_fn_decl;
void us_connecting_socket_free(struct us_connecting_socket_t *c) nonnull_fn_decl;
void us_connecting_socket_free(int ssl, struct us_connecting_socket_t *c) nonnull_fn_decl;
/* Cancel a connecting socket. Can be used together with us_socket_timeout to limit connection times.
* Entirely destroys the socket - this function works like us_socket_close but does not trigger on_close event since

View File

@@ -129,19 +129,20 @@ int us_socket_is_established(int ssl, struct us_socket_t *s) {
return us_internal_poll_type((struct us_poll_t *) s) != POLL_TYPE_SEMI_SOCKET;
}
void us_connecting_socket_free(struct us_connecting_socket_t *c) {
void us_connecting_socket_free(int ssl, struct us_connecting_socket_t *c) {
// we can't just free c immediately, as it may be enqueued in the dns_ready_head list
// instead, we move it to a close list and free it after the iteration
c->next = c->context->loop->data.closed_connecting_head;
c->context->loop->data.closed_connecting_head = c;
us_socket_context_unref(ssl, c->context);
}
void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) {
if (c->closed) return;
c->closed = 1;
for (struct us_socket_t *s = c->connecting_head; s; s = s->connect_next) {
us_internal_socket_context_unlink_socket(ssl, s->context, s);
us_poll_stop((struct us_poll_t *) s, s->context->loop);
bsd_close_socket(us_poll_fd((struct us_poll_t *) s));
@@ -156,7 +157,7 @@ void us_connecting_socket_close(int ssl, struct us_connecting_socket_t *c) {
// we can only schedule the socket to be freed if there is no pending callback
// otherwise, the callback will see that the socket is closed and will free it
if (!c->pending_resolve_callback) {
us_connecting_socket_free(c);
us_connecting_socket_free(ssl, c);
}
}

View File

@@ -361,4 +361,4 @@ public:
}
#endif // UWS_ASYNCSOCKET_H
#endif // UWS_ASYNCSOCKET_H

File diff suppressed because it is too large Load Diff

View File

@@ -2131,7 +2131,9 @@ pub const Subprocess = struct {
if (subprocess.ipc_data) |*ipc_data| {
if (Environment.isPosix) {
posix_ipc_info.ext(*Subprocess).* = subprocess;
if (posix_ipc_info.ext(*Subprocess)) |ctx| {
ctx.* = subprocess;
}
} else {
if (ipc_data.configureServer(
Subprocess,

View File

@@ -83,9 +83,16 @@ pub const InternalLoopData = extern struct {
pub const InternalSocket = union(enum) {
done: *Socket,
connecting: *ConnectingSocket,
detached: void,
pub fn isDetached(this: InternalSocket) bool {
return this == .detached;
}
pub fn detach(this: *InternalSocket) void {
this.* = .detached;
}
pub fn close(this: InternalSocket, comptime is_ssl: bool, code: CloseCode) void {
switch (this) {
.detached => {},
.done => |socket| {
debug("us_socket_close({d})", .{@intFromPtr(socket)});
_ = us_socket_close(
@@ -109,6 +116,7 @@ pub const InternalSocket = union(enum) {
return switch (this) {
.done => |socket| us_socket_is_closed(@intFromBool(is_ssl), socket) > 0,
.connecting => |socket| us_connecting_socket_is_closed(@intFromBool(is_ssl), socket) > 0,
.detached => true,
};
}
@@ -116,6 +124,7 @@ pub const InternalSocket = union(enum) {
return switch (this) {
.done => this.done,
.connecting => null,
.detached => null,
};
}
@@ -123,12 +132,16 @@ pub const InternalSocket = union(enum) {
return switch (this) {
.done => switch (other) {
.done => this.done == other.done,
.connecting => false,
.connecting, .detached => false,
},
.connecting => switch (other) {
.done => false,
.done, .detached => false,
.connecting => this.connecting == other.connecting,
},
.detached => switch (other) {
.detached => true,
.done, .connecting => false,
},
};
}
};
@@ -138,7 +151,13 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
const ssl_int: i32 = @intFromBool(is_ssl);
socket: InternalSocket,
const ThisSocket = @This();
pub const detached: NewSocketHandler(is_ssl) = NewSocketHandler(is_ssl){ .socket = .{ .detached = {} } };
pub fn detach(this: *ThisSocket) void {
this.socket.detach();
}
pub fn isDetached(this: ThisSocket) bool {
return this.socket.isDetached();
}
pub fn verifyError(this: ThisSocket) us_bun_verify_error_t {
const socket = this.socket.get() orelse return std.mem.zeroes(us_bun_verify_error_t);
const ssl_error: us_bun_verify_error_t = uws.us_socket_verify_error(comptime ssl_int, socket);
@@ -154,6 +173,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
switch (this.socket) {
.done => |socket| us_socket_timeout(comptime ssl_int, socket, seconds),
.connecting => |socket| us_connecting_socket_timeout(comptime ssl_int, socket, seconds),
.detached => {},
}
}
@@ -177,6 +197,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
us_connecting_socket_long_timeout(comptime ssl_int, socket, 0);
}
},
.detached => {},
}
}
@@ -190,19 +211,23 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
us_connecting_socket_timeout(comptime ssl_int, socket, 0);
us_connecting_socket_long_timeout(comptime ssl_int, socket, minutes);
},
.detached => {},
}
}
pub fn startTLS(this: ThisSocket, is_client: bool) void {
const socket = this.socket.get() orelse @panic("socket is not open");
const socket = this.socket.get() orelse return;
_ = us_socket_open(comptime ssl_int, socket, @intFromBool(is_client), null, 0);
}
pub fn ssl(this: ThisSocket) *BoringSSL.SSL {
pub fn ssl(this: ThisSocket) ?*BoringSSL.SSL {
if (comptime is_ssl) {
return @as(*BoringSSL.SSL, @ptrCast(this.getNativeHandle()));
if(this.getNativeHandle()) |handle| {
return @as(*BoringSSL.SSL, @ptrCast(handle));
}
return null;
}
@panic("socket is not a TLS socket");
return null;
}
// Note: this assumes that the socket is non-TLS and will be adopted and wrapped with a new TLS context
@@ -229,7 +254,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
}
if (comptime deref_) {
return (TLSSocket.from(socket)).ext(ContextType).*;
return (TLSSocket.from(socket)).ext(ContextType).?.*;
}
return (TLSSocket.from(socket)).ext(ContextType);
@@ -289,7 +314,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
}
pub fn on_connect_error(socket: *Socket, code: i32) callconv(.C) ?*Socket {
Fields.onConnectError(
TLSSocket.from(socket).ext(ContextType).*,
TLSSocket.from(socket).ext(ContextType).?.*,
TLSSocket.from(socket),
code,
);
@@ -328,7 +353,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
.on_long_timeout = SocketHandler.on_long_timeout,
};
const this_socket = this.socket.get() orelse @panic("socket is not open");
const this_socket = this.socket.get() orelse return null;
const socket = us_socket_wrap_with_tls(ssl_int, this_socket, options, events, socket_ext_size) orelse return null;
return NewSocketHandler(true).from(socket);
@@ -338,6 +363,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
return @ptrCast(switch (this.socket) {
.done => |socket| us_socket_get_native_handle(comptime ssl_int, socket),
.connecting => |socket| us_connecting_socket_get_native_handle(comptime ssl_int, socket),
.detached => null,
} orelse return null);
}
@@ -361,7 +387,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
us_socket_sendfile_needs_more(socket);
}
pub fn ext(this: ThisSocket, comptime ContextType: type) *ContextType {
pub fn ext(this: ThisSocket, comptime ContextType: type) ?*ContextType {
const alignment = if (ContextType == *anyopaque)
@sizeOf(usize)
else
@@ -370,6 +396,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
const ptr = switch (this.socket) {
.done => |sock| us_socket_ext(comptime ssl_int, sock),
.connecting => |sock| us_connecting_socket_ext(comptime ssl_int, sock),
.detached => return null,
};
return @as(*align(alignment) ContextType, @ptrCast(@alignCast(ptr)));
@@ -380,6 +407,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
switch (this.socket) {
.done => |socket| return us_socket_context(comptime ssl_int, socket),
.connecting => |socket| return us_connecting_socket_context(comptime ssl_int, socket),
.detached => return null,
}
}
@@ -434,6 +462,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
socket,
);
},
.detached => {},
}
}
@@ -453,6 +482,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
socket,
);
},
.detached => {},
}
}
@@ -470,6 +500,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
socket,
) > 0;
},
.detached => return true,
}
}
@@ -495,6 +526,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
socket,
);
},
.detached => return 0,
}
}
@@ -642,8 +674,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
) ?ThisSocket {
const socket_ = ThisSocket{ .socket = .{ .done = us_socket_from_fd(ctx, @sizeOf(*anyopaque), bun.socketcast(handle)) orelse return null } };
const holder = socket_.ext(*anyopaque);
holder.* = this;
if(socket_.ext(*anyopaque)) |holder| {
holder.* = this;
}
if (comptime socket_field_name) |field| {
@field(this, field) = socket_;
@@ -679,8 +712,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
return error.FailedToOpenSocket;
const socket_ = ThisSocket{ .socket = .{ .done = socket } };
const holder = socket_.ext(*anyopaque);
holder.* = ctx;
if(socket_.ext(*anyopaque)) |holder| {
holder.* = ctx;
}
return socket_;
}
@@ -721,9 +755,9 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
ThisSocket{
.socket = .{ .connecting = @ptrCast(socket_ptr) },
};
const holder = socket.ext(*anyopaque);
holder.* = ptr;
if(socket.ext(*anyopaque)) |holder| {
holder.* = ptr;
}
return socket;
}
@@ -751,7 +785,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
}
if (comptime deref_) {
return (SocketHandlerType.from(socket)).ext(ContextType).*;
return (SocketHandlerType.from(socket)).ext(ContextType).?.*;
}
return (SocketHandlerType.from(socket)).ext(ContextType);
@@ -806,7 +840,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
const val = if (comptime ContextType == anyopaque)
us_connecting_socket_ext(comptime ssl_int, socket)
else if (comptime deref_)
SocketHandlerType.fromConnecting(socket).ext(ContextType).*
SocketHandlerType.fromConnecting(socket).ext(ContextType).?.*
else
SocketHandlerType.fromConnecting(socket).ext(ContextType);
Fields.onConnectError(
@@ -820,7 +854,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
const val = if (comptime ContextType == anyopaque)
us_socket_ext(comptime ssl_int, socket)
else if (comptime deref_)
SocketHandlerType.from(socket).ext(ContextType).*
SocketHandlerType.from(socket).ext(ContextType).?.*
else
SocketHandlerType.from(socket).ext(ContextType);
Fields.onConnectError(
@@ -883,7 +917,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
}
if (comptime deref_) {
return (ThisSocket.from(socket)).ext(ContextType).*;
return (ThisSocket.from(socket)).ext(ContextType).?.*;
}
return (ThisSocket.from(socket)).ext(ContextType);
@@ -945,7 +979,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
const val = if (comptime ContextType == anyopaque)
us_connecting_socket_ext(comptime ssl_int, socket)
else if (comptime deref_)
ThisSocket.fromConnecting(socket).ext(ContextType).*
ThisSocket.fromConnecting(socket).ext(ContextType).?.*
else
ThisSocket.fromConnecting(socket).ext(ContextType);
Fields.onConnectError(
@@ -959,7 +993,7 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
const val = if (comptime ContextType == anyopaque)
us_socket_ext(comptime ssl_int, socket)
else if (comptime deref_)
ThisSocket.from(socket).ext(ContextType).*
ThisSocket.from(socket).ext(ContextType).?.*
else
ThisSocket.from(socket).ext(ContextType);
@@ -1033,14 +1067,14 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
const new_socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, -1) orelse return false;
bun.assert(new_socket == socket);
var adopted = ThisSocket.from(new_socket);
const holder = adopted.ext(*anyopaque);
holder.* = ctx;
if(adopted.ext(*anyopaque)) |holder| {
holder.* = ctx;
}
@field(ctx, socket_field_name) = adopted;
return true;
}
};
}
pub const SocketTCP = NewSocketHandler(false);
pub const SocketTLS = NewSocketHandler(true);
@@ -2312,7 +2346,6 @@ pub fn NewApp(comptime ssl: bool) type {
}
}
pub fn onAborted(res: *Response, comptime UserDataType: type, comptime handler: fn (UserDataType, *Response) void, opcional_data: UserDataType) void {
const Wrapper = struct {
pub fn handle(this: *uws_res, user_data: ?*anyopaque) callconv(.C) void {
if (comptime UserDataType == void) {

View File

@@ -342,7 +342,9 @@ fn NewHTTPContext(comptime ssl: bool) type {
};
pub fn markSocketAsDead(socket: HTTPSocket) void {
socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr());
if (socket.ext(**anyopaque)) |ctx| {
ctx.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr());
}
}
fn terminateSocket(socket: HTTPSocket) void {
@@ -461,7 +463,9 @@ fn NewHTTPContext(comptime ssl: bool) type {
if (hostname.len <= MAX_KEEPALIVE_HOSTNAME and !socket.isClosedOrHasError() and socket.isEstablished()) {
if (this.pending_sockets.get()) |pending| {
socket.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(pending).ptr());
if (socket.ext(**anyopaque)) |ctx| {
ctx.* = bun.cast(**anyopaque, ActiveSocket.init(pending).ptr());
}
socket.flush();
socket.timeout(0);
socket.setTimeoutMinutes(5);
@@ -517,9 +521,9 @@ fn NewHTTPContext(comptime ssl: bool) type {
// handshake completed but we may have ssl errors
client.flags.did_have_handshaking_error = handshake_error.error_no != 0;
if (handshake_success) {
if(client.flags.reject_unauthorized) {
if (client.flags.reject_unauthorized) {
// only reject the connection if reject_unauthorized == true
if(client.flags.did_have_handshaking_error) {
if (client.flags.did_have_handshaking_error) {
client.closeAndFail(BoringSSL.getCertErrorFromNo(handshake_error.error_no), comptime ssl, socket);
return;
}
@@ -537,7 +541,7 @@ fn NewHTTPContext(comptime ssl: bool) type {
} else {
// if we are here is because server rejected us, and the error_no is the cause of this
// if we set reject_unauthorized == false this means the server requires custom CA aka NODE_EXTRA_CA_CERTS
if(client.flags.did_have_handshaking_error) {
if (client.flags.did_have_handshaking_error) {
client.closeAndFail(BoringSSL.getCertErrorFromNo(handshake_error.error_no), comptime ssl, socket);
return;
}
@@ -545,7 +549,6 @@ fn NewHTTPContext(comptime ssl: bool) type {
client.closeAndFail(error.ConnectionRefused, comptime ssl, socket);
return;
}
}
if (socket.isClosed()) {
@@ -745,7 +748,9 @@ fn NewHTTPContext(comptime ssl: bool) type {
if (client.isKeepAlivePossible()) {
if (this.existingSocket(client.flags.reject_unauthorized, hostname, port)) |sock| {
sock.ext(**anyopaque).* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr());
if (sock.ext(**anyopaque)) |ctx| {
ctx.* = bun.cast(**anyopaque, ActiveSocket.init(client).ptr());
}
client.allow_retry = true;
client.onOpen(comptime ssl, sock);
if (comptime ssl) {
@@ -2858,10 +2863,10 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s
}
pub fn closeAndFail(this: *HTTPClient, err: anyerror, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket) void {
if (!socket.isClosed()) {
NewHTTPContext(is_ssl).terminateSocket(socket);
}
if (this.state.stage != .fail and this.state.stage != .done) {
if (!socket.isClosed()) {
NewHTTPContext(is_ssl).terminateSocket(socket);
}
log("closeAndFail: {s}", .{@errorName(err)});
this.fail(err);
}

View File

@@ -100,7 +100,7 @@ export async function expectMaxObjectTypeCount(
await Bun.sleep(wait);
gc();
}
expect(heapStats().objectTypeCounts[type]).toBeLessThanOrEqual(count);
expect(heapStats().objectTypeCounts[type] || 0).toBeLessThanOrEqual(count);
}
// we must ensure that finalizers are run

View File

@@ -126,9 +126,9 @@ async function calculateMemoryLeak(fn: () => Promise<void>) {
// If it was leaking the body, the memory usage would be at least 512 KB * 10_000 = 5 GB
// If it ends up around 280 MB, it's probably not leaking the body.
for (const test_info of [
["#10265 should not leak memory when ignoring the body", callIgnore, false, 48],
["should not leak memory when buffering the body", callBuffering, false, 48],
["should not leak memory when streaming the body", callStreaming, false, 48],
["#10265 should not leak memory when ignoring the body", callIgnore, false, 64],
["should not leak memory when buffering the body", callBuffering, false, 64],
["should not leak memory when streaming the body", callStreaming, false, 64],
["should not leak memory when streaming the body incompletely", callIncompleteStreaming, false, 64],
["should not leak memory when streaming the body and echoing it back", callStreamingEcho, false, 64],
] as const) {

View File

@@ -1,5 +1,5 @@
import { expect, it } from "bun:test";
import { bunEnv, bunExe, expectMaxObjectTypeCount, isWindows } from "harness";
import { bunEnv, bunExe, expectMaxObjectTypeCount, isWindows, tls } from "harness";
import { connect, fileURLToPath, SocketHandler, spawn } from "bun";
import type { Socket } from "bun";
it("should coerce '0' to 0", async () => {
@@ -354,7 +354,7 @@ it("it should not crash when returning a Error on client socket open", async ()
});
it("it should only call open once", async () => {
const server = Bun.listen({
using server = Bun.listen({
port: 0,
hostname: "localhost",
socket: {
@@ -381,7 +381,6 @@ it("it should only call open once", async () => {
expect().fail("connectError should not be called");
},
close(socket) {
server.stop();
resolve();
},
data(socket, data) {},
@@ -397,7 +396,7 @@ it.skipIf(isWindows)("should not leak file descriptors when connecting", async (
});
it("should not call open if the connection had an error", async () => {
const server = Bun.listen({
using server = Bun.listen({
port: 0,
hostname: "0.0.0.0",
socket: {
@@ -435,12 +434,11 @@ it("should not call open if the connection had an error", async () => {
await Bun.sleep(50);
await promise;
server.stop();
expect(hadError).toBe(true);
});
it("should connect directly when using an ip address", async () => {
const server = Bun.listen({
using server = Bun.listen({
port: 0,
hostname: "127.0.0.1",
socket: {
@@ -467,7 +465,6 @@ it("should connect directly when using an ip address", async () => {
expect().fail("connectError should not be called");
},
close(socket) {
server.stop();
resolve();
},
data(socket, data) {},
@@ -498,3 +495,77 @@ it("should not call drain before handshake", async () => {
await promise;
expect(socket.authorized).toBe(true);
});
it("should be able to upgrade to TLS", async () => {
using server = Bun.serve({
tls,
async fetch(req) {
return new Response("Hello World");
},
});
const { promise: tlsSocketPromise, resolve, reject } = Promise.withResolvers();
const { promise: rawSocketPromise, resolve: rawSocketResolve, reject: rawSocketReject } = Promise.withResolvers();
{
let body = "";
let rawBody = Buffer.alloc(0);
const socket = await Bun.connect({
hostname: "localhost",
port: server.port,
socket: {
data(socket, data) {
rawBody = Buffer.concat([rawBody, data]);
},
close() {
rawSocketResolve(rawBody);
},
error(err) {
rawSocketReject(err);
},
},
});
const result = socket.upgradeTLS({
data: Buffer.from("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n"),
tls,
socket: {
data(socket, data) {
body += data.toString("utf8");
if (body.includes("\r\n\r\n")) {
socket.end();
}
},
close() {
resolve(body);
},
drain(socket) {
while (socket.data.byteLength > 0) {
const written = socket.write(socket.data);
if (written === 0) {
break;
}
socket.data = socket.data.slice(written);
}
socket.flush();
},
error(err) {
reject(err);
},
},
});
const [raw, tls_socket] = result;
expect(raw).toBeDefined();
expect(tls_socket).toBeDefined();
}
const [tlsData, rawData] = await Promise.all([tlsSocketPromise, rawSocketPromise]);
expect(tlsData).toContain("HTTP/1.1 200 OK");
expect(tlsData).toContain("Content-Length: 11");
expect(tlsData).toContain("\r\nHello World");
expect(rawData.byteLength).toBeGreaterThanOrEqual(1980);
});
it("should not leak memory", async () => {
// assert we don't leak the sockets
// we expect 1 or 2 because that's the prototype / structure
await expectMaxObjectTypeCount(expect, "Listener", 2);
await expectMaxObjectTypeCount(expect, "TCPSocket", 2);
await expectMaxObjectTypeCount(expect, "TLSSocket", 2);
});

View File

@@ -15,7 +15,7 @@ it("remoteAddress works", async () => {
};
reject = reject1;
});
let server = Bun.listen({
using server = Bun.listen({
socket: {
open(ws) {
try {
@@ -25,8 +25,6 @@ it("remoteAddress works", async () => {
reject(e);
return;
} finally {
setTimeout(() => server.stop(true), 0);
}
},
close() {},
@@ -63,7 +61,7 @@ it("should not allow invalid tls option", () => {
[1, "string", Symbol("symbol")].forEach(value => {
expect(() => {
// @ts-ignore
const server = Bun.listen({
using server = Bun.listen({
socket: {
open(ws) {},
close() {},
@@ -73,7 +71,6 @@ it("should not allow invalid tls option", () => {
hostname: "localhost",
tls: value,
});
server.stop(true);
}).toThrow("tls option expects an object");
});
});
@@ -82,7 +79,7 @@ it("should allow using false, null or undefined tls option", () => {
[false, null, undefined].forEach(value => {
expect(() => {
// @ts-ignore
const server = Bun.listen({
using server = Bun.listen({
socket: {
open(ws) {},
close() {},
@@ -92,7 +89,6 @@ it("should allow using false, null or undefined tls option", () => {
hostname: "localhost",
tls: value,
});
server.stop(true);
}).not.toThrow("tls option expects an object");
});
});
@@ -167,7 +163,7 @@ it("echo server 1 on 1", async () => {
},
} as SocketHandler<any>;
var server: TCPSocketListener<any> | undefined = listen({
using server: TCPSocketListener<any> | undefined = listen({
socket: handlers,
hostname: "localhost",
port: 0,
@@ -186,8 +182,6 @@ it("echo server 1 on 1", async () => {
},
});
await Promise.all([prom, clientProm, serverProm]);
server.stop(true);
server = serverData = clientData = undefined;
})();
});
@@ -276,7 +270,7 @@ describe("tcp socket binaryType", () => {
binaryType: type,
} as SocketHandler<any>;
var server: TCPSocketListener<any> | undefined = listen({
using server: TCPSocketListener<any> | undefined = listen({
socket: handlers,
hostname: "localhost",
port: 0,
@@ -296,8 +290,6 @@ describe("tcp socket binaryType", () => {
});
await Promise.all([prom, clientProm, serverProm]);
server.stop(true);
server = serverData = clientData = undefined;
})();
});
}