Fix abort handler in "ws" polyfill (#21867)

### What does this PR do?

This does two things:
1. Fix an ASAN use-after-poison on macOS involving `ws` module when
running websocket.test.js. This was caused by the `open` callback firing
before the `.upgrade` function call returns. We need to update the
`socket` value on the ServerWebSocket to ensure the `NodeHTTPResponse`
object is kept alive for as long as it should be, but the `us_socket_t`
address can, in theory, change due to `realloc` being used when adopting
the socket.
2. Fixes an "undefined is not a function" error when the websocket
upgrade fails. This occurred because the `_httpMessage` property is not
set when a socket is upgraded

### How did you verify your code works?

There is a test and the asan error no longer triggers

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jarred Sumner
2025-08-14 16:00:03 -07:00
committed by GitHub
parent 7b31393d44
commit ff372f44cb
6 changed files with 192 additions and 82 deletions

View File

@@ -122,13 +122,37 @@ pub fn getServerSocketValue(this: *NodeHTTPResponse) jsc.JSValue {
pub fn pauseSocket(this: *NodeHTTPResponse) void {
log("pauseSocket", .{});
if (this.flags.socket_closed or this.flags.upgraded) {
return;
}
this.raw_response.pause();
}
pub fn resumeSocket(this: *NodeHTTPResponse) void {
log("resumeSocket", .{});
if (this.flags.socket_closed or this.flags.upgraded) {
return;
}
this.raw_response.@"resume"();
}
const OnBeforeOpen = struct {
this: *NodeHTTPResponse,
socketValue: jsc.JSValue,
globalObject: *jsc.JSGlobalObject,
pub fn onBeforeOpen(ctx: *OnBeforeOpen, js_websocket: JSValue, socket: *uws.RawWebSocket) void {
Bun__setNodeHTTPServerSocketUsSocketValue(ctx.socketValue, socket.asSocket());
ServerWebSocket.js.gc.socket.set(js_websocket, ctx.globalObject, ctx.socketValue);
ctx.this.flags.upgraded = true;
defer ctx.this.js_ref.unref(ctx.globalObject.bunVM());
switch (ctx.this.raw_response) {
.SSL => ctx.this.raw_response = uws.AnyResponse.init(uws.NewApp(true).Response.castRes(@alignCast(@ptrCast(socket)))),
.TCP => ctx.this.raw_response = uws.AnyResponse.init(uws.NewApp(false).Response.castRes(@alignCast(@ptrCast(socket)))),
}
}
};
pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_protocol: ZigString, sec_websocket_extensions: ZigString) bool {
const upgrade_ctx = this.upgrade_context.context orelse return false;
const ws_handler = this.server.webSocketHandler() orelse return false;
@@ -149,61 +173,18 @@ pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_proto
.this_value = data_value,
});
var new_socket: ?*uws.Socket = null;
defer if (new_socket) |socket| {
this.flags.upgraded = true;
Bun__setNodeHTTPServerSocketUsSocketValue(socketValue, socket);
ServerWebSocket.js.socketSetCached(ws.getThisValue(), ws_handler.globalObject, socketValue);
defer this.js_ref.unref(jsc.VirtualMachine.get());
switch (this.raw_response) {
.SSL => this.raw_response = uws.AnyResponse.init(uws.NewApp(true).Response.castRes(@alignCast(@ptrCast(socket)))),
.TCP => this.raw_response = uws.AnyResponse.init(uws.NewApp(false).Response.castRes(@alignCast(@ptrCast(socket)))),
}
};
if (this.upgrade_context.request) |request| {
this.upgrade_context = .{};
var sec_websocket_protocol_str: ?ZigString.Slice = null;
var sec_websocket_extensions_str: ?ZigString.Slice = null;
const sec_websocket_protocol_value = brk: {
if (sec_websocket_protocol.isEmpty()) {
break :brk request.header("sec-websocket-protocol") orelse "";
}
sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator);
break :brk sec_websocket_protocol_str.?.slice();
};
const sec_websocket_extensions_value = brk: {
if (sec_websocket_extensions.isEmpty()) {
break :brk request.header("sec-websocket-extensions") orelse "";
}
sec_websocket_extensions_str = sec_websocket_protocol.toSlice(bun.default_allocator);
break :brk sec_websocket_extensions_str.?.slice();
};
defer {
if (sec_websocket_protocol_str) |str| str.deinit();
if (sec_websocket_extensions_str) |str| str.deinit();
}
new_socket = this.raw_response.upgrade(
*ServerWebSocket,
ws,
request.header("sec-websocket-key") orelse "",
sec_websocket_protocol_value,
sec_websocket_extensions_value,
upgrade_ctx,
);
return true;
}
var sec_websocket_protocol_str: ?ZigString.Slice = null;
defer if (sec_websocket_protocol_str) |*str| str.deinit();
var sec_websocket_extensions_str: ?ZigString.Slice = null;
defer if (sec_websocket_extensions_str) |*str| str.deinit();
const sec_websocket_protocol_value = brk: {
if (sec_websocket_protocol.isEmpty()) {
break :brk this.upgrade_context.sec_websocket_protocol;
if (this.upgrade_context.request) |request| {
break :brk request.header("sec-websocket-protocol") orelse "";
} else {
break :brk this.upgrade_context.sec_websocket_protocol;
}
}
sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator);
break :brk sec_websocket_protocol_str.?.slice();
@@ -211,35 +192,48 @@ pub fn upgrade(this: *NodeHTTPResponse, data_value: JSValue, sec_websocket_proto
const sec_websocket_extensions_value = brk: {
if (sec_websocket_extensions.isEmpty()) {
break :brk this.upgrade_context.sec_websocket_extensions;
if (this.upgrade_context.request) |request| {
break :brk request.header("sec-websocket-extensions") orelse "";
} else {
break :brk this.upgrade_context.sec_websocket_extensions;
}
}
sec_websocket_extensions_str = sec_websocket_protocol.toSlice(bun.default_allocator);
sec_websocket_extensions_str = sec_websocket_extensions.toSlice(bun.default_allocator);
break :brk sec_websocket_extensions_str.?.slice();
};
defer {
if (sec_websocket_protocol_str) |str| str.deinit();
if (sec_websocket_extensions_str) |str| str.deinit();
}
new_socket = this.raw_response.upgrade(
*ServerWebSocket,
ws,
this.upgrade_context.sec_websocket_key,
sec_websocket_protocol_value,
sec_websocket_extensions_value,
upgrade_ctx,
);
const websocket_key = if (this.upgrade_context.request) |request|
request.header("sec-websocket-key") orelse ""
else
this.upgrade_context.sec_websocket_key;
var on_before_open = OnBeforeOpen{
.this = this,
.socketValue = socketValue,
.globalObject = this.server.globalThis(),
};
var on_before_open_ptr = WebSocketServerContext.Handler.OnBeforeOpen{
.ctx = &on_before_open,
.callback = @ptrCast(&OnBeforeOpen.onBeforeOpen),
};
this.server.webSocketHandler().?.onBeforeOpen = &on_before_open_ptr;
_ = this.raw_response.upgrade(*ServerWebSocket, ws, websocket_key, sec_websocket_protocol_value, sec_websocket_extensions_value, upgrade_ctx);
return true;
}
pub fn maybeStopReadingBody(this: *NodeHTTPResponse, vm: *jsc.VirtualMachine, thisValue: jsc.JSValue) void {
this.upgrade_context.deinit(); // we can discard the upgrade context now
if ((this.flags.socket_closed or this.flags.ended) and
if ((this.flags.upgraded or this.flags.socket_closed or this.flags.ended) and
(this.body_read_ref.has or this.body_read_state == .pending) and
(!this.flags.hasCustomOnData or js.onDataGetCached(thisValue) == null))
{
const had_ref = this.body_read_ref.has;
this.raw_response.clearOnData();
if (!this.flags.upgraded and !this.flags.socket_closed) {
this.raw_response.clearOnData();
}
this.body_read_ref.unref(vm);
this.body_read_state = .done;
@@ -578,7 +572,7 @@ pub fn onTimeout(this: *NodeHTTPResponse, _: uws.AnyResponse) void {
pub fn doPause(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame, thisValue: jsc.JSValue) bun.JSError!jsc.JSValue {
log("doPause", .{});
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended or this.flags.upgraded) {
return .false;
}
if (this.body_read_ref.has and js.onDataGetCached(thisValue) == null) {
@@ -608,7 +602,7 @@ fn drainBufferedRequestBodyFromPause(this: *NodeHTTPResponse, globalObject: *jsc
pub fn doResume(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, _: *jsc.CallFrame) jsc.JSValue {
log("doResume", .{});
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.ended or this.flags.upgraded) {
return .false;
}
@@ -671,7 +665,7 @@ pub export fn Bun__NodeHTTPRequest__onReject(globalObject: *jsc.JSGlobalObject,
defer this.deref();
if (!this.flags.request_has_completed and !this.flags.socket_closed) {
if (!this.flags.request_has_completed and !this.flags.socket_closed and !this.flags.upgraded) {
const this_value = this.getThisValue();
if (this_value != .zero) {
js.onAbortedSetCached(this_value, globalObject, .zero);
@@ -787,7 +781,7 @@ fn onDrain(this: *NodeHTTPResponse, offset: u64, response: uws.AnyResponse) bool
this.ref();
defer this.deref();
response.clearOnWritable();
if (this.flags.socket_closed or this.flags.request_has_completed) {
if (this.flags.socket_closed or this.flags.request_has_completed or this.flags.upgraded) {
// return false means we don't have anything to drain
return false;
}
@@ -963,14 +957,14 @@ pub fn getOnWritable(_: *NodeHTTPResponse, thisValue: jsc.JSValue, _: *jsc.JSGlo
}
pub fn getOnAbort(this: *NodeHTTPResponse, thisValue: jsc.JSValue, _: *jsc.JSGlobalObject) jsc.JSValue {
if (this.flags.socket_closed) {
if (this.flags.socket_closed or this.flags.upgraded) {
return .js_undefined;
}
return js.onAbortedGetCached(thisValue) orelse .js_undefined;
}
pub fn setOnAbort(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject, value: JSValue) void {
if (this.flags.socket_closed) {
if (this.flags.socket_closed or this.flags.upgraded) {
return;
}
@@ -1002,7 +996,7 @@ fn clearOnDataCallback(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalOb
if (thisValue != .zero) {
js.onDataSetCached(thisValue, globalObject, .js_undefined);
}
if (!this.flags.socket_closed)
if (!this.flags.socket_closed and !this.flags.upgraded)
this.raw_response.clearOnData();
if (this.body_read_state != .done) {
this.body_read_state = .done;
@@ -1011,7 +1005,7 @@ fn clearOnDataCallback(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalOb
}
pub fn setOnData(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject, value: JSValue) void {
if (value.isUndefined() or this.flags.ended or this.flags.socket_closed or this.body_read_state == .none or this.flags.is_data_buffered_during_pause_last) {
if (value.isUndefined() or this.flags.ended or this.flags.socket_closed or this.body_read_state == .none or this.flags.is_data_buffered_during_pause_last or this.flags.upgraded) {
js.onDataSetCached(thisValue, globalObject, .js_undefined);
defer {
if (this.body_read_ref.has) {
@@ -1020,7 +1014,7 @@ pub fn setOnData(this: *NodeHTTPResponse, thisValue: jsc.JSValue, globalObject:
}
switch (this.body_read_state) {
.pending, .done => {
if (!this.flags.request_has_completed and !this.flags.socket_closed) {
if (!this.flags.request_has_completed and !this.flags.socket_closed and !this.flags.upgraded) {
this.raw_response.clearOnData();
}
this.body_read_state = .done;
@@ -1048,7 +1042,7 @@ pub fn write(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, callfra
}
pub fn flushHeaders(this: *NodeHTTPResponse, _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!jsc.JSValue {
if (!this.flags.socket_closed)
if (!this.flags.socket_closed and !this.flags.upgraded)
this.raw_response.flushHeaders();
return .js_undefined;
@@ -1074,7 +1068,7 @@ fn handleCorked(globalObject: *jsc.JSGlobalObject, function: jsc.JSValue, result
}
pub fn setTimeout(this: *NodeHTTPResponse, seconds: u8) void {
if (this.flags.request_has_completed or this.flags.socket_closed) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) {
return;
}
@@ -1087,7 +1081,7 @@ export fn NodeHTTPResponse__setTimeout(this: *NodeHTTPResponse, seconds: jsc.JSV
return false;
}
if (this.flags.request_has_completed or this.flags.socket_closed) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) {
return false;
}
@@ -1105,7 +1099,7 @@ pub fn cork(this: *NodeHTTPResponse, globalObject: *jsc.JSGlobalObject, callfram
return globalObject.throwInvalidArgumentTypeValue("cork", "function", arguments[0]);
}
if (this.flags.request_has_completed or this.flags.socket_closed) {
if (this.flags.request_has_completed or this.flags.socket_closed or this.flags.upgraded) {
return globalObject.ERR(.STREAM_ALREADY_FINISHED, "Stream is already ended", .{}).throw();
}
@@ -1163,6 +1157,7 @@ pub export fn Bun__NodeHTTPResponse_setClosed(response: *NodeHTTPResponse) void
const string = []const u8;
const WebSocketServerContext = @import("./WebSocketServerContext.zig");
const std = @import("std");
const bun = @import("bun");

View File

@@ -73,9 +73,20 @@ pub fn onOpen(this: *ServerWebSocket, ws: uws.AnyWebSocket) void {
js.dataSetCached(current_this, globalObject, value_to_cache);
}
if (onOpenHandler.isEmptyOrUndefinedOrNull()) return;
if (onOpenHandler.isEmptyOrUndefinedOrNull()) {
if (bun.take(&this.handler.onBeforeOpen)) |on_before_open| {
// Only create the "this" value if needed.
const this_value = this.getThisValue();
on_before_open.callback(on_before_open.ctx, this_value, ws.raw());
}
return;
}
const this_value = this.getThisValue();
var args = [_]JSValue{this_value};
if (bun.take(&this.handler.onBeforeOpen)) |on_before_open| {
on_before_open.callback(on_before_open.ctx, this_value, ws.raw());
}
const loop = vm.eventLoop();
loop.enter();

View File

@@ -28,12 +28,25 @@ pub const Handler = struct {
globalObject: *jsc.JSGlobalObject = undefined,
active_connections: usize = 0,
/// Only used by NodeHTTPResponse.
///
/// Before we call into JavaScript and after the WebSocket is upgraded, we need to call a function in NodeHTTPResponse.
///
/// This is per-ServerWebSocket data, so it needs to be null'd on usage.
onBeforeOpen: ?*OnBeforeOpen = null,
/// used by publish()
flags: packed struct(u2) {
flags: packed struct(u8) {
ssl: bool = false,
publish_to_self: bool = false,
_: u6 = 0,
} = .{},
pub const OnBeforeOpen = struct {
ctx: *anyopaque,
callback: *const fn (*anyopaque, this_value: jsc.JSValue, socket: *uws.RawWebSocket) void,
};
pub fn runErrorCallback(this: *const Handler, vm: *jsc.VirtualMachine, globalObject: *jsc.JSGlobalObject, error_value: jsc.JSValue) void {
const onError = this.onError;
if (!onError.isEmptyOrUndefinedOrNull()) {

View File

@@ -72,6 +72,15 @@ pub const RawWebSocket = opaque {
pub fn memoryCost(this: *RawWebSocket, ssl_flag: i32) usize {
return c.uws_ws_memory_cost(ssl_flag, this);
}
/// They're the same memory address.
///
/// Equivalent to:
///
/// (struct us_socket_t *)socket
pub fn asSocket(this: *RawWebSocket) *uws.Socket {
return @as(*uws.Socket, @ptrCast(this));
}
};
pub const AnyWebSocket = union(enum) {

View File

@@ -1288,7 +1288,7 @@ class WebSocketServer extends EventEmitter {
*/
handleUpgrade(req, socket, head, cb) {
// socket is actually fake so we use internal http_res
const response = socket._httpMessage;
const response = socket._httpMessage || socket[kBunInternals];
// socket.on("error", socketOnError);

View File

@@ -553,6 +553,88 @@ it("WebSocketServer should handle backpressure", async () => {
}
});
it("should abort incorrect WebSocket handshake", async () => {
const { promise, resolve, reject } = Promise.withResolvers<void>();
const wss = new WebSocketServer({ port: 0 });
let connectionAttempted = false;
let testResolved = false;
wss.on("connection", () => {
connectionAttempted = true;
if (!testResolved) {
testResolved = true;
reject(new Error("Connection should not have been established"));
}
});
wss.on("error", error => {
// Server errors are expected for invalid handshakes
console.log("Server error (expected):", error.message);
});
try {
const net = require("node:net");
const port = (wss.address() as any).port;
const socket = net.createConnection(port, "localhost");
socket.on("connect", () => {
// Send an invalid WebSocket handshake request (invalid Sec-WebSocket-Key)
const invalidRequest = [
"GET / HTTP/1.1",
"Host: localhost",
"Connection: Upgrade",
"Upgrade: websocket",
"Sec-WebSocket-Key: invalid-key", // Invalid key format
"Sec-WebSocket-Version: 13",
"",
"",
].join("\r\n");
socket.write(invalidRequest);
});
let responseReceived = false;
socket.on("data", data => {
const response = data.toString();
responseReceived = true;
// Should receive a 400 Bad Request response for invalid handshake
if (response.includes("400") && !testResolved) {
testResolved = true;
resolve();
} else if (!testResolved) {
testResolved = true;
reject(new Error(`Expected 400 response, got: ${response}`));
}
socket.end();
});
socket.on("error", error => {
// Connection errors are also acceptable as the server may close the connection
if (!testResolved) {
testResolved = true;
resolve();
}
});
socket.on("close", () => {
// If we reach here without getting a proper response and connection wasn't attempted,
// the server properly rejected the invalid handshake
if (!responseReceived && !connectionAttempted && !testResolved) {
testResolved = true;
resolve();
}
});
await promise;
} finally {
wss.close();
}
expect(connectionAttempted).toBeFalse();
expect(testResolved).toBeTrue();
});
it("Server should be able to send empty pings", async () => {
// WebSocket frame creation function with masking
function createWebSocketFrame(message: string) {