From 837cbd60d5ffc5d1faf2131bbd8d968217bab2ff Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Mon, 1 Jan 2024 18:08:08 -0800 Subject: [PATCH] Fix crash in WebSocket client when handshaking fails or when the HTTP response is invalid (#7933) * Fix double-free in websocket client * Update test * Fix null pointer dereference * Fix missing protect() / unprotect() call * More careful checks --------- Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com> --- packages/bun-usockets/src/crypto/openssl.c | 11 ++- packages/bun-usockets/src/loop.c | 11 +-- src/bun.js/api/bun/socket.zig | 14 ++-- src/bun.js/api/server.zig | 10 ++- src/bun.zig | 18 ++++ src/deps/uws.zig | 18 ++++ src/http/websocket_http_client.zig | 97 +++++++++++++--------- test/js/web/websocket/websocket.test.js | 41 ++++++++- 8 files changed, 161 insertions(+), 59 deletions(-) diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index aeab551e97..b222946303 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -499,7 +499,7 @@ restart: s = (struct us_internal_ssl_socket_t *)context->sc.on_writable( &s->s); // cast here! // if we are closed here, then exit - if (us_socket_is_closed(0, &s->s)) { + if (!s || us_socket_is_closed(0, &s->s)) { return s; } } @@ -544,8 +544,13 @@ ssl_on_writable(struct us_internal_ssl_socket_t *s) { 0); // cast here! } - // should this one come before we have read? should it come always? spurious - // on_writable is okay + + // Do not call on_writable if the socket is closed. + // on close means the socket data is no longer accessible + if (!s || us_socket_is_closed(0, &s->s)) { + return 0; + } + s = context->on_writable(s); return s; diff --git a/packages/bun-usockets/src/loop.c b/packages/bun-usockets/src/loop.c index b2608b99de..8083eabbf4 100644 --- a/packages/bun-usockets/src/loop.c +++ b/packages/bun-usockets/src/loop.c @@ -207,8 +207,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) #endif } cb->cb(cb->cb_expects_the_loop ? (struct us_internal_callback_t *) cb->loop : (struct us_internal_callback_t *) &cb->p); + break; } - break; case POLL_TYPE_SEMI_SOCKET: { /* Both connect and listen sockets are semi-sockets * but they poll for different events */ @@ -220,6 +220,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) /* Emit error, close without emitting on_close */ s->context->on_connect_error(s, 0); us_socket_close_connecting(0, s); + s = NULL; } else { /* All sockets poll for readable */ us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE); @@ -274,8 +275,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) } while ((client_fd = bsd_accept_socket(us_poll_fd(p), &addr)) != LIBUS_SOCKET_ERROR); } } - } break; + } case POLL_TYPE_SOCKET_SHUT_DOWN: case POLL_TYPE_SOCKET: { /* We should only use s, no p after this point */ @@ -288,7 +289,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) s = s->context->on_writable(s); - if (us_socket_is_closed(0, s)) { + if (!s || us_socket_is_closed(0, s)) { return; } @@ -346,13 +347,13 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events) } /* Such as epollerr epollhup */ - if (error) { + if (error && s) { /* Todo: decide what code we give here */ s = us_socket_close(0, s, 0, NULL); return; } + break; } - break; } } diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index b5e688b624..a6445dab2d 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -435,7 +435,7 @@ pub const SocketConfig = struct { return null; } - const handlers = Handlers.fromJS(globalObject, opts.get(globalObject, "socket") orelse JSValue.zero, exception) orelse { + var handlers = Handlers.fromJS(globalObject, opts.get(globalObject, "socket") orelse JSValue.zero, exception) orelse { hostname_or_unix.deinit(); return null; }; @@ -444,6 +444,8 @@ pub const SocketConfig = struct { default_data = default_data_value; } + handlers.protect(); + return SocketConfig{ .hostname_or_unix = hostname_or_unix, .port = port, @@ -547,7 +549,7 @@ pub const Listener = struct { return .zero; }; - var prev_handlers = this.handlers; + var prev_handlers = &this.handlers; prev_handlers.unprotect(); this.handlers = handlers; // TODO: this is a memory leak this.handlers.protect(); @@ -579,7 +581,7 @@ pub const Listener = struct { var hostname_or_unix = socket_config.hostname_or_unix; const port = socket_config.port; var ssl = socket_config.ssl; - var handlers = socket_config.handlers; + var handlers = &socket_config.handlers; var protos: ?[]const u8 = null; const exclusive = socket_config.exclusive; handlers.is_server = true; @@ -714,7 +716,7 @@ pub const Listener = struct { }; var socket = Listener{ - .handlers = handlers, + .handlers = handlers.*, .connection = connection, .ssl = ssl_enabled, .socket_context = socket_context, @@ -837,6 +839,7 @@ pub const Listener = struct { this.poll_ref.unref(this.handlers.vm); std.debug.assert(this.listener == null); std.debug.assert(this.handlers.active_connections == 0); + this.handlers.unprotect(); if (this.socket_context) |ctx| { ctx.deinit(this.ssl); @@ -925,8 +928,6 @@ pub const Listener = struct { const ssl_enabled = ssl != null; defer if (ssl != null) ssl.?.deinit(); - handlers.protect(); - const ctx_opts: uws.us_bun_socket_context_options_t = JSC.API.ServerConfig.SSLConfig.asUSockets(socket_config.ssl); globalObject.bunVM().eventLoop().ensureWaker(); @@ -938,6 +939,7 @@ pub const Listener = struct { .code = if (port == null) bun.String.static("ENOENT") else bun.String.static("ECONNREFUSED"), }; exception.* = err.toErrorInstance(globalObject).asObjectRef(); + handlers.unprotect(); return .zero; }; diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 9b0e15cec6..5086074aea 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -3651,6 +3651,7 @@ pub const ServerWebSocket = struct { opened: bool = false, pub usingnamespace JSC.Codegen.JSServerWebSocket; + pub usingnamespace bun.New(ServerWebSocket); const log = Output.scoped(.WebSocketServer, false); @@ -3958,7 +3959,7 @@ pub const ServerWebSocket = struct { pub fn finalize(this: *ServerWebSocket) callconv(.C) void { log("finalize", .{}); - bun.default_allocator.destroy(this); + this.destroy(); } pub fn publish( @@ -5077,11 +5078,12 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp resp.clearAborted(); - const ws = this.vm.allocator.create(ServerWebSocket) catch return .zero; - ws.* = .{ + data_value.ensureStillAlive(); + const ws = ServerWebSocket.new(.{ .handler = &this.config.websocket.?.handler, .this_value = data_value, - }; + }); + data_value.ensureStillAlive(); var sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator); defer sec_websocket_protocol_str.deinit(); diff --git a/src/bun.zig b/src/bun.zig index 983e81d364..32884f9b9d 100644 --- a/src/bun.zig +++ b/src/bun.zig @@ -2835,7 +2835,13 @@ pub inline fn destroyWithAlloc(allocator: std.mem.Allocator, t: anytype) void { pub fn New(comptime T: type) type { return struct { + const allocation_logger = Output.scoped(.alloc, @hasDecl(T, "logAllocations")); + pub inline fn destroy(self: *T) void { + if (comptime Environment.allow_assert) { + allocation_logger("destroy({*})", .{self}); + } + if (comptime is_heap_breakdown_enabled) { HeapBreakdown.allocator(T).destroy(self); } else { @@ -2847,11 +2853,18 @@ pub fn New(comptime T: type) type { if (comptime is_heap_breakdown_enabled) { const ptr = HeapBreakdown.allocator(T).create(T) catch outOfMemory(); ptr.* = t; + if (comptime Environment.allow_assert) { + allocation_logger("new() = {*}", .{ptr}); + } return ptr; } const ptr = default_allocator.create(T) catch outOfMemory(); ptr.* = t; + + if (comptime Environment.allow_assert) { + allocation_logger("new() = {*}", .{ptr}); + } return ptr; } }; @@ -2874,9 +2887,12 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void) } return struct { + const allocation_logger = Output.scoped(.alloc, @hasDecl(T, "logAllocations")); + pub fn destroy(self: *T) void { if (comptime Environment.allow_assert) { std.debug.assert(self.ref_count == 0); + allocation_logger("destroy() = {*}", .{self}); } if (comptime is_heap_breakdown_enabled) { @@ -2909,6 +2925,7 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void) if (comptime Environment.allow_assert) { std.debug.assert(ptr.ref_count == 1); + allocation_logger("new() = {*}", .{ptr}); } return ptr; @@ -2919,6 +2936,7 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void) if (comptime Environment.allow_assert) { std.debug.assert(ptr.ref_count == 1); + allocation_logger("new() = {*}", .{ptr}); } return ptr; diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 61315d1f7a..1e8ef310e8 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -764,6 +764,24 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type { @field(holder, socket_field_name) = adopted; return holder; } + + pub fn adoptPtr( + socket: *Socket, + socket_ctx: *SocketContext, + comptime Context: type, + comptime socket_field_name: []const u8, + ctx: *Context, + ) bool { + var adopted = ThisSocket{ .socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, @sizeOf(*Context)) orelse return false }; + const holder = adopted.ext(*anyopaque) orelse { + if (comptime bun.Environment.allow_assert) unreachable; + _ = us_socket_close(comptime ssl_int, socket, 0, null); + return false; + }; + holder.* = ctx; + @field(ctx, socket_field_name) = adopted; + return true; + } }; } diff --git a/src/http/websocket_http_client.zig b/src/http/websocket_http_client.zig index d2ea819002..a7a5ce745a 100644 --- a/src/http/websocket_http_client.zig +++ b/src/http/websocket_http_client.zig @@ -196,7 +196,7 @@ const BodyBuf = BodyBufPool.Node; pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { return struct { pub const Socket = uws.NewSocketHandler(ssl); - tcp: Socket, + tcp: ?Socket = null, outgoing_websocket: ?*CppWebSocket, input_body_buf: []u8 = &[_]u8{}, client_protocol: []const u8 = "", @@ -207,9 +207,11 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { websocket_protocol: u64 = 0, hostname: [:0]const u8 = "", poll_ref: Async.KeepAlive = Async.KeepAlive.init(), + pub const name = if (ssl) "WebSocketHTTPSClient" else "WebSocketHTTPClient"; pub const shim = JSC.Shimmer("Bun", name, @This()); + pub usingnamespace bun.New(@This()); const HTTPClient = @This(); @@ -227,8 +229,8 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { Socket.configure( ctx, - false, - HTTPClient, + true, + *HTTPClient, struct { pub const onOpen = handleOpen; pub const onClose = handleClose; @@ -272,12 +274,13 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { ) catch return null; var vm = global.bunVM(); - var client: HTTPClient = HTTPClient{ + var client = HTTPClient.new(.{ .tcp = undefined, .outgoing_websocket = websocket, .input_body_buf = body, .websocket_protocol = client_protocol_hash, - }; + }); + var host_ = host.toSlice(bun.default_allocator); defer host_.deinit(); const prev_start_server_on_next_tick = vm.eventLoop().start_server_on_next_tick; @@ -289,7 +292,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { else display_host_; - if (Socket.connect( + if (Socket.connectPtr( display_host, port, @as(*uws.SocketContext, @ptrCast(socket_ctx)), @@ -303,12 +306,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { } } - out.tcp.timeout(120); + out.tcp.?.timeout(120); return out; - } - vm.eventLoop().start_server_on_next_tick = prev_start_server_on_next_tick; + } else { + vm.eventLoop().start_server_on_next_tick = prev_start_server_on_next_tick; - client.clearData(); + client.clearData(); + client.destroy(); + } return null; } @@ -326,10 +331,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn cancel(this: *HTTPClient) callconv(.C) void { this.clearData(); - if (!this.tcp.isEstablished()) { - _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket); + var tcp = this.tcp orelse return; + this.tcp = null; + + if (!tcp.isEstablished()) { + _ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), tcp.socket); } else { - this.tcp.close(0, null); + tcp.shutdown(); + tcp.close(0, null); } } @@ -348,16 +357,20 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { log("onClose", .{}); JSC.markBinding(@src()); this.clearData(); + this.tcp = null; + if (this.outgoing_websocket) |ws| { this.outgoing_websocket = null; ws.didAbruptClose(ErrorCode.ended); } + + this.destroy(); } pub fn terminate(this: *HTTPClient, code: ErrorCode) void { this.fail(code); - if (!this.tcp.isClosed()) - this.tcp.close(0, null); + + // We cannot access the pointer after fail is called. } pub fn handleHandshake(this: *HTTPClient, socket: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void { @@ -388,7 +401,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn handleOpen(this: *HTTPClient, socket: Socket) void { log("onOpen", .{}); - std.debug.assert(socket.socket == this.tcp.socket); + std.debug.assert(socket.socket == this.tcp.?.socket); std.debug.assert(this.input_body_buf.len > 0); std.debug.assert(this.to_send.len == 0); @@ -414,11 +427,11 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void { log("onData", .{}); defer JSC.VirtualMachine.get().drainMicrotasks(); - std.debug.assert(socket.socket == this.tcp.socket); if (this.outgoing_websocket == null) { this.clearData(); return; } + std.debug.assert(socket.socket == this.tcp.?.socket); if (comptime Environment.allow_assert) std.debug.assert(!socket.isShutdown()); @@ -458,7 +471,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { pub fn handleEnd(this: *HTTPClient, socket: Socket) void { log("onEnd", .{}); - std.debug.assert(socket.socket == this.tcp.socket); + std.debug.assert(socket.socket == this.tcp.?.socket); this.terminate(ErrorCode.ended); } @@ -570,17 +583,17 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.clearData(); JSC.markBinding(@src()); - this.tcp.timeout(0); + this.tcp.?.timeout(0); log("onDidConnect", .{}); - this.outgoing_websocket.?.didConnect(this.tcp.socket, overflow.ptr, overflow.len); + this.outgoing_websocket.?.didConnect(this.tcp.?.socket, overflow.ptr, overflow.len); } pub fn handleWritable( this: *HTTPClient, socket: Socket, ) void { - std.debug.assert(socket.socket == this.tcp.socket); + std.debug.assert(socket.socket == this.tcp.?.socket); if (this.to_send.len == 0) return; @@ -600,7 +613,9 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type { this.terminate(ErrorCode.timeout); } pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void { + this.tcp = null; this.terminate(ErrorCode.failed_to_connect); + this.destroy(); } pub const Export = shim.exportFunctions(.{ @@ -923,6 +938,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { const WebSocket = @This(); + pub usingnamespace bun.New(@This()); + pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void { const vm = global.bunVM(); const loop = @as(*uws.Loop, @ptrCast(@alignCast(loop_))); @@ -937,8 +954,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { Socket.configure( ctx, - false, - WebSocket, + true, + *WebSocket, struct { pub const onClose = handleClose; pub const onData = handleData; @@ -946,7 +963,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { pub const onTimeout = handleTimeout; pub const onConnectError = handleConnectError; pub const onEnd = handleEnd; - // just by adding it will fix ssl handshake pub const onHandshake = handleHandshake; }, ); @@ -1758,28 +1774,33 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { ) callconv(.C) ?*anyopaque { const tcp = @as(*uws.Socket, @ptrCast(input_socket)); const ctx = @as(*uws.SocketContext, @ptrCast(socket_ctx)); - var adopted = Socket.adopt( + var ws = WebSocket.new(WebSocket{ + .tcp = undefined, + .outgoing_websocket = outgoing, + .globalThis = globalThis, + .send_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), + .receive_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), + }); + if (!Socket.adoptPtr( tcp, ctx, WebSocket, "tcp", - WebSocket{ - .tcp = undefined, - .outgoing_websocket = outgoing, - .globalThis = globalThis, - .send_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), - .receive_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator), - }, - ) orelse return null; - adopted.send_buffer.ensureTotalCapacity(2048) catch return null; - adopted.receive_buffer.ensureTotalCapacity(2048) catch return null; - adopted.poll_ref.ref(globalThis.bunVM()); + ws, + )) { + ws.destroy(); + return null; + } + + ws.send_buffer.ensureTotalCapacity(2048) catch return null; + ws.receive_buffer.ensureTotalCapacity(2048) catch return null; + ws.poll_ref.ref(globalThis.bunVM()); const buffered_slice: []u8 = buffered_data[0..buffered_data_len]; if (buffered_slice.len > 0) { const initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable; initial_data.* = .{ - .adopted = adopted, + .adopted = ws, .slice = buffered_slice, .ws = outgoing, }; @@ -1793,7 +1814,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type { } return @as( *anyopaque, - @ptrCast(adopted), + @ptrCast(ws), ); } diff --git a/test/js/web/websocket/websocket.test.js b/test/js/web/websocket/websocket.test.js index da7a9e8fba..aab5bdc788 100644 --- a/test/js/web/websocket/websocket.test.js +++ b/test/js/web/websocket/websocket.test.js @@ -258,11 +258,10 @@ describe("WebSocket", () => { }); }); - it("should connect over http", done => { + it("should FAIL to connect over http when the status code is invalid", done => { const server = Bun.serve({ port: 0, fetch(req, server) { - done(); server.stop(); return new Response(); }, @@ -271,9 +270,45 @@ describe("WebSocket", () => { message(ws) { ws.close(); }, + close() {}, }, }); - new WebSocket(`http://${server.hostname}:${server.port}`, {}); + var ws = new WebSocket(`http://${server.hostname}:${server.port}`, {}); + ws.onopen = () => { + ws.send("Hello World!"); + }; + + ws.onclose = e => { + expect(e.code).toBe(1002); + done(); + }; + }); + + it("should connect over http ", done => { + const server = Bun.serve({ + port: 0, + fetch(req, server) { + server.upgrade(req); + server.stop(); + + return new Response(); + }, + websocket: { + open(ws) {}, + message(ws) { + ws.close(); + }, + close() {}, + }, + }); + var ws = new WebSocket(`http://${server.hostname}:${server.port}`, {}); + ws.onopen = () => { + ws.send("Hello World!"); + }; + + ws.onclose = () => { + done(); + }; }); describe("nodebuffer", () => { it("should support 'nodebuffer' binaryType", done => {