Compare commits

..

3 Commits

Author SHA1 Message Date
Claude Bot
172bd045d8 refactor: extract computeExpectedAccept, use bun.SHA1 and bun.base64
Address review feedback:
- Extract the Sec-WebSocket-Accept computation into a separate
  `computeExpectedAccept` function
- Use `bun.sha.SHA1` (BoringSSL EVP) instead of `std.crypto.hash.Sha1`
- Use `bun.base64.encode` instead of `std.base64.standard.Encoder.encode`
- Replace `= undefined` field default with zero-initialized array to
  satisfy ban-words lint

Co-Authored-By: Claude <noreply@anthropic.com>
2026-02-12 07:07:13 +00:00
Claude
225a5cceab refactor: extract computeExpectedAccept, use bun.SHA1 and bun.base64
Extract SHA-1 + base64 accept header computation into a separate
computeExpectedAccept method. Switch from std.crypto.hash.Sha1 to
bun.sha.SHA1 and from std.base64.standard.Encoder to bun.base64.
Also remove `= undefined` default on expected_accept struct field
to fix the ban-words test.

https://claude.ai/code/session_01Rtii7UWFL1csaEkeGBdgfd
2026-02-12 06:57:48 +00:00
Claude Bot
ca6b28b2ac fix(websocket): validate Sec-WebSocket-Accept header per RFC 6455
The WebSocket upgrade client checked that the Sec-WebSocket-Accept
header was present but never validated its value against the expected
SHA-1 hash of the client's Sec-WebSocket-Key concatenated with the
RFC 6455 magic GUID. This allowed a MitM attacker to fake a WebSocket
handshake with any arbitrary accept value.

Store the expected accept value (computed during request construction)
on the client struct and validate it against the server's response
during the upgrade handshake.

Co-Authored-By: Claude <noreply@anthropic.com>
2026-02-12 04:49:38 +00:00
4 changed files with 227 additions and 156 deletions

View File

@@ -23,7 +23,6 @@ var print_every_i: usize = 0;
// we always rewrite the entire HTTP request when write() returns EAGAIN
// so we can reuse this buffer
var shared_request_headers_buf: [256]picohttp.Header = undefined;
var shared_request_headers_overflow: ?[]picohttp.Header = null;
// this doesn't need to be stack memory because it is immediately cloned after use
var shared_response_headers_buf: [256]picohttp.Header = undefined;
@@ -606,32 +605,7 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
var header_entries = this.header_entries.slice();
const header_names = header_entries.items(.name);
const header_values = header_entries.items(.value);
// The maximum number of headers is the user-provided headers plus up to
// 6 extra headers that may be added below (Connection, User-Agent,
// Accept, Host, Accept-Encoding, Content-Length/Transfer-Encoding).
const max_headers = header_names.len + 6;
const static_buf_len = shared_request_headers_buf.len;
// Use the static buffer for the common case, dynamically allocate for overflow.
// The overflow buffer is kept around for reuse to avoid repeated allocations.
var request_headers_buf: []picohttp.Header = if (max_headers <= static_buf_len)
&shared_request_headers_buf
else blk: {
if (shared_request_headers_overflow) |overflow| {
if (overflow.len >= max_headers) {
break :blk overflow;
}
bun.default_allocator.free(overflow);
shared_request_headers_overflow = null;
}
const buf = bun.default_allocator.alloc(picohttp.Header, max_headers) catch
// On allocation failure, fall back to the static buffer and
// truncate headers rather than writing out of bounds.
break :blk @as([]picohttp.Header, &shared_request_headers_buf);
shared_request_headers_overflow = buf;
break :blk buf;
};
var request_headers_buf = &shared_request_headers_buf;
var override_accept_encoding = false;
var override_accept_header = false;
@@ -693,32 +667,43 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
else => {},
}
if (header_count >= request_headers_buf.len) break;
request_headers_buf[header_count] = .{
.name = name,
.value = this.headerStr(header_values[i]),
};
// header_name_hashes[header_count] = hash;
// // ensure duplicate headers come after each other
// if (header_count > 2) {
// var head_i: usize = header_count - 1;
// while (head_i > 0) : (head_i -= 1) {
// if (header_name_hashes[head_i] == header_name_hashes[header_count]) {
// std.mem.swap(picohttp.Header, &header_name_hashes[header_count], &header_name_hashes[head_i + 1]);
// std.mem.swap(u64, &request_headers_buf[header_count], &request_headers_buf[head_i + 1]);
// break;
// }
// }
// }
header_count += 1;
}
if (!override_connection_header and !this.flags.disable_keepalive and header_count < request_headers_buf.len) {
if (!override_connection_header and !this.flags.disable_keepalive) {
request_headers_buf[header_count] = connection_header;
header_count += 1;
}
if (!override_user_agent and header_count < request_headers_buf.len) {
if (!override_user_agent) {
request_headers_buf[header_count] = getUserAgentHeader();
header_count += 1;
}
if (!override_accept_header and header_count < request_headers_buf.len) {
if (!override_accept_header) {
request_headers_buf[header_count] = accept_header;
header_count += 1;
}
if (!override_host_header and header_count < request_headers_buf.len) {
if (!override_host_header) {
request_headers_buf[header_count] = .{
.name = host_header_name,
.value = this.url.host,
@@ -726,33 +711,31 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
header_count += 1;
}
if (!override_accept_encoding and !this.flags.disable_decompression and header_count < request_headers_buf.len) {
if (!override_accept_encoding and !this.flags.disable_decompression) {
request_headers_buf[header_count] = accept_encoding_header;
header_count += 1;
}
if (header_count < request_headers_buf.len) {
if (body_len > 0 or this.method.hasRequestBody()) {
if (this.flags.is_streaming_request_body) {
if (add_transfer_encoding and this.flags.upgrade_state == .none) {
request_headers_buf[header_count] = chunked_encoded_header;
header_count += 1;
}
} else {
request_headers_buf[header_count] = .{
.name = content_length_header_name,
.value = std.fmt.bufPrint(&this.request_content_len_buf, "{d}", .{body_len}) catch "0",
};
if (body_len > 0 or this.method.hasRequestBody()) {
if (this.flags.is_streaming_request_body) {
if (add_transfer_encoding and this.flags.upgrade_state == .none) {
request_headers_buf[header_count] = chunked_encoded_header;
header_count += 1;
}
} else if (original_content_length) |content_length| {
} else {
request_headers_buf[header_count] = .{
.name = content_length_header_name,
.value = content_length,
.value = std.fmt.bufPrint(&this.request_content_len_buf, "{d}", .{body_len}) catch "0",
};
header_count += 1;
}
} else if (original_content_length) |content_length| {
request_headers_buf[header_count] = .{
.name = content_length_header_name,
.value = content_length,
};
header_count += 1;
}
return picohttp.Request{

View File

@@ -43,6 +43,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
state: State = .initializing,
subprotocols: bun.StringSet,
/// Expected Sec-WebSocket-Accept value for RFC 6455 handshake validation.
/// This is SHA-1(Sec-WebSocket-Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") base64-encoded (always 28 bytes).
expected_accept: [28]u8,
/// Proxy state (null when not using proxy)
proxy: ?WebSocketProxy = null,
@@ -133,7 +137,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
}
}
const body = buildRequestBody(
const build_result = buildRequestBody(
vm,
pathname,
ssl,
@@ -143,6 +147,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
extra_headers,
if (target_authorization) |auth| auth.slice() else null,
) catch return null;
const body = build_result.body;
// Build proxy state if using proxy
// The CONNECT request is built using local variables for proxy_authorization and proxy_headers
@@ -209,6 +214,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
.input_body_buf = if (using_proxy) connect_request else body,
.state = .initializing,
.proxy = proxy_state,
.expected_accept = build_result.expected_accept,
.subprotocols = brk: {
var subprotocols = bun.StringSet.init(bun.default_allocator);
var it = bun.http.HeaderValueIterator.init(protocol_for_subprotocols.slice());
@@ -923,7 +929,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
return;
}
// TODO: check websocket_accept_header.value
if (!strings.eql(websocket_accept_header.value, &this.expected_accept)) {
this.terminate(ErrorCode.mismatch_websocket_accept_header);
return;
}
const overflow_len = remain_buf.len;
var overflow: []u8 = &.{};
@@ -1165,6 +1174,26 @@ fn buildConnectRequest(
return buf.toOwnedSlice();
}
const BuildRequestResult = struct {
body: []u8,
expected_accept: [28]u8,
};
/// Compute the expected Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2:
/// Base64(SHA-1(Sec-WebSocket-Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
fn computeExpectedAccept(key: []const u8) [28]u8 {
const websocket_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
var hasher = bun.sha.SHA1.init();
defer hasher.deinit();
hasher.update(key);
hasher.update(websocket_guid);
var sha1_digest: bun.sha.SHA1.Digest = .{0} ** bun.sha.SHA1.digest;
hasher.final(&sha1_digest);
var result: [28]u8 = .{0} ** 28;
_ = bun.base64.encode(&result, &sha1_digest);
return result;
}
fn buildRequestBody(
vm: *jsc.VirtualMachine,
pathname: *const jsc.ZigString,
@@ -1174,7 +1203,7 @@ fn buildRequestBody(
client_protocol: *const jsc.ZigString,
extra_headers: NonUTF8Headers,
target_authorization: ?[]const u8,
) std.mem.Allocator.Error![]u8 {
) std.mem.Allocator.Error!BuildRequestResult {
const allocator = vm.allocator;
// Check for user overrides
@@ -1221,6 +1250,9 @@ fn buildRequestBody(
// Generate a new key if user key is invalid or not provided
break :blk std.base64.standard.Encoder.encode(&encoded_buf, &vm.rareData().nextUUID().bytes);
};
const expected_accept = computeExpectedAccept(key);
const protocol = if (user_protocol) |p| p.slice() else client_protocol.slice();
const pathname_ = pathname.toSlice(allocator);
@@ -1273,7 +1305,26 @@ fn buildRequestBody(
// Build request with user overrides
if (user_host) |h| {
return try std.fmt.allocPrint(
return .{
.body = try std.fmt.allocPrint(
allocator,
"GET {s} HTTP/1.1\r\n" ++
"Host: {f}\r\n" ++
"Connection: Upgrade\r\n" ++
"Upgrade: websocket\r\n" ++
"Sec-WebSocket-Version: 13\r\n" ++
"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" ++
"{f}" ++
"{s}" ++
"\r\n",
.{ pathname_.slice(), h, pico_headers, extra_headers_buf.items },
),
.expected_accept = expected_accept,
};
}
return .{
.body = try std.fmt.allocPrint(
allocator,
"GET {s} HTTP/1.1\r\n" ++
"Host: {f}\r\n" ++
@@ -1284,23 +1335,10 @@ fn buildRequestBody(
"{f}" ++
"{s}" ++
"\r\n",
.{ pathname_.slice(), h, pico_headers, extra_headers_buf.items },
);
}
return try std.fmt.allocPrint(
allocator,
"GET {s} HTTP/1.1\r\n" ++
"Host: {f}\r\n" ++
"Connection: Upgrade\r\n" ++
"Upgrade: websocket\r\n" ++
"Sec-WebSocket-Version: 13\r\n" ++
"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" ++
"{f}" ++
"{s}" ++
"\r\n",
.{ pathname_.slice(), host_fmt, pico_headers, extra_headers_buf.items },
);
.{ pathname_.slice(), host_fmt, pico_headers, extra_headers_buf.items },
),
.expected_accept = expected_accept,
};
}
const log = Output.scoped(.WebSocketUpgradeClient, .visible);

View File

@@ -1,87 +0,0 @@
import { describe, expect, test } from "bun:test";
import { once } from "node:events";
import { createServer } from "node:net";
describe("fetch with many headers", () => {
test("should not crash or corrupt memory with more than 256 headers", async () => {
// Use a raw TCP server to avoid uws header count limits on the server side.
// We just need to verify that the client sends the request without crashing.
await using server = createServer(socket => {
let data = "";
socket.on("data", (chunk: Buffer) => {
data += chunk.toString();
// Wait for the end of HTTP headers (double CRLF)
if (data.includes("\r\n\r\n")) {
// Count headers (lines between the request line and the blank line)
const headerSection = data.split("\r\n\r\n")[0];
const lines = headerSection.split("\r\n");
// First line is the request line (GET / HTTP/1.1), rest are headers
const headerCount = lines.length - 1;
const body = String(headerCount);
const response = ["HTTP/1.1 200 OK", `Content-Length: ${body.length}`, "Connection: close", "", body].join(
"\r\n",
);
socket.write(response);
socket.end();
}
});
}).listen(0);
await once(server, "listening");
const port = (server.address() as any).port;
// Build 300 unique custom headers (exceeds the 256-entry static buffer)
const headers = new Headers();
const headerCount = 300;
for (let i = 0; i < headerCount; i++) {
headers.set(`x-custom-${i}`, `value-${i}`);
}
const res = await fetch(`http://localhost:${port}/`, { headers });
const receivedCount = parseInt(await res.text(), 10);
expect(res.status).toBe(200);
// The server should receive our custom headers plus default ones
// (host, connection, user-agent, accept, accept-encoding = 5 extra)
expect(receivedCount).toBeGreaterThanOrEqual(headerCount);
});
test("should handle exactly 256 user headers without issues", async () => {
await using server = createServer(socket => {
let data = "";
socket.on("data", (chunk: Buffer) => {
data += chunk.toString();
if (data.includes("\r\n\r\n")) {
const headerSection = data.split("\r\n\r\n")[0];
const lines = headerSection.split("\r\n");
const headerCount = lines.length - 1;
const body = String(headerCount);
const response = ["HTTP/1.1 200 OK", `Content-Length: ${body.length}`, "Connection: close", "", body].join(
"\r\n",
);
socket.write(response);
socket.end();
}
});
}).listen(0);
await once(server, "listening");
const port = (server.address() as any).port;
const headers = new Headers();
const headerCount = 256;
for (let i = 0; i < headerCount; i++) {
headers.set(`x-custom-${i}`, `value-${i}`);
}
const res = await fetch(`http://localhost:${port}/`, { headers });
const receivedCount = parseInt(await res.text(), 10);
expect(res.status).toBe(200);
expect(receivedCount).toBeGreaterThanOrEqual(headerCount);
});
});

View File

@@ -0,0 +1,137 @@
import { describe, expect, it, mock } from "bun:test";
import crypto from "node:crypto";
import net from "node:net";
describe("WebSocket Sec-WebSocket-Accept validation (RFC 6455 Section 4.1)", () => {
function computeAcceptKey(websocketKey: string): string {
return crypto
.createHash("sha1")
.update(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
.digest("base64");
}
async function createFakeServer(
getAcceptKey: (clientKey: string) => string,
): Promise<{ port: number; [Symbol.asyncDispose]: () => Promise<void> }> {
const server = net.createServer();
let port: number;
await new Promise<void>(resolve => {
server.listen(0, () => {
port = (server.address() as any).port;
resolve();
});
});
server.on("connection", socket => {
let requestData = "";
socket.on("data", data => {
requestData += data.toString();
if (requestData.includes("\r\n\r\n")) {
const lines = requestData.split("\r\n");
let websocketKey = "";
for (const line of lines) {
if (line.startsWith("Sec-WebSocket-Key:")) {
websocketKey = line.split(":")[1].trim();
break;
}
}
const acceptKey = getAcceptKey(websocketKey);
const response = [
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
`Sec-WebSocket-Accept: ${acceptKey}`,
"\r\n",
].join("\r\n");
socket.write(response);
}
});
});
return {
port: port!,
[Symbol.asyncDispose]: async () => {
server.close();
},
};
}
it("should accept valid Sec-WebSocket-Accept header", async () => {
await using server = await createFakeServer(key => computeAcceptKey(key));
const { promise, resolve, reject } = Promise.withResolvers();
const ws = new WebSocket(`ws://localhost:${server.port}`);
ws.onopen = () => resolve(undefined);
ws.onerror = () => reject(new Error("connection failed"));
await promise;
ws.close();
});
it("should reject invalid Sec-WebSocket-Accept header", async () => {
// Server returns a completely wrong accept key
await using server = await createFakeServer(_key => "dGhlIHNhbXBsZSBub25jZQ==");
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
const onopenMock = mock(() => {});
const ws = new WebSocket(`ws://localhost:${server.port}`);
ws.onopen = onopenMock;
ws.onclose = event => {
resolve({ code: event.code, reason: event.reason });
};
const result = await promise;
expect(onopenMock).not.toHaveBeenCalled();
expect(result.code).toBe(1002);
expect(result.reason).toBe("Mismatch websocket accept header");
});
it("should reject empty Sec-WebSocket-Accept value", async () => {
// Server returns an empty accept key
await using server = await createFakeServer(_key => "");
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
const onopenMock = mock(() => {});
const ws = new WebSocket(`ws://localhost:${server.port}`);
ws.onopen = onopenMock;
ws.onclose = event => {
resolve({ code: event.code, reason: event.reason });
};
const result = await promise;
expect(onopenMock).not.toHaveBeenCalled();
// Empty value should be caught by either the missing header check or the accept validation
expect(result.code).toBe(1002);
});
it("should reject Sec-WebSocket-Accept with wrong key computation", async () => {
// Server computes accept from a different key (simulating MitM)
await using server = await createFakeServer(_key => {
// Compute valid accept but for a different (attacker-chosen) key
return computeAcceptKey("AAAAAAAAAAAAAAAAAAAAAA==");
});
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
const onopenMock = mock(() => {});
const ws = new WebSocket(`ws://localhost:${server.port}`);
ws.onopen = onopenMock;
ws.onclose = event => {
resolve({ code: event.code, reason: event.reason });
};
const result = await promise;
expect(onopenMock).not.toHaveBeenCalled();
expect(result.code).toBe(1002);
});
});