ws: implement upgrade and unexpected-response events

Implements support for the `upgrade` and `unexpected-response` events in
the `ws` package polyfill. This enables Playwright's `chromium.connectOverCDP()`
and other tools that rely on these events to work correctly with Bun.

Changes:
- Add `upgradeStatusCode` property to native WebSocket that stores the HTTP
  status code from the upgrade handshake
- Pass the status code from the HTTP upgrade response through Zig to C++
- Update ws.js polyfill to emit `upgrade` event before `open` event with
  the actual status code from the native WebSocket
- Emit `unexpected-response` event on connection errors for compatibility
- Add TypeScript types for the new `upgradeStatusCode` property
- Add regression tests for the new events

The `upgrade` event provides a response object with `statusCode`,
`statusMessage`, and `headers` properties. Headers are currently empty
but can be populated in a future enhancement if needed.

Fixes #9911

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Claude Bot
2026-01-27 07:37:09 +00:00
parent bfe40e8760
commit 736d9917ef
8 changed files with 310 additions and 15 deletions

View File

@@ -3655,6 +3655,12 @@ declare module "bun" {
*/
readonly bufferedAmount: number;
/**
* The HTTP status code from the WebSocket upgrade handshake (typically 101).
* This is a Bun extension to support the ws package's 'upgrade' event.
*/
readonly upgradeStatusCode: number;
/**
* The protocol selected by the server
*/

View File

@@ -85,6 +85,7 @@ static JSC_DECLARE_CUSTOM_GETTER(jsWebSocket_URL);
static JSC_DECLARE_CUSTOM_GETTER(jsWebSocket_url);
static JSC_DECLARE_CUSTOM_GETTER(jsWebSocket_readyState);
static JSC_DECLARE_CUSTOM_GETTER(jsWebSocket_bufferedAmount);
static JSC_DECLARE_CUSTOM_GETTER(jsWebSocket_upgradeStatusCode);
static JSC_DECLARE_CUSTOM_GETTER(jsWebSocket_onopen);
static JSC_DECLARE_CUSTOM_SETTER(setJSWebSocket_onopen);
static JSC_DECLARE_CUSTOM_GETTER(jsWebSocket_onmessage);
@@ -382,6 +383,7 @@ static const HashTableValue JSWebSocketPrototypeTableValues[] = {
{ "url"_s, static_cast<unsigned>(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute), NoIntrinsic, { HashTableValue::GetterSetterType, jsWebSocket_url, 0 } },
{ "readyState"_s, static_cast<unsigned>(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute), NoIntrinsic, { HashTableValue::GetterSetterType, jsWebSocket_readyState, 0 } },
{ "bufferedAmount"_s, static_cast<unsigned>(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute), NoIntrinsic, { HashTableValue::GetterSetterType, jsWebSocket_bufferedAmount, 0 } },
{ "upgradeStatusCode"_s, static_cast<unsigned>(JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute), NoIntrinsic, { HashTableValue::GetterSetterType, jsWebSocket_upgradeStatusCode, 0 } },
{ "onopen"_s, static_cast<unsigned>(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute), NoIntrinsic, { HashTableValue::GetterSetterType, jsWebSocket_onopen, setJSWebSocket_onopen } },
{ "onmessage"_s, static_cast<unsigned>(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute), NoIntrinsic, { HashTableValue::GetterSetterType, jsWebSocket_onmessage, setJSWebSocket_onmessage } },
{ "onerror"_s, static_cast<unsigned>(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::DOMAttribute), NoIntrinsic, { HashTableValue::GetterSetterType, jsWebSocket_onerror, setJSWebSocket_onerror } },
@@ -501,6 +503,19 @@ JSC_DEFINE_CUSTOM_GETTER(jsWebSocket_bufferedAmount, (JSGlobalObject * lexicalGl
return IDLAttribute<JSWebSocket>::get<jsWebSocket_bufferedAmountGetter, CastedThisErrorBehavior::Assert>(*lexicalGlobalObject, thisValue, attributeName);
}
static inline JSValue jsWebSocket_upgradeStatusCodeGetter(JSGlobalObject& lexicalGlobalObject, JSWebSocket& thisObject)
{
auto& vm = JSC::getVM(&lexicalGlobalObject);
auto throwScope = DECLARE_THROW_SCOPE(vm);
auto& impl = thisObject.wrapped();
RELEASE_AND_RETURN(throwScope, (toJS<IDLUnsignedShort>(lexicalGlobalObject, throwScope, impl.upgradeStatusCode())));
}
JSC_DEFINE_CUSTOM_GETTER(jsWebSocket_upgradeStatusCode, (JSGlobalObject * lexicalGlobalObject, JSC::EncodedJSValue thisValue, PropertyName attributeName))
{
return IDLAttribute<JSWebSocket>::get<jsWebSocket_upgradeStatusCodeGetter, CastedThisErrorBehavior::Assert>(*lexicalGlobalObject, thisValue, attributeName);
}
static inline JSValue jsWebSocket_onopenGetter(JSGlobalObject& lexicalGlobalObject, JSWebSocket& thisObject)
{
UNUSED_PARAM(lexicalGlobalObject);

View File

@@ -1482,9 +1482,10 @@ void WebSocket::didClose(unsigned unhandledBufferedAmount, unsigned short code,
this->disablePendingActivity();
}
void WebSocket::didConnect(us_socket_t* socket, char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params, void* customSSLCtx)
void WebSocket::didConnect(us_socket_t* socket, char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params, void* customSSLCtx, uint16_t upgradeStatusCode)
{
this->m_upgradeClient = nullptr;
this->m_upgradeStatusCode = upgradeStatusCode;
setExtensionsFromDeflateParams(deflate_params);
// Use TLS WebSocket client if connection type is TLS or ProxyTLS.
@@ -1696,9 +1697,10 @@ void WebSocket::updateHasPendingActivity()
extern "C" void* Bun__WebSocketClient__initWithTunnel(CppWebSocket* ws, void* tunnel, JSC::JSGlobalObject* globalObject, unsigned char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params);
extern "C" void WebSocketProxyTunnel__setConnectedWebSocket(void* tunnel, void* websocket);
void WebSocket::didConnectWithTunnel(void* tunnel, char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params)
void WebSocket::didConnectWithTunnel(void* tunnel, char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params, uint16_t upgradeStatusCode)
{
this->m_upgradeClient = nullptr;
this->m_upgradeStatusCode = upgradeStatusCode;
setExtensionsFromDeflateParams(deflate_params);
// For wss:// through HTTP proxy, we use a plain (non-TLS) WebSocket client
@@ -1724,14 +1726,14 @@ void WebSocket::didConnectWithTunnel(void* tunnel, char* bufferedData, size_t bu
} // namespace WebCore
extern "C" void WebSocket__didConnect(WebCore::WebSocket* webSocket, us_socket_t* socket, char* bufferedData, size_t len, const PerMessageDeflateParams* deflate_params, void* customSSLCtx)
extern "C" void WebSocket__didConnect(WebCore::WebSocket* webSocket, us_socket_t* socket, char* bufferedData, size_t len, const PerMessageDeflateParams* deflate_params, void* customSSLCtx, uint16_t upgradeStatusCode)
{
webSocket->didConnect(socket, bufferedData, len, deflate_params, customSSLCtx);
webSocket->didConnect(socket, bufferedData, len, deflate_params, customSSLCtx, upgradeStatusCode);
}
extern "C" void WebSocket__didConnectWithTunnel(WebCore::WebSocket* webSocket, void* tunnel, char* bufferedData, size_t len, const PerMessageDeflateParams* deflate_params)
extern "C" void WebSocket__didConnectWithTunnel(WebCore::WebSocket* webSocket, void* tunnel, char* bufferedData, size_t len, const PerMessageDeflateParams* deflate_params, uint16_t upgradeStatusCode)
{
webSocket->didConnectWithTunnel(tunnel, bufferedData, len, deflate_params);
webSocket->didConnectWithTunnel(tunnel, bufferedData, len, deflate_params, upgradeStatusCode);
}
extern "C" void WebSocket__didAbruptClose(WebCore::WebSocket* webSocket, Bun::WebSocketErrorCode errorCode)

View File

@@ -142,6 +142,8 @@ public:
String binaryType() const;
ExceptionOr<void> setBinaryType(const String&);
uint16_t upgradeStatusCode() const { return m_upgradeStatusCode; }
ScriptExecutionContext* scriptExecutionContext() const final;
using RefCounted::deref;
@@ -149,8 +151,8 @@ public:
void didConnect();
void disablePendingActivity();
void didClose(unsigned unhandledBufferedAmount, unsigned short code, const String& reason);
void didConnect(us_socket_t* socket, char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params, void* customSSLCtx);
void didConnectWithTunnel(void* tunnel, char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params);
void didConnect(us_socket_t* socket, char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params, void* customSSLCtx, uint16_t upgradeStatusCode);
void didConnectWithTunnel(void* tunnel, char* bufferedData, size_t bufferedDataSize, const PerMessageDeflateParams* deflate_params, uint16_t upgradeStatusCode);
void didFailWithErrorCode(Bun::WebSocketErrorCode code);
void didReceiveMessage(String&& message);
@@ -248,6 +250,7 @@ private:
String m_subprotocol;
String m_extensions;
void* m_upgradeClient { nullptr };
uint16_t m_upgradeStatusCode { 0 };
ConnectionType m_connectionType { ConnectionType::Plain };
bool m_rejectUnauthorized { false };
AnyWebSocket m_connectedWebSocket { nullptr };

View File

@@ -15,6 +15,7 @@ pub const CppWebSocket = opaque {
buffered_len: usize,
deflate_params: ?*const WebSocketDeflate.Params,
custom_ssl_ctx: ?*uws.SocketContext,
upgrade_status_code: u16,
) void;
extern fn WebSocket__didConnectWithTunnel(
websocket_context: *CppWebSocket,
@@ -22,6 +23,7 @@ pub const CppWebSocket = opaque {
buffered_data: ?[*]u8,
buffered_len: usize,
deflate_params: ?*const WebSocketDeflate.Params,
upgrade_status_code: u16,
) void;
extern fn WebSocket__didAbruptClose(websocket_context: *CppWebSocket, reason: ErrorCode) void;
extern fn WebSocket__didClose(websocket_context: *CppWebSocket, code: u16, reason: *const bun.String) void;
@@ -58,17 +60,17 @@ pub const CppWebSocket = opaque {
defer loop.exit();
return WebSocket__rejectUnauthorized(this);
}
pub fn didConnect(this: *CppWebSocket, socket: *uws.Socket, buffered_data: ?[*]u8, buffered_len: usize, deflate_params: ?*const WebSocketDeflate.Params, custom_ssl_ctx: ?*uws.SocketContext) void {
pub fn didConnect(this: *CppWebSocket, socket: *uws.Socket, buffered_data: ?[*]u8, buffered_len: usize, deflate_params: ?*const WebSocketDeflate.Params, custom_ssl_ctx: ?*uws.SocketContext, upgrade_status_code: u16) void {
const loop = jsc.VirtualMachine.get().eventLoop();
loop.enter();
defer loop.exit();
WebSocket__didConnect(this, socket, buffered_data, buffered_len, deflate_params, custom_ssl_ctx);
WebSocket__didConnect(this, socket, buffered_data, buffered_len, deflate_params, custom_ssl_ctx, upgrade_status_code);
}
pub fn didConnectWithTunnel(this: *CppWebSocket, tunnel: *anyopaque, buffered_data: ?[*]u8, buffered_len: usize, deflate_params: ?*const WebSocketDeflate.Params) void {
pub fn didConnectWithTunnel(this: *CppWebSocket, tunnel: *anyopaque, buffered_data: ?[*]u8, buffered_len: usize, deflate_params: ?*const WebSocketDeflate.Params, upgrade_status_code: u16) void {
const loop = jsc.VirtualMachine.get().eventLoop();
loop.enter();
defer loop.exit();
WebSocket__didConnectWithTunnel(this, tunnel, buffered_data, buffered_len, deflate_params);
WebSocket__didConnectWithTunnel(this, tunnel, buffered_data, buffered_len, deflate_params, upgrade_status_code);
}
extern fn WebSocket__incrementPendingActivity(websocket_context: *CppWebSocket) void;
extern fn WebSocket__decrementPendingActivity(websocket_context: *CppWebSocket) void;

View File

@@ -953,7 +953,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
const ws = bun.take(&this.outgoing_websocket).?;
// Create the WebSocket client with the tunnel
ws.didConnectWithTunnel(tunnel, overflow.ptr, overflow.len, if (deflate_result.enabled) &deflate_result.params else null);
ws.didConnectWithTunnel(tunnel, overflow.ptr, overflow.len, if (deflate_result.enabled) &deflate_result.params else null, @intCast(response.status_code));
// Switch state to connected - handleData will forward to tunnel
this.state = .done;
@@ -987,7 +987,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
// Once again for the TCP socket.
defer this.deref();
if (socket.socket.get()) |native_socket| {
ws.didConnect(native_socket, overflow.ptr, overflow.len, if (deflate_result.enabled) &deflate_result.params else null, saved_custom_ssl_ctx);
ws.didConnect(native_socket, overflow.ptr, overflow.len, if (deflate_result.enabled) &deflate_result.params else null, saved_custom_ssl_ctx, @intCast(response.status_code));
} else {
this.terminate(ErrorCode.failed_to_connect);
}

View File

@@ -126,6 +126,8 @@ class BunWebSocket extends EventEmitter {
#binaryType = "nodebuffer";
// Bitset to track whether event handlers are set.
#eventId = 0;
// Track whether we've set up upgrade event emission
#upgradeEmitterSet = false;
constructor(url, protocols, options) {
super();
@@ -260,8 +262,27 @@ class BunWebSocket extends EventEmitter {
return ws;
}
#setupUpgradeEmitter() {
// Set up upgrade event emission only once
if (this.#upgradeEmitterSet) return;
this.#upgradeEmitterSet = true;
this.#ws.addEventListener("open", () => {
// Emit upgrade event before open event if there are upgrade listeners
if (this.listenerCount("upgrade") > 0) {
const statusCode = this.#ws.upgradeStatusCode || 101;
const response = {
statusCode: statusCode,
statusMessage: "Switching Protocols",
headers: {},
};
this.emit("upgrade", response);
}
});
}
#onOrOnce(event, listener, once) {
if (event === "unexpected-response" || event === "upgrade" || event === "redirect") {
if (event === "redirect") {
emitWarning(event, "ws.WebSocket '" + event + "' event is not implemented in bun");
}
const mask = 1 << eventIds[event];
@@ -276,6 +297,8 @@ class BunWebSocket extends EventEmitter {
this.#eventId |= mask;
}
if (event === "open") {
// Set up upgrade emitter so upgrade events fire before open
this.#setupUpgradeEmitter();
this.#ws.addEventListener(
"open",
() => {
@@ -283,6 +306,31 @@ class BunWebSocket extends EventEmitter {
},
once,
);
} else if (event === "upgrade") {
// The 'upgrade' event is emitted when the WebSocket handshake completes successfully.
// Set up the upgrade emitter to fire on the native 'open' event.
this.#setupUpgradeEmitter();
} else if (event === "unexpected-response") {
// The 'unexpected-response' event is emitted when the server responds with
// a non-101 status code during the handshake. We emit this on 'error' events.
this.#ws.addEventListener(
"error",
err => {
// Create mock request/response objects for compatibility
const mockRequest = {
method: "GET",
url: this.#ws.url,
headers: {},
};
const mockResponse = {
statusCode: 0,
statusMessage: err?.message || "Connection failed",
headers: {},
};
this.emit("unexpected-response", mockRequest, mockResponse);
},
once,
);
} else if (event === "close") {
this.#ws.addEventListener(
"close",

View File

@@ -0,0 +1,219 @@
import { describe, expect, test } from "bun:test";
import { bunEnv, bunExe } from "harness";
describe("ws upgrade and unexpected-response events (#9911)", () => {
test("ws WebSocket should not emit warnings for upgrade event", async () => {
using server = Bun.serve({
port: 0,
fetch(req, server) {
if (server.upgrade(req)) {
return;
}
return new Response("Not found", { status: 404 });
},
websocket: {
open() {},
message() {},
close() {},
},
});
await using proc = Bun.spawn({
cmd: [
bunExe(),
"-e",
`import WebSocket from "ws";
const ws = new WebSocket("ws://localhost:${server.port}");
ws.on("upgrade", () => {});
ws.on("open", () => ws.close());
ws.on("close", () => process.exit(0));`,
],
env: bunEnv,
stderr: "pipe",
});
const [stderr, exitCode] = await Promise.all([proc.stderr.text(), proc.exited]);
expect(stderr).not.toContain("'upgrade' event is not implemented");
expect(exitCode).toBe(0);
});
test("ws WebSocket should not emit warnings for unexpected-response event", async () => {
using server = Bun.serve({
port: 0,
fetch(req, server) {
if (server.upgrade(req)) {
return;
}
return new Response("Not found", { status: 404 });
},
websocket: {
open() {},
message() {},
close() {},
},
});
await using proc = Bun.spawn({
cmd: [
bunExe(),
"-e",
`import WebSocket from "ws";
const ws = new WebSocket("ws://localhost:${server.port}");
ws.on("unexpected-response", () => {});
ws.on("open", () => ws.close());
ws.on("close", () => process.exit(0));`,
],
env: bunEnv,
stderr: "pipe",
});
const [stderr, exitCode] = await Promise.all([proc.stderr.text(), proc.exited]);
expect(stderr).not.toContain("'unexpected-response' event is not implemented");
expect(exitCode).toBe(0);
});
test("ws WebSocket should emit upgrade event with response object", async () => {
using server = Bun.serve({
port: 0,
fetch(req, server) {
if (server.upgrade(req)) {
return;
}
return new Response("Not found", { status: 404 });
},
websocket: {
open() {},
message() {},
close() {},
},
});
const WebSocket = (await import("ws")).default;
const ws = new WebSocket(`ws://localhost:${server.port}`);
let upgradeReceived = false;
let upgradeResponse: any = null;
await new Promise<void>((resolve, reject) => {
ws.on("upgrade", (response: any) => {
upgradeReceived = true;
upgradeResponse = response;
});
ws.on("open", () => {
ws.close();
});
ws.on("close", () => {
resolve();
});
ws.on("error", reject);
});
expect(upgradeReceived).toBe(true);
expect(upgradeResponse).not.toBeNull();
expect(upgradeResponse.statusCode).toBe(101);
expect(upgradeResponse.statusMessage).toBe("Switching Protocols");
expect(typeof upgradeResponse.headers).toBe("object");
});
test("ws WebSocket upgrade event should be emitted before open event", async () => {
using server = Bun.serve({
port: 0,
fetch(req, server) {
if (server.upgrade(req)) {
return;
}
return new Response("Not found", { status: 404 });
},
websocket: {
open() {},
message() {},
close() {},
},
});
const WebSocket = (await import("ws")).default;
const ws = new WebSocket(`ws://localhost:${server.port}`);
const events: string[] = [];
await new Promise<void>((resolve, reject) => {
ws.on("upgrade", () => {
events.push("upgrade");
});
ws.on("open", () => {
events.push("open");
ws.close();
});
ws.on("close", () => {
resolve();
});
ws.on("error", reject);
});
expect(events).toEqual(["upgrade", "open"]);
});
test("ws WebSocket should work without upgrade listener (backward compatibility)", async () => {
using server = Bun.serve({
port: 0,
fetch(req, server) {
if (server.upgrade(req)) {
return;
}
return new Response("Not found", { status: 404 });
},
websocket: {
open() {},
message() {},
close() {},
},
});
const WebSocket = (await import("ws")).default;
const ws = new WebSocket(`ws://localhost:${server.port}`);
let openReceived = false;
await new Promise<void>((resolve, reject) => {
ws.on("open", () => {
openReceived = true;
ws.close();
});
ws.on("close", () => {
resolve();
});
ws.on("error", reject);
});
expect(openReceived).toBe(true);
});
test("native WebSocket should expose upgradeStatusCode property", async () => {
using server = Bun.serve({
port: 0,
fetch(req, server) {
if (server.upgrade(req)) {
return;
}
return new Response("Not found", { status: 404 });
},
websocket: {
open() {},
message() {},
close() {},
},
});
const ws = new WebSocket(`ws://localhost:${server.port}`);
await new Promise<void>((resolve, reject) => {
ws.addEventListener("open", () => {
expect((ws as any).upgradeStatusCode).toBe(101);
expect(typeof (ws as any).upgradeStatusCode).toBe("number");
ws.close();
resolve();
});
ws.addEventListener("error", reject);
});
});
});