mirror of
https://github.com/oven-sh/bun
synced 2026-02-20 07:42:30 +00:00
Compare commits
4 Commits
claude/fix
...
claude/opt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76582063b8 | ||
|
|
d6a5cfa69e | ||
|
|
9bd4ad7166 | ||
|
|
d6d41d58d1 |
108
bench/snippets/microtask-throughput.mjs
Normal file
108
bench/snippets/microtask-throughput.mjs
Normal file
@@ -0,0 +1,108 @@
|
||||
import { AsyncLocalStorage } from "node:async_hooks";
|
||||
import { bench, group, run } from "../runner.mjs";
|
||||
|
||||
// Benchmark 1: queueMicrotask throughput
|
||||
// Tests the BunPerformMicrotaskJob handler path directly.
|
||||
// The optimization removes the JS trampoline and uses callMicrotask.
|
||||
group("queueMicrotask throughput", () => {
|
||||
bench("queueMicrotask 1k", () => {
|
||||
return new Promise(resolve => {
|
||||
let remaining = 1000;
|
||||
const tick = () => {
|
||||
if (--remaining === 0) resolve();
|
||||
else queueMicrotask(tick);
|
||||
};
|
||||
queueMicrotask(tick);
|
||||
});
|
||||
});
|
||||
|
||||
bench("queueMicrotask 10k", () => {
|
||||
return new Promise(resolve => {
|
||||
let remaining = 10000;
|
||||
const tick = () => {
|
||||
if (--remaining === 0) resolve();
|
||||
else queueMicrotask(tick);
|
||||
};
|
||||
queueMicrotask(tick);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Benchmark 2: Promise.resolve chain
|
||||
// Each .then() queues a microtask via the promise machinery.
|
||||
// Benefits from smaller QueuedTask (better cache locality in the Deque).
|
||||
group("Promise.resolve chain", () => {
|
||||
bench("Promise chain 1k", () => {
|
||||
let p = Promise.resolve();
|
||||
for (let i = 0; i < 1000; i++) {
|
||||
p = p.then(() => {});
|
||||
}
|
||||
return p;
|
||||
});
|
||||
|
||||
bench("Promise chain 10k", () => {
|
||||
let p = Promise.resolve();
|
||||
for (let i = 0; i < 10000; i++) {
|
||||
p = p.then(() => {});
|
||||
}
|
||||
return p;
|
||||
});
|
||||
});
|
||||
|
||||
// Benchmark 3: Promise.all (many simultaneous resolves)
|
||||
// All promises resolve at once, flooding the microtask queue.
|
||||
// Smaller QueuedTask = less memory, better cache utilization.
|
||||
group("Promise.all simultaneous", () => {
|
||||
bench("Promise.all 1k", () => {
|
||||
const promises = [];
|
||||
for (let i = 0; i < 1000; i++) {
|
||||
promises.push(Promise.resolve(i));
|
||||
}
|
||||
return Promise.all(promises);
|
||||
});
|
||||
|
||||
bench("Promise.all 10k", () => {
|
||||
const promises = [];
|
||||
for (let i = 0; i < 10000; i++) {
|
||||
promises.push(Promise.resolve(i));
|
||||
}
|
||||
return Promise.all(promises);
|
||||
});
|
||||
});
|
||||
|
||||
// Benchmark 4: queueMicrotask with AsyncLocalStorage
|
||||
// Tests the inlined async context save/restore path.
|
||||
// Previously went through performMicrotaskFunction JS trampoline.
|
||||
group("queueMicrotask + AsyncLocalStorage", () => {
|
||||
const als = new AsyncLocalStorage();
|
||||
|
||||
bench("ALS.run + queueMicrotask 1k", () => {
|
||||
return als.run({ id: 1 }, () => {
|
||||
return new Promise(resolve => {
|
||||
let remaining = 1000;
|
||||
const tick = () => {
|
||||
als.getStore(); // force context read
|
||||
if (--remaining === 0) resolve();
|
||||
else queueMicrotask(tick);
|
||||
};
|
||||
queueMicrotask(tick);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Benchmark 5: async/await (each await queues microtasks)
|
||||
group("async/await chain", () => {
|
||||
async function asyncChain(n) {
|
||||
let sum = 0;
|
||||
for (let i = 0; i < n; i++) {
|
||||
sum += await Promise.resolve(i);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
bench("async/await 1k", () => asyncChain(1000));
|
||||
bench("async/await 10k", () => asyncChain(10000));
|
||||
});
|
||||
|
||||
await run();
|
||||
@@ -6,7 +6,7 @@ option(WEBKIT_LOCAL "If a local version of WebKit should be used instead of down
|
||||
option(WEBKIT_BUILD_TYPE "The build type for local WebKit (defaults to CMAKE_BUILD_TYPE)")
|
||||
|
||||
if(NOT WEBKIT_VERSION)
|
||||
set(WEBKIT_VERSION 8af7958ff0e2a4787569edf64641a1ae7cfe074a)
|
||||
set(WEBKIT_VERSION preview-pr-160-8680a32c)
|
||||
endif()
|
||||
|
||||
# Use preview build URL for Windows ARM64 until the fix is merged to main
|
||||
|
||||
@@ -1061,9 +1061,7 @@ JSC_DEFINE_HOST_FUNCTION(functionQueueMicrotask,
|
||||
|
||||
auto* globalObject = defaultGlobalObject(lexicalGlobalObject);
|
||||
JSC::JSValue asyncContext = globalObject->m_asyncContextData.get()->getInternalField(0);
|
||||
auto function = globalObject->performMicrotaskFunction();
|
||||
#if ASSERT_ENABLED
|
||||
ASSERT_WITH_MESSAGE(function, "Invalid microtask function");
|
||||
ASSERT_WITH_MESSAGE(!callback.isEmpty(), "Invalid microtask callback");
|
||||
#endif
|
||||
|
||||
@@ -1071,10 +1069,8 @@ JSC_DEFINE_HOST_FUNCTION(functionQueueMicrotask,
|
||||
asyncContext = JSC::jsUndefined();
|
||||
}
|
||||
|
||||
// BunPerformMicrotaskJob accepts a variable number of arguments (up to: performMicrotask, job, asyncContext, arg0, arg1).
|
||||
// The runtime inspects argumentCount to determine which arguments are present, so callers may pass only the subset they need.
|
||||
// Here we pass: function, callback, asyncContext.
|
||||
JSC::QueuedTask task { nullptr, JSC::InternalMicrotask::BunPerformMicrotaskJob, 0, globalObject, function, callback, asyncContext };
|
||||
// BunPerformMicrotaskJob: callback, asyncContext
|
||||
JSC::QueuedTask task { nullptr, JSC::InternalMicrotask::BunPerformMicrotaskJob, 0, globalObject, callback, asyncContext };
|
||||
globalObject->vm().queueMicrotask(WTF::move(task));
|
||||
|
||||
return JSC::JSValue::encode(JSC::jsUndefined());
|
||||
@@ -1554,63 +1550,6 @@ extern "C" napi_env ZigGlobalObject__makeNapiEnvForFFI(Zig::GlobalObject* global
|
||||
return globalObject->makeNapiEnvForFFI();
|
||||
}
|
||||
|
||||
JSC_DEFINE_HOST_FUNCTION(jsFunctionPerformMicrotask, (JSGlobalObject * globalObject, CallFrame* callframe))
|
||||
{
|
||||
auto& vm = JSC::getVM(globalObject);
|
||||
auto scope = DECLARE_TOP_EXCEPTION_SCOPE(vm);
|
||||
|
||||
auto job = callframe->argument(0);
|
||||
if (!job || job.isUndefinedOrNull()) [[unlikely]] {
|
||||
return JSValue::encode(jsUndefined());
|
||||
}
|
||||
|
||||
auto callData = JSC::getCallData(job);
|
||||
MarkedArgumentBuffer arguments;
|
||||
|
||||
if (callData.type == CallData::Type::None) [[unlikely]] {
|
||||
return JSValue::encode(jsUndefined());
|
||||
}
|
||||
|
||||
JSValue result;
|
||||
WTF::NakedPtr<JSC::Exception> exceptionPtr;
|
||||
|
||||
JSValue restoreAsyncContext = {};
|
||||
InternalFieldTuple* asyncContextData = nullptr;
|
||||
auto setAsyncContext = callframe->argument(1);
|
||||
if (!setAsyncContext.isUndefined()) {
|
||||
asyncContextData = globalObject->m_asyncContextData.get();
|
||||
restoreAsyncContext = asyncContextData->getInternalField(0);
|
||||
asyncContextData->putInternalField(vm, 0, setAsyncContext);
|
||||
}
|
||||
|
||||
size_t argCount = callframe->argumentCount();
|
||||
switch (argCount) {
|
||||
case 3: {
|
||||
arguments.append(callframe->uncheckedArgument(2));
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
arguments.append(callframe->uncheckedArgument(2));
|
||||
arguments.append(callframe->uncheckedArgument(3));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
JSC::profiledCall(globalObject, ProfilingReason::API, job, callData, jsUndefined(), arguments, exceptionPtr);
|
||||
|
||||
if (asyncContextData) {
|
||||
asyncContextData->putInternalField(vm, 0, restoreAsyncContext);
|
||||
}
|
||||
|
||||
if (auto* exception = exceptionPtr.get()) {
|
||||
Bun__reportUnhandledError(globalObject, JSValue::encode(exception));
|
||||
}
|
||||
|
||||
return JSValue::encode(jsUndefined());
|
||||
}
|
||||
|
||||
JSC_DEFINE_HOST_FUNCTION(jsFunctionPerformMicrotaskVariadic, (JSGlobalObject * globalObject, CallFrame* callframe))
|
||||
{
|
||||
auto& vm = JSC::getVM(globalObject);
|
||||
@@ -1940,11 +1879,6 @@ void GlobalObject::finishCreation(VM& vm)
|
||||
scope.assertNoExceptionExceptTermination();
|
||||
init.set(subclassStructure);
|
||||
});
|
||||
m_performMicrotaskFunction.initLater(
|
||||
[](const Initializer<JSFunction>& init) {
|
||||
init.set(JSFunction::create(init.vm, init.owner, 4, "performMicrotask"_s, jsFunctionPerformMicrotask, ImplementationVisibility::Public));
|
||||
});
|
||||
|
||||
m_performMicrotaskVariadicFunction.initLater(
|
||||
[](const Initializer<JSFunction>& init) {
|
||||
init.set(JSFunction::create(init.vm, init.owner, 4, "performMicrotaskVariadic"_s, jsFunctionPerformMicrotaskVariadic, ImplementationVisibility::Public));
|
||||
|
||||
@@ -272,7 +272,6 @@ public:
|
||||
|
||||
JSC::JSObject* performanceObject() const { return m_performanceObject.getInitializedOnMainThread(this); }
|
||||
|
||||
JSC::JSFunction* performMicrotaskFunction() const { return m_performMicrotaskFunction.getInitializedOnMainThread(this); }
|
||||
JSC::JSFunction* performMicrotaskVariadicFunction() const { return m_performMicrotaskVariadicFunction.getInitializedOnMainThread(this); }
|
||||
|
||||
JSC::Structure* utilInspectOptionsStructure() const { return m_utilInspectOptionsStructure.getInitializedOnMainThread(this); }
|
||||
@@ -569,7 +568,6 @@ public:
|
||||
V(private, LazyPropertyOfGlobalObject<Structure>, m_jsonlParseResultStructure) \
|
||||
V(private, LazyPropertyOfGlobalObject<Structure>, m_pathParsedObjectStructure) \
|
||||
V(private, LazyPropertyOfGlobalObject<Structure>, m_pendingVirtualModuleResultStructure) \
|
||||
V(private, LazyPropertyOfGlobalObject<JSFunction>, m_performMicrotaskFunction) \
|
||||
V(private, LazyPropertyOfGlobalObject<JSFunction>, m_nativeMicrotaskTrampoline) \
|
||||
V(private, LazyPropertyOfGlobalObject<JSFunction>, m_performMicrotaskVariadicFunction) \
|
||||
V(private, LazyPropertyOfGlobalObject<JSFunction>, m_utilInspectFunction) \
|
||||
|
||||
@@ -3538,13 +3538,11 @@ void JSC__JSPromise__rejectOnNextTickWithHandled(JSC::JSPromise* promise, JSC::J
|
||||
|
||||
promise->internalField(JSC::JSPromise::Field::Flags).set(vm, promise, jsNumber(flags | JSC::JSPromise::isFirstResolvingFunctionCalledFlag));
|
||||
auto* globalObject = jsCast<Zig::GlobalObject*>(promise->globalObject());
|
||||
auto microtaskFunction = globalObject->performMicrotaskFunction();
|
||||
auto rejectPromiseFunction = globalObject->rejectPromiseFunction();
|
||||
|
||||
auto asyncContext = globalObject->m_asyncContextData.get()->getInternalField(0);
|
||||
|
||||
#if ASSERT_ENABLED
|
||||
ASSERT_WITH_MESSAGE(microtaskFunction, "Invalid microtask function");
|
||||
ASSERT_WITH_MESSAGE(rejectPromiseFunction, "Invalid microtask callback");
|
||||
ASSERT_WITH_MESSAGE(!value.isEmpty(), "Invalid microtask value");
|
||||
#endif
|
||||
@@ -3557,7 +3555,8 @@ void JSC__JSPromise__rejectOnNextTickWithHandled(JSC::JSPromise* promise, JSC::J
|
||||
value = jsUndefined();
|
||||
}
|
||||
|
||||
JSC::QueuedTask task { nullptr, JSC::InternalMicrotask::BunPerformMicrotaskJob, 0, globalObject, microtaskFunction, rejectPromiseFunction, globalObject->m_asyncContextData.get()->getInternalField(0), promise, value };
|
||||
// BunPerformMicrotaskJob: rejectPromiseFunction, asyncContext, promise, value
|
||||
JSC::QueuedTask task { nullptr, JSC::InternalMicrotask::BunPerformMicrotaskJob, 0, globalObject, rejectPromiseFunction, globalObject->m_asyncContextData.get()->getInternalField(0), promise, value };
|
||||
globalObject->vm().queueMicrotask(WTF::move(task));
|
||||
RETURN_IF_EXCEPTION(scope, );
|
||||
}
|
||||
@@ -5438,9 +5437,7 @@ extern "C" void JSC__JSGlobalObject__queueMicrotaskJob(JSC::JSGlobalObject* arg0
|
||||
if (microtaskArgs[3].isEmpty()) {
|
||||
microtaskArgs[3] = jsUndefined();
|
||||
}
|
||||
JSC::JSFunction* microTaskFunction = globalObject->performMicrotaskFunction();
|
||||
#if ASSERT_ENABLED
|
||||
ASSERT_WITH_MESSAGE(microTaskFunction, "Invalid microtask function");
|
||||
auto& vm = globalObject->vm();
|
||||
if (microtaskArgs[0].isCell()) {
|
||||
JSC::Integrity::auditCellFully(vm, microtaskArgs[0].asCell());
|
||||
@@ -5460,7 +5457,8 @@ extern "C" void JSC__JSGlobalObject__queueMicrotaskJob(JSC::JSGlobalObject* arg0
|
||||
|
||||
#endif
|
||||
|
||||
JSC::QueuedTask task { nullptr, JSC::InternalMicrotask::BunPerformMicrotaskJob, 0, globalObject, microTaskFunction, WTF::move(microtaskArgs[0]), WTF::move(microtaskArgs[1]), WTF::move(microtaskArgs[2]), WTF::move(microtaskArgs[3]) };
|
||||
// BunPerformMicrotaskJob: job, asyncContext, arg0, arg1
|
||||
JSC::QueuedTask task { nullptr, JSC::InternalMicrotask::BunPerformMicrotaskJob, 0, globalObject, WTF::move(microtaskArgs[0]), WTF::move(microtaskArgs[1]), WTF::move(microtaskArgs[2]), WTF::move(microtaskArgs[3]) };
|
||||
globalObject->vm().queueMicrotask(WTF::move(task));
|
||||
}
|
||||
|
||||
|
||||
@@ -43,10 +43,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
state: State = .initializing,
|
||||
subprotocols: bun.StringSet,
|
||||
|
||||
/// Expected Sec-WebSocket-Accept value for RFC 6455 handshake validation.
|
||||
/// This is SHA-1(Sec-WebSocket-Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") base64-encoded (always 28 bytes).
|
||||
expected_accept: [28]u8,
|
||||
|
||||
/// Proxy state (null when not using proxy)
|
||||
proxy: ?WebSocketProxy = null,
|
||||
|
||||
@@ -137,7 +133,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
}
|
||||
}
|
||||
|
||||
const build_result = buildRequestBody(
|
||||
const body = buildRequestBody(
|
||||
vm,
|
||||
pathname,
|
||||
ssl,
|
||||
@@ -147,7 +143,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
extra_headers,
|
||||
if (target_authorization) |auth| auth.slice() else null,
|
||||
) catch return null;
|
||||
const body = build_result.body;
|
||||
|
||||
// Build proxy state if using proxy
|
||||
// The CONNECT request is built using local variables for proxy_authorization and proxy_headers
|
||||
@@ -214,7 +209,6 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
.input_body_buf = if (using_proxy) connect_request else body,
|
||||
.state = .initializing,
|
||||
.proxy = proxy_state,
|
||||
.expected_accept = build_result.expected_accept,
|
||||
.subprotocols = brk: {
|
||||
var subprotocols = bun.StringSet.init(bun.default_allocator);
|
||||
var it = bun.http.HeaderValueIterator.init(protocol_for_subprotocols.slice());
|
||||
@@ -929,10 +923,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!strings.eql(websocket_accept_header.value, &this.expected_accept)) {
|
||||
this.terminate(ErrorCode.mismatch_websocket_accept_header);
|
||||
return;
|
||||
}
|
||||
// TODO: check websocket_accept_header.value
|
||||
|
||||
const overflow_len = remain_buf.len;
|
||||
var overflow: []u8 = &.{};
|
||||
@@ -1174,26 +1165,6 @@ fn buildConnectRequest(
|
||||
return buf.toOwnedSlice();
|
||||
}
|
||||
|
||||
const BuildRequestResult = struct {
|
||||
body: []u8,
|
||||
expected_accept: [28]u8,
|
||||
};
|
||||
|
||||
/// Compute the expected Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2:
|
||||
/// Base64(SHA-1(Sec-WebSocket-Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
||||
fn computeExpectedAccept(key: []const u8) [28]u8 {
|
||||
const websocket_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
var hasher = bun.sha.SHA1.init();
|
||||
defer hasher.deinit();
|
||||
hasher.update(key);
|
||||
hasher.update(websocket_guid);
|
||||
var sha1_digest: bun.sha.SHA1.Digest = .{0} ** bun.sha.SHA1.digest;
|
||||
hasher.final(&sha1_digest);
|
||||
var result: [28]u8 = .{0} ** 28;
|
||||
_ = bun.base64.encode(&result, &sha1_digest);
|
||||
return result;
|
||||
}
|
||||
|
||||
fn buildRequestBody(
|
||||
vm: *jsc.VirtualMachine,
|
||||
pathname: *const jsc.ZigString,
|
||||
@@ -1203,7 +1174,7 @@ fn buildRequestBody(
|
||||
client_protocol: *const jsc.ZigString,
|
||||
extra_headers: NonUTF8Headers,
|
||||
target_authorization: ?[]const u8,
|
||||
) std.mem.Allocator.Error!BuildRequestResult {
|
||||
) std.mem.Allocator.Error![]u8 {
|
||||
const allocator = vm.allocator;
|
||||
|
||||
// Check for user overrides
|
||||
@@ -1250,9 +1221,6 @@ fn buildRequestBody(
|
||||
// Generate a new key if user key is invalid or not provided
|
||||
break :blk std.base64.standard.Encoder.encode(&encoded_buf, &vm.rareData().nextUUID().bytes);
|
||||
};
|
||||
|
||||
const expected_accept = computeExpectedAccept(key);
|
||||
|
||||
const protocol = if (user_protocol) |p| p.slice() else client_protocol.slice();
|
||||
|
||||
const pathname_ = pathname.toSlice(allocator);
|
||||
@@ -1305,26 +1273,7 @@ fn buildRequestBody(
|
||||
|
||||
// Build request with user overrides
|
||||
if (user_host) |h| {
|
||||
return .{
|
||||
.body = try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"GET {s} HTTP/1.1\r\n" ++
|
||||
"Host: {f}\r\n" ++
|
||||
"Connection: Upgrade\r\n" ++
|
||||
"Upgrade: websocket\r\n" ++
|
||||
"Sec-WebSocket-Version: 13\r\n" ++
|
||||
"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" ++
|
||||
"{f}" ++
|
||||
"{s}" ++
|
||||
"\r\n",
|
||||
.{ pathname_.slice(), h, pico_headers, extra_headers_buf.items },
|
||||
),
|
||||
.expected_accept = expected_accept,
|
||||
};
|
||||
}
|
||||
|
||||
return .{
|
||||
.body = try std.fmt.allocPrint(
|
||||
return try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"GET {s} HTTP/1.1\r\n" ++
|
||||
"Host: {f}\r\n" ++
|
||||
@@ -1335,10 +1284,23 @@ fn buildRequestBody(
|
||||
"{f}" ++
|
||||
"{s}" ++
|
||||
"\r\n",
|
||||
.{ pathname_.slice(), host_fmt, pico_headers, extra_headers_buf.items },
|
||||
),
|
||||
.expected_accept = expected_accept,
|
||||
};
|
||||
.{ pathname_.slice(), h, pico_headers, extra_headers_buf.items },
|
||||
);
|
||||
}
|
||||
|
||||
return try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"GET {s} HTTP/1.1\r\n" ++
|
||||
"Host: {f}\r\n" ++
|
||||
"Connection: Upgrade\r\n" ++
|
||||
"Upgrade: websocket\r\n" ++
|
||||
"Sec-WebSocket-Version: 13\r\n" ++
|
||||
"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" ++
|
||||
"{f}" ++
|
||||
"{s}" ++
|
||||
"\r\n",
|
||||
.{ pathname_.slice(), host_fmt, pico_headers, extra_headers_buf.items },
|
||||
);
|
||||
}
|
||||
|
||||
const log = Output.scoped(.WebSocketUpgradeClient, .visible);
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import { describe, expect, it, mock } from "bun:test";
|
||||
import crypto from "node:crypto";
|
||||
import net from "node:net";
|
||||
|
||||
describe("WebSocket Sec-WebSocket-Accept validation (RFC 6455 Section 4.1)", () => {
|
||||
function computeAcceptKey(websocketKey: string): string {
|
||||
return crypto
|
||||
.createHash("sha1")
|
||||
.update(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
.digest("base64");
|
||||
}
|
||||
|
||||
async function createFakeServer(
|
||||
getAcceptKey: (clientKey: string) => string,
|
||||
): Promise<{ port: number; [Symbol.asyncDispose]: () => Promise<void> }> {
|
||||
const server = net.createServer();
|
||||
let port: number;
|
||||
|
||||
await new Promise<void>(resolve => {
|
||||
server.listen(0, () => {
|
||||
port = (server.address() as any).port;
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
server.on("connection", socket => {
|
||||
let requestData = "";
|
||||
|
||||
socket.on("data", data => {
|
||||
requestData += data.toString();
|
||||
|
||||
if (requestData.includes("\r\n\r\n")) {
|
||||
const lines = requestData.split("\r\n");
|
||||
let websocketKey = "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("Sec-WebSocket-Key:")) {
|
||||
websocketKey = line.split(":")[1].trim();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const acceptKey = getAcceptKey(websocketKey);
|
||||
|
||||
const response = [
|
||||
"HTTP/1.1 101 Switching Protocols",
|
||||
"Upgrade: websocket",
|
||||
"Connection: Upgrade",
|
||||
`Sec-WebSocket-Accept: ${acceptKey}`,
|
||||
"\r\n",
|
||||
].join("\r\n");
|
||||
|
||||
socket.write(response);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
port: port!,
|
||||
[Symbol.asyncDispose]: async () => {
|
||||
server.close();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
it("should accept valid Sec-WebSocket-Accept header", async () => {
|
||||
await using server = await createFakeServer(key => computeAcceptKey(key));
|
||||
|
||||
const { promise, resolve, reject } = Promise.withResolvers();
|
||||
const ws = new WebSocket(`ws://localhost:${server.port}`);
|
||||
|
||||
ws.onopen = () => resolve(undefined);
|
||||
ws.onerror = () => reject(new Error("connection failed"));
|
||||
|
||||
await promise;
|
||||
ws.close();
|
||||
});
|
||||
|
||||
it("should reject invalid Sec-WebSocket-Accept header", async () => {
|
||||
// Server returns a completely wrong accept key
|
||||
await using server = await createFakeServer(_key => "dGhlIHNhbXBsZSBub25jZQ==");
|
||||
|
||||
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
|
||||
const onopenMock = mock(() => {});
|
||||
|
||||
const ws = new WebSocket(`ws://localhost:${server.port}`);
|
||||
ws.onopen = onopenMock;
|
||||
ws.onclose = event => {
|
||||
resolve({ code: event.code, reason: event.reason });
|
||||
};
|
||||
|
||||
const result = await promise;
|
||||
expect(onopenMock).not.toHaveBeenCalled();
|
||||
expect(result.code).toBe(1002);
|
||||
expect(result.reason).toBe("Mismatch websocket accept header");
|
||||
});
|
||||
|
||||
it("should reject empty Sec-WebSocket-Accept value", async () => {
|
||||
// Server returns an empty accept key
|
||||
await using server = await createFakeServer(_key => "");
|
||||
|
||||
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
|
||||
const onopenMock = mock(() => {});
|
||||
|
||||
const ws = new WebSocket(`ws://localhost:${server.port}`);
|
||||
ws.onopen = onopenMock;
|
||||
ws.onclose = event => {
|
||||
resolve({ code: event.code, reason: event.reason });
|
||||
};
|
||||
|
||||
const result = await promise;
|
||||
expect(onopenMock).not.toHaveBeenCalled();
|
||||
// Empty value should be caught by either the missing header check or the accept validation
|
||||
expect(result.code).toBe(1002);
|
||||
});
|
||||
|
||||
it("should reject Sec-WebSocket-Accept with wrong key computation", async () => {
|
||||
// Server computes accept from a different key (simulating MitM)
|
||||
await using server = await createFakeServer(_key => {
|
||||
// Compute valid accept but for a different (attacker-chosen) key
|
||||
return computeAcceptKey("AAAAAAAAAAAAAAAAAAAAAA==");
|
||||
});
|
||||
|
||||
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
|
||||
const onopenMock = mock(() => {});
|
||||
|
||||
const ws = new WebSocket(`ws://localhost:${server.port}`);
|
||||
ws.onopen = onopenMock;
|
||||
ws.onclose = event => {
|
||||
resolve({ code: event.code, reason: event.reason });
|
||||
};
|
||||
|
||||
const result = await promise;
|
||||
expect(onopenMock).not.toHaveBeenCalled();
|
||||
expect(result.code).toBe(1002);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user