Fix crash in WebSocket client when handshaking fails or when the HTTP response is invalid (#7933)

* Fix double-free in websocket client

* Update test

* Fix null pointer dereference

* Fix missing protect() / unprotect() call

* More careful checks

---------

Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com>
This commit is contained in:
Jarred Sumner
2024-01-01 18:08:08 -08:00
committed by GitHub
parent 9d6c0649a4
commit 837cbd60d5
8 changed files with 161 additions and 59 deletions

View File

@@ -499,7 +499,7 @@ restart:
s = (struct us_internal_ssl_socket_t *)context->sc.on_writable(
&s->s); // cast here!
// if we are closed here, then exit
if (us_socket_is_closed(0, &s->s)) {
if (!s || us_socket_is_closed(0, &s->s)) {
return s;
}
}
@@ -544,8 +544,13 @@ ssl_on_writable(struct us_internal_ssl_socket_t *s) {
0); // cast here!
}
// should this one come before we have read? should it come always? spurious
// on_writable is okay
// Do not call on_writable if the socket is closed.
// on close means the socket data is no longer accessible
if (!s || us_socket_is_closed(0, &s->s)) {
return 0;
}
s = context->on_writable(s);
return s;

View File

@@ -207,8 +207,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
#endif
}
cb->cb(cb->cb_expects_the_loop ? (struct us_internal_callback_t *) cb->loop : (struct us_internal_callback_t *) &cb->p);
break;
}
break;
case POLL_TYPE_SEMI_SOCKET: {
/* Both connect and listen sockets are semi-sockets
* but they poll for different events */
@@ -220,6 +220,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
/* Emit error, close without emitting on_close */
s->context->on_connect_error(s, 0);
us_socket_close_connecting(0, s);
s = NULL;
} else {
/* All sockets poll for readable */
us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE);
@@ -274,8 +275,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
} while ((client_fd = bsd_accept_socket(us_poll_fd(p), &addr)) != LIBUS_SOCKET_ERROR);
}
}
}
break;
}
case POLL_TYPE_SOCKET_SHUT_DOWN:
case POLL_TYPE_SOCKET: {
/* We should only use s, no p after this point */
@@ -288,7 +289,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
s = s->context->on_writable(s);
if (us_socket_is_closed(0, s)) {
if (!s || us_socket_is_closed(0, s)) {
return;
}
@@ -346,13 +347,13 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
}
/* Such as epollerr epollhup */
if (error) {
if (error && s) {
/* Todo: decide what code we give here */
s = us_socket_close(0, s, 0, NULL);
return;
}
break;
}
break;
}
}

View File

@@ -435,7 +435,7 @@ pub const SocketConfig = struct {
return null;
}
const handlers = Handlers.fromJS(globalObject, opts.get(globalObject, "socket") orelse JSValue.zero, exception) orelse {
var handlers = Handlers.fromJS(globalObject, opts.get(globalObject, "socket") orelse JSValue.zero, exception) orelse {
hostname_or_unix.deinit();
return null;
};
@@ -444,6 +444,8 @@ pub const SocketConfig = struct {
default_data = default_data_value;
}
handlers.protect();
return SocketConfig{
.hostname_or_unix = hostname_or_unix,
.port = port,
@@ -547,7 +549,7 @@ pub const Listener = struct {
return .zero;
};
var prev_handlers = this.handlers;
var prev_handlers = &this.handlers;
prev_handlers.unprotect();
this.handlers = handlers; // TODO: this is a memory leak
this.handlers.protect();
@@ -579,7 +581,7 @@ pub const Listener = struct {
var hostname_or_unix = socket_config.hostname_or_unix;
const port = socket_config.port;
var ssl = socket_config.ssl;
var handlers = socket_config.handlers;
var handlers = &socket_config.handlers;
var protos: ?[]const u8 = null;
const exclusive = socket_config.exclusive;
handlers.is_server = true;
@@ -714,7 +716,7 @@ pub const Listener = struct {
};
var socket = Listener{
.handlers = handlers,
.handlers = handlers.*,
.connection = connection,
.ssl = ssl_enabled,
.socket_context = socket_context,
@@ -837,6 +839,7 @@ pub const Listener = struct {
this.poll_ref.unref(this.handlers.vm);
std.debug.assert(this.listener == null);
std.debug.assert(this.handlers.active_connections == 0);
this.handlers.unprotect();
if (this.socket_context) |ctx| {
ctx.deinit(this.ssl);
@@ -925,8 +928,6 @@ pub const Listener = struct {
const ssl_enabled = ssl != null;
defer if (ssl != null) ssl.?.deinit();
handlers.protect();
const ctx_opts: uws.us_bun_socket_context_options_t = JSC.API.ServerConfig.SSLConfig.asUSockets(socket_config.ssl);
globalObject.bunVM().eventLoop().ensureWaker();
@@ -938,6 +939,7 @@ pub const Listener = struct {
.code = if (port == null) bun.String.static("ENOENT") else bun.String.static("ECONNREFUSED"),
};
exception.* = err.toErrorInstance(globalObject).asObjectRef();
handlers.unprotect();
return .zero;
};

View File

@@ -3651,6 +3651,7 @@ pub const ServerWebSocket = struct {
opened: bool = false,
pub usingnamespace JSC.Codegen.JSServerWebSocket;
pub usingnamespace bun.New(ServerWebSocket);
const log = Output.scoped(.WebSocketServer, false);
@@ -3958,7 +3959,7 @@ pub const ServerWebSocket = struct {
pub fn finalize(this: *ServerWebSocket) callconv(.C) void {
log("finalize", .{});
bun.default_allocator.destroy(this);
this.destroy();
}
pub fn publish(
@@ -5077,11 +5078,12 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp
resp.clearAborted();
const ws = this.vm.allocator.create(ServerWebSocket) catch return .zero;
ws.* = .{
data_value.ensureStillAlive();
const ws = ServerWebSocket.new(.{
.handler = &this.config.websocket.?.handler,
.this_value = data_value,
};
});
data_value.ensureStillAlive();
var sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator);
defer sec_websocket_protocol_str.deinit();

View File

@@ -2835,7 +2835,13 @@ pub inline fn destroyWithAlloc(allocator: std.mem.Allocator, t: anytype) void {
pub fn New(comptime T: type) type {
return struct {
const allocation_logger = Output.scoped(.alloc, @hasDecl(T, "logAllocations"));
pub inline fn destroy(self: *T) void {
if (comptime Environment.allow_assert) {
allocation_logger("destroy({*})", .{self});
}
if (comptime is_heap_breakdown_enabled) {
HeapBreakdown.allocator(T).destroy(self);
} else {
@@ -2847,11 +2853,18 @@ pub fn New(comptime T: type) type {
if (comptime is_heap_breakdown_enabled) {
const ptr = HeapBreakdown.allocator(T).create(T) catch outOfMemory();
ptr.* = t;
if (comptime Environment.allow_assert) {
allocation_logger("new() = {*}", .{ptr});
}
return ptr;
}
const ptr = default_allocator.create(T) catch outOfMemory();
ptr.* = t;
if (comptime Environment.allow_assert) {
allocation_logger("new() = {*}", .{ptr});
}
return ptr;
}
};
@@ -2874,9 +2887,12 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void)
}
return struct {
const allocation_logger = Output.scoped(.alloc, @hasDecl(T, "logAllocations"));
pub fn destroy(self: *T) void {
if (comptime Environment.allow_assert) {
std.debug.assert(self.ref_count == 0);
allocation_logger("destroy() = {*}", .{self});
}
if (comptime is_heap_breakdown_enabled) {
@@ -2909,6 +2925,7 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void)
if (comptime Environment.allow_assert) {
std.debug.assert(ptr.ref_count == 1);
allocation_logger("new() = {*}", .{ptr});
}
return ptr;
@@ -2919,6 +2936,7 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void)
if (comptime Environment.allow_assert) {
std.debug.assert(ptr.ref_count == 1);
allocation_logger("new() = {*}", .{ptr});
}
return ptr;

View File

@@ -764,6 +764,24 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
@field(holder, socket_field_name) = adopted;
return holder;
}
pub fn adoptPtr(
socket: *Socket,
socket_ctx: *SocketContext,
comptime Context: type,
comptime socket_field_name: []const u8,
ctx: *Context,
) bool {
var adopted = ThisSocket{ .socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, @sizeOf(*Context)) orelse return false };
const holder = adopted.ext(*anyopaque) orelse {
if (comptime bun.Environment.allow_assert) unreachable;
_ = us_socket_close(comptime ssl_int, socket, 0, null);
return false;
};
holder.* = ctx;
@field(ctx, socket_field_name) = adopted;
return true;
}
};
}

View File

@@ -196,7 +196,7 @@ const BodyBuf = BodyBufPool.Node;
pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
return struct {
pub const Socket = uws.NewSocketHandler(ssl);
tcp: Socket,
tcp: ?Socket = null,
outgoing_websocket: ?*CppWebSocket,
input_body_buf: []u8 = &[_]u8{},
client_protocol: []const u8 = "",
@@ -207,9 +207,11 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
websocket_protocol: u64 = 0,
hostname: [:0]const u8 = "",
poll_ref: Async.KeepAlive = Async.KeepAlive.init(),
pub const name = if (ssl) "WebSocketHTTPSClient" else "WebSocketHTTPClient";
pub const shim = JSC.Shimmer("Bun", name, @This());
pub usingnamespace bun.New(@This());
const HTTPClient = @This();
@@ -227,8 +229,8 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
Socket.configure(
ctx,
false,
HTTPClient,
true,
*HTTPClient,
struct {
pub const onOpen = handleOpen;
pub const onClose = handleClose;
@@ -272,12 +274,13 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
) catch return null;
var vm = global.bunVM();
var client: HTTPClient = HTTPClient{
var client = HTTPClient.new(.{
.tcp = undefined,
.outgoing_websocket = websocket,
.input_body_buf = body,
.websocket_protocol = client_protocol_hash,
};
});
var host_ = host.toSlice(bun.default_allocator);
defer host_.deinit();
const prev_start_server_on_next_tick = vm.eventLoop().start_server_on_next_tick;
@@ -289,7 +292,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
else
display_host_;
if (Socket.connect(
if (Socket.connectPtr(
display_host,
port,
@as(*uws.SocketContext, @ptrCast(socket_ctx)),
@@ -303,12 +306,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
}
}
out.tcp.timeout(120);
out.tcp.?.timeout(120);
return out;
}
vm.eventLoop().start_server_on_next_tick = prev_start_server_on_next_tick;
} else {
vm.eventLoop().start_server_on_next_tick = prev_start_server_on_next_tick;
client.clearData();
client.clearData();
client.destroy();
}
return null;
}
@@ -326,10 +331,14 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn cancel(this: *HTTPClient) callconv(.C) void {
this.clearData();
if (!this.tcp.isEstablished()) {
_ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), this.tcp.socket);
var tcp = this.tcp orelse return;
this.tcp = null;
if (!tcp.isEstablished()) {
_ = uws.us_socket_close_connecting(comptime @as(c_int, @intFromBool(ssl)), tcp.socket);
} else {
this.tcp.close(0, null);
tcp.shutdown();
tcp.close(0, null);
}
}
@@ -348,16 +357,20 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
log("onClose", .{});
JSC.markBinding(@src());
this.clearData();
this.tcp = null;
if (this.outgoing_websocket) |ws| {
this.outgoing_websocket = null;
ws.didAbruptClose(ErrorCode.ended);
}
this.destroy();
}
pub fn terminate(this: *HTTPClient, code: ErrorCode) void {
this.fail(code);
if (!this.tcp.isClosed())
this.tcp.close(0, null);
// We cannot access the pointer after fail is called.
}
pub fn handleHandshake(this: *HTTPClient, socket: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void {
@@ -388,7 +401,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn handleOpen(this: *HTTPClient, socket: Socket) void {
log("onOpen", .{});
std.debug.assert(socket.socket == this.tcp.socket);
std.debug.assert(socket.socket == this.tcp.?.socket);
std.debug.assert(this.input_body_buf.len > 0);
std.debug.assert(this.to_send.len == 0);
@@ -414,11 +427,11 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void {
log("onData", .{});
defer JSC.VirtualMachine.get().drainMicrotasks();
std.debug.assert(socket.socket == this.tcp.socket);
if (this.outgoing_websocket == null) {
this.clearData();
return;
}
std.debug.assert(socket.socket == this.tcp.?.socket);
if (comptime Environment.allow_assert)
std.debug.assert(!socket.isShutdown());
@@ -458,7 +471,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
pub fn handleEnd(this: *HTTPClient, socket: Socket) void {
log("onEnd", .{});
std.debug.assert(socket.socket == this.tcp.socket);
std.debug.assert(socket.socket == this.tcp.?.socket);
this.terminate(ErrorCode.ended);
}
@@ -570,17 +583,17 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.clearData();
JSC.markBinding(@src());
this.tcp.timeout(0);
this.tcp.?.timeout(0);
log("onDidConnect", .{});
this.outgoing_websocket.?.didConnect(this.tcp.socket, overflow.ptr, overflow.len);
this.outgoing_websocket.?.didConnect(this.tcp.?.socket, overflow.ptr, overflow.len);
}
pub fn handleWritable(
this: *HTTPClient,
socket: Socket,
) void {
std.debug.assert(socket.socket == this.tcp.socket);
std.debug.assert(socket.socket == this.tcp.?.socket);
if (this.to_send.len == 0)
return;
@@ -600,7 +613,9 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
this.terminate(ErrorCode.timeout);
}
pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void {
this.tcp = null;
this.terminate(ErrorCode.failed_to_connect);
this.destroy();
}
pub const Export = shim.exportFunctions(.{
@@ -923,6 +938,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
const WebSocket = @This();
pub usingnamespace bun.New(@This());
pub fn register(global: *JSC.JSGlobalObject, loop_: *anyopaque, ctx_: *anyopaque) callconv(.C) void {
const vm = global.bunVM();
const loop = @as(*uws.Loop, @ptrCast(@alignCast(loop_)));
@@ -937,8 +954,8 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
Socket.configure(
ctx,
false,
WebSocket,
true,
*WebSocket,
struct {
pub const onClose = handleClose;
pub const onData = handleData;
@@ -946,7 +963,6 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
pub const onTimeout = handleTimeout;
pub const onConnectError = handleConnectError;
pub const onEnd = handleEnd;
// just by adding it will fix ssl handshake
pub const onHandshake = handleHandshake;
},
);
@@ -1758,28 +1774,33 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
) callconv(.C) ?*anyopaque {
const tcp = @as(*uws.Socket, @ptrCast(input_socket));
const ctx = @as(*uws.SocketContext, @ptrCast(socket_ctx));
var adopted = Socket.adopt(
var ws = WebSocket.new(WebSocket{
.tcp = undefined,
.outgoing_websocket = outgoing,
.globalThis = globalThis,
.send_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator),
.receive_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator),
});
if (!Socket.adoptPtr(
tcp,
ctx,
WebSocket,
"tcp",
WebSocket{
.tcp = undefined,
.outgoing_websocket = outgoing,
.globalThis = globalThis,
.send_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator),
.receive_buffer = bun.LinearFifo(u8, .Dynamic).init(bun.default_allocator),
},
) orelse return null;
adopted.send_buffer.ensureTotalCapacity(2048) catch return null;
adopted.receive_buffer.ensureTotalCapacity(2048) catch return null;
adopted.poll_ref.ref(globalThis.bunVM());
ws,
)) {
ws.destroy();
return null;
}
ws.send_buffer.ensureTotalCapacity(2048) catch return null;
ws.receive_buffer.ensureTotalCapacity(2048) catch return null;
ws.poll_ref.ref(globalThis.bunVM());
const buffered_slice: []u8 = buffered_data[0..buffered_data_len];
if (buffered_slice.len > 0) {
const initial_data = bun.default_allocator.create(InitialDataHandler) catch unreachable;
initial_data.* = .{
.adopted = adopted,
.adopted = ws,
.slice = buffered_slice,
.ws = outgoing,
};
@@ -1793,7 +1814,7 @@ pub fn NewWebSocketClient(comptime ssl: bool) type {
}
return @as(
*anyopaque,
@ptrCast(adopted),
@ptrCast(ws),
);
}

View File

@@ -258,11 +258,10 @@ describe("WebSocket", () => {
});
});
it("should connect over http", done => {
it("should FAIL to connect over http when the status code is invalid", done => {
const server = Bun.serve({
port: 0,
fetch(req, server) {
done();
server.stop();
return new Response();
},
@@ -271,9 +270,45 @@ describe("WebSocket", () => {
message(ws) {
ws.close();
},
close() {},
},
});
new WebSocket(`http://${server.hostname}:${server.port}`, {});
var ws = new WebSocket(`http://${server.hostname}:${server.port}`, {});
ws.onopen = () => {
ws.send("Hello World!");
};
ws.onclose = e => {
expect(e.code).toBe(1002);
done();
};
});
it("should connect over http ", done => {
const server = Bun.serve({
port: 0,
fetch(req, server) {
server.upgrade(req);
server.stop();
return new Response();
},
websocket: {
open(ws) {},
message(ws) {
ws.close();
},
close() {},
},
});
var ws = new WebSocket(`http://${server.hostname}:${server.port}`, {});
ws.onopen = () => {
ws.send("Hello World!");
};
ws.onclose = () => {
done();
};
});
describe("nodebuffer", () => {
it("should support 'nodebuffer' binaryType", done => {