mirror of
https://github.com/oven-sh/bun
synced 2026-02-21 08:12:21 +00:00
Compare commits
3 Commits
claude/fix
...
claude/fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
172bd045d8 | ||
|
|
225a5cceab | ||
|
|
ca6b28b2ac |
79
src/http.zig
79
src/http.zig
@@ -23,7 +23,6 @@ var print_every_i: usize = 0;
|
||||
// we always rewrite the entire HTTP request when write() returns EAGAIN
|
||||
// so we can reuse this buffer
|
||||
var shared_request_headers_buf: [256]picohttp.Header = undefined;
|
||||
var shared_request_headers_overflow: ?[]picohttp.Header = null;
|
||||
|
||||
// this doesn't need to be stack memory because it is immediately cloned after use
|
||||
var shared_response_headers_buf: [256]picohttp.Header = undefined;
|
||||
@@ -606,32 +605,7 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
|
||||
var header_entries = this.header_entries.slice();
|
||||
const header_names = header_entries.items(.name);
|
||||
const header_values = header_entries.items(.value);
|
||||
|
||||
// The maximum number of headers is the user-provided headers plus up to
|
||||
// 6 extra headers that may be added below (Connection, User-Agent,
|
||||
// Accept, Host, Accept-Encoding, Content-Length/Transfer-Encoding).
|
||||
const max_headers = header_names.len + 6;
|
||||
const static_buf_len = shared_request_headers_buf.len;
|
||||
|
||||
// Use the static buffer for the common case, dynamically allocate for overflow.
|
||||
// The overflow buffer is kept around for reuse to avoid repeated allocations.
|
||||
var request_headers_buf: []picohttp.Header = if (max_headers <= static_buf_len)
|
||||
&shared_request_headers_buf
|
||||
else blk: {
|
||||
if (shared_request_headers_overflow) |overflow| {
|
||||
if (overflow.len >= max_headers) {
|
||||
break :blk overflow;
|
||||
}
|
||||
bun.default_allocator.free(overflow);
|
||||
shared_request_headers_overflow = null;
|
||||
}
|
||||
const buf = bun.default_allocator.alloc(picohttp.Header, max_headers) catch
|
||||
// On allocation failure, fall back to the static buffer and
|
||||
// truncate headers rather than writing out of bounds.
|
||||
break :blk @as([]picohttp.Header, &shared_request_headers_buf);
|
||||
shared_request_headers_overflow = buf;
|
||||
break :blk buf;
|
||||
};
|
||||
var request_headers_buf = &shared_request_headers_buf;
|
||||
|
||||
var override_accept_encoding = false;
|
||||
var override_accept_header = false;
|
||||
@@ -693,32 +667,43 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
|
||||
else => {},
|
||||
}
|
||||
|
||||
if (header_count >= request_headers_buf.len) break;
|
||||
|
||||
request_headers_buf[header_count] = .{
|
||||
.name = name,
|
||||
.value = this.headerStr(header_values[i]),
|
||||
};
|
||||
|
||||
// header_name_hashes[header_count] = hash;
|
||||
|
||||
// // ensure duplicate headers come after each other
|
||||
// if (header_count > 2) {
|
||||
// var head_i: usize = header_count - 1;
|
||||
// while (head_i > 0) : (head_i -= 1) {
|
||||
// if (header_name_hashes[head_i] == header_name_hashes[header_count]) {
|
||||
// std.mem.swap(picohttp.Header, &header_name_hashes[header_count], &header_name_hashes[head_i + 1]);
|
||||
// std.mem.swap(u64, &request_headers_buf[header_count], &request_headers_buf[head_i + 1]);
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
header_count += 1;
|
||||
}
|
||||
|
||||
if (!override_connection_header and !this.flags.disable_keepalive and header_count < request_headers_buf.len) {
|
||||
if (!override_connection_header and !this.flags.disable_keepalive) {
|
||||
request_headers_buf[header_count] = connection_header;
|
||||
header_count += 1;
|
||||
}
|
||||
|
||||
if (!override_user_agent and header_count < request_headers_buf.len) {
|
||||
if (!override_user_agent) {
|
||||
request_headers_buf[header_count] = getUserAgentHeader();
|
||||
header_count += 1;
|
||||
}
|
||||
|
||||
if (!override_accept_header and header_count < request_headers_buf.len) {
|
||||
if (!override_accept_header) {
|
||||
request_headers_buf[header_count] = accept_header;
|
||||
header_count += 1;
|
||||
}
|
||||
|
||||
if (!override_host_header and header_count < request_headers_buf.len) {
|
||||
if (!override_host_header) {
|
||||
request_headers_buf[header_count] = .{
|
||||
.name = host_header_name,
|
||||
.value = this.url.host,
|
||||
@@ -726,33 +711,31 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
|
||||
header_count += 1;
|
||||
}
|
||||
|
||||
if (!override_accept_encoding and !this.flags.disable_decompression and header_count < request_headers_buf.len) {
|
||||
if (!override_accept_encoding and !this.flags.disable_decompression) {
|
||||
request_headers_buf[header_count] = accept_encoding_header;
|
||||
|
||||
header_count += 1;
|
||||
}
|
||||
|
||||
if (header_count < request_headers_buf.len) {
|
||||
if (body_len > 0 or this.method.hasRequestBody()) {
|
||||
if (this.flags.is_streaming_request_body) {
|
||||
if (add_transfer_encoding and this.flags.upgrade_state == .none) {
|
||||
request_headers_buf[header_count] = chunked_encoded_header;
|
||||
header_count += 1;
|
||||
}
|
||||
} else {
|
||||
request_headers_buf[header_count] = .{
|
||||
.name = content_length_header_name,
|
||||
.value = std.fmt.bufPrint(&this.request_content_len_buf, "{d}", .{body_len}) catch "0",
|
||||
};
|
||||
if (body_len > 0 or this.method.hasRequestBody()) {
|
||||
if (this.flags.is_streaming_request_body) {
|
||||
if (add_transfer_encoding and this.flags.upgrade_state == .none) {
|
||||
request_headers_buf[header_count] = chunked_encoded_header;
|
||||
header_count += 1;
|
||||
}
|
||||
} else if (original_content_length) |content_length| {
|
||||
} else {
|
||||
request_headers_buf[header_count] = .{
|
||||
.name = content_length_header_name,
|
||||
.value = content_length,
|
||||
.value = std.fmt.bufPrint(&this.request_content_len_buf, "{d}", .{body_len}) catch "0",
|
||||
};
|
||||
header_count += 1;
|
||||
}
|
||||
} else if (original_content_length) |content_length| {
|
||||
request_headers_buf[header_count] = .{
|
||||
.name = content_length_header_name,
|
||||
.value = content_length,
|
||||
};
|
||||
header_count += 1;
|
||||
}
|
||||
|
||||
return picohttp.Request{
|
||||
|
||||
@@ -43,6 +43,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
state: State = .initializing,
|
||||
subprotocols: bun.StringSet,
|
||||
|
||||
/// Expected Sec-WebSocket-Accept value for RFC 6455 handshake validation.
|
||||
/// This is SHA-1(Sec-WebSocket-Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") base64-encoded (always 28 bytes).
|
||||
expected_accept: [28]u8,
|
||||
|
||||
/// Proxy state (null when not using proxy)
|
||||
proxy: ?WebSocketProxy = null,
|
||||
|
||||
@@ -133,7 +137,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
}
|
||||
}
|
||||
|
||||
const body = buildRequestBody(
|
||||
const build_result = buildRequestBody(
|
||||
vm,
|
||||
pathname,
|
||||
ssl,
|
||||
@@ -143,6 +147,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
extra_headers,
|
||||
if (target_authorization) |auth| auth.slice() else null,
|
||||
) catch return null;
|
||||
const body = build_result.body;
|
||||
|
||||
// Build proxy state if using proxy
|
||||
// The CONNECT request is built using local variables for proxy_authorization and proxy_headers
|
||||
@@ -209,6 +214,7 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
.input_body_buf = if (using_proxy) connect_request else body,
|
||||
.state = .initializing,
|
||||
.proxy = proxy_state,
|
||||
.expected_accept = build_result.expected_accept,
|
||||
.subprotocols = brk: {
|
||||
var subprotocols = bun.StringSet.init(bun.default_allocator);
|
||||
var it = bun.http.HeaderValueIterator.init(protocol_for_subprotocols.slice());
|
||||
@@ -923,7 +929,10 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: check websocket_accept_header.value
|
||||
if (!strings.eql(websocket_accept_header.value, &this.expected_accept)) {
|
||||
this.terminate(ErrorCode.mismatch_websocket_accept_header);
|
||||
return;
|
||||
}
|
||||
|
||||
const overflow_len = remain_buf.len;
|
||||
var overflow: []u8 = &.{};
|
||||
@@ -1165,6 +1174,26 @@ fn buildConnectRequest(
|
||||
return buf.toOwnedSlice();
|
||||
}
|
||||
|
||||
const BuildRequestResult = struct {
|
||||
body: []u8,
|
||||
expected_accept: [28]u8,
|
||||
};
|
||||
|
||||
/// Compute the expected Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2:
|
||||
/// Base64(SHA-1(Sec-WebSocket-Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
||||
fn computeExpectedAccept(key: []const u8) [28]u8 {
|
||||
const websocket_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
||||
var hasher = bun.sha.SHA1.init();
|
||||
defer hasher.deinit();
|
||||
hasher.update(key);
|
||||
hasher.update(websocket_guid);
|
||||
var sha1_digest: bun.sha.SHA1.Digest = .{0} ** bun.sha.SHA1.digest;
|
||||
hasher.final(&sha1_digest);
|
||||
var result: [28]u8 = .{0} ** 28;
|
||||
_ = bun.base64.encode(&result, &sha1_digest);
|
||||
return result;
|
||||
}
|
||||
|
||||
fn buildRequestBody(
|
||||
vm: *jsc.VirtualMachine,
|
||||
pathname: *const jsc.ZigString,
|
||||
@@ -1174,7 +1203,7 @@ fn buildRequestBody(
|
||||
client_protocol: *const jsc.ZigString,
|
||||
extra_headers: NonUTF8Headers,
|
||||
target_authorization: ?[]const u8,
|
||||
) std.mem.Allocator.Error![]u8 {
|
||||
) std.mem.Allocator.Error!BuildRequestResult {
|
||||
const allocator = vm.allocator;
|
||||
|
||||
// Check for user overrides
|
||||
@@ -1221,6 +1250,9 @@ fn buildRequestBody(
|
||||
// Generate a new key if user key is invalid or not provided
|
||||
break :blk std.base64.standard.Encoder.encode(&encoded_buf, &vm.rareData().nextUUID().bytes);
|
||||
};
|
||||
|
||||
const expected_accept = computeExpectedAccept(key);
|
||||
|
||||
const protocol = if (user_protocol) |p| p.slice() else client_protocol.slice();
|
||||
|
||||
const pathname_ = pathname.toSlice(allocator);
|
||||
@@ -1273,7 +1305,26 @@ fn buildRequestBody(
|
||||
|
||||
// Build request with user overrides
|
||||
if (user_host) |h| {
|
||||
return try std.fmt.allocPrint(
|
||||
return .{
|
||||
.body = try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"GET {s} HTTP/1.1\r\n" ++
|
||||
"Host: {f}\r\n" ++
|
||||
"Connection: Upgrade\r\n" ++
|
||||
"Upgrade: websocket\r\n" ++
|
||||
"Sec-WebSocket-Version: 13\r\n" ++
|
||||
"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" ++
|
||||
"{f}" ++
|
||||
"{s}" ++
|
||||
"\r\n",
|
||||
.{ pathname_.slice(), h, pico_headers, extra_headers_buf.items },
|
||||
),
|
||||
.expected_accept = expected_accept,
|
||||
};
|
||||
}
|
||||
|
||||
return .{
|
||||
.body = try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"GET {s} HTTP/1.1\r\n" ++
|
||||
"Host: {f}\r\n" ++
|
||||
@@ -1284,23 +1335,10 @@ fn buildRequestBody(
|
||||
"{f}" ++
|
||||
"{s}" ++
|
||||
"\r\n",
|
||||
.{ pathname_.slice(), h, pico_headers, extra_headers_buf.items },
|
||||
);
|
||||
}
|
||||
|
||||
return try std.fmt.allocPrint(
|
||||
allocator,
|
||||
"GET {s} HTTP/1.1\r\n" ++
|
||||
"Host: {f}\r\n" ++
|
||||
"Connection: Upgrade\r\n" ++
|
||||
"Upgrade: websocket\r\n" ++
|
||||
"Sec-WebSocket-Version: 13\r\n" ++
|
||||
"Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" ++
|
||||
"{f}" ++
|
||||
"{s}" ++
|
||||
"\r\n",
|
||||
.{ pathname_.slice(), host_fmt, pico_headers, extra_headers_buf.items },
|
||||
);
|
||||
.{ pathname_.slice(), host_fmt, pico_headers, extra_headers_buf.items },
|
||||
),
|
||||
.expected_accept = expected_accept,
|
||||
};
|
||||
}
|
||||
|
||||
const log = Output.scoped(.WebSocketUpgradeClient, .visible);
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import { describe, expect, test } from "bun:test";
|
||||
import { once } from "node:events";
|
||||
import { createServer } from "node:net";
|
||||
|
||||
describe("fetch with many headers", () => {
|
||||
test("should not crash or corrupt memory with more than 256 headers", async () => {
|
||||
// Use a raw TCP server to avoid uws header count limits on the server side.
|
||||
// We just need to verify that the client sends the request without crashing.
|
||||
await using server = createServer(socket => {
|
||||
let data = "";
|
||||
socket.on("data", (chunk: Buffer) => {
|
||||
data += chunk.toString();
|
||||
// Wait for the end of HTTP headers (double CRLF)
|
||||
if (data.includes("\r\n\r\n")) {
|
||||
// Count headers (lines between the request line and the blank line)
|
||||
const headerSection = data.split("\r\n\r\n")[0];
|
||||
const lines = headerSection.split("\r\n");
|
||||
// First line is the request line (GET / HTTP/1.1), rest are headers
|
||||
const headerCount = lines.length - 1;
|
||||
|
||||
const body = String(headerCount);
|
||||
const response = ["HTTP/1.1 200 OK", `Content-Length: ${body.length}`, "Connection: close", "", body].join(
|
||||
"\r\n",
|
||||
);
|
||||
|
||||
socket.write(response);
|
||||
socket.end();
|
||||
}
|
||||
});
|
||||
}).listen(0);
|
||||
await once(server, "listening");
|
||||
|
||||
const port = (server.address() as any).port;
|
||||
|
||||
// Build 300 unique custom headers (exceeds the 256-entry static buffer)
|
||||
const headers = new Headers();
|
||||
const headerCount = 300;
|
||||
for (let i = 0; i < headerCount; i++) {
|
||||
headers.set(`x-custom-${i}`, `value-${i}`);
|
||||
}
|
||||
|
||||
const res = await fetch(`http://localhost:${port}/`, { headers });
|
||||
const receivedCount = parseInt(await res.text(), 10);
|
||||
|
||||
expect(res.status).toBe(200);
|
||||
// The server should receive our custom headers plus default ones
|
||||
// (host, connection, user-agent, accept, accept-encoding = 5 extra)
|
||||
expect(receivedCount).toBeGreaterThanOrEqual(headerCount);
|
||||
});
|
||||
|
||||
test("should handle exactly 256 user headers without issues", async () => {
|
||||
await using server = createServer(socket => {
|
||||
let data = "";
|
||||
socket.on("data", (chunk: Buffer) => {
|
||||
data += chunk.toString();
|
||||
if (data.includes("\r\n\r\n")) {
|
||||
const headerSection = data.split("\r\n\r\n")[0];
|
||||
const lines = headerSection.split("\r\n");
|
||||
const headerCount = lines.length - 1;
|
||||
|
||||
const body = String(headerCount);
|
||||
const response = ["HTTP/1.1 200 OK", `Content-Length: ${body.length}`, "Connection: close", "", body].join(
|
||||
"\r\n",
|
||||
);
|
||||
|
||||
socket.write(response);
|
||||
socket.end();
|
||||
}
|
||||
});
|
||||
}).listen(0);
|
||||
await once(server, "listening");
|
||||
|
||||
const port = (server.address() as any).port;
|
||||
|
||||
const headers = new Headers();
|
||||
const headerCount = 256;
|
||||
for (let i = 0; i < headerCount; i++) {
|
||||
headers.set(`x-custom-${i}`, `value-${i}`);
|
||||
}
|
||||
|
||||
const res = await fetch(`http://localhost:${port}/`, { headers });
|
||||
const receivedCount = parseInt(await res.text(), 10);
|
||||
|
||||
expect(res.status).toBe(200);
|
||||
expect(receivedCount).toBeGreaterThanOrEqual(headerCount);
|
||||
});
|
||||
});
|
||||
137
test/js/web/websocket/websocket-accept-validation.test.ts
Normal file
137
test/js/web/websocket/websocket-accept-validation.test.ts
Normal file
@@ -0,0 +1,137 @@
|
||||
import { describe, expect, it, mock } from "bun:test";
|
||||
import crypto from "node:crypto";
|
||||
import net from "node:net";
|
||||
|
||||
describe("WebSocket Sec-WebSocket-Accept validation (RFC 6455 Section 4.1)", () => {
|
||||
function computeAcceptKey(websocketKey: string): string {
|
||||
return crypto
|
||||
.createHash("sha1")
|
||||
.update(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
.digest("base64");
|
||||
}
|
||||
|
||||
async function createFakeServer(
|
||||
getAcceptKey: (clientKey: string) => string,
|
||||
): Promise<{ port: number; [Symbol.asyncDispose]: () => Promise<void> }> {
|
||||
const server = net.createServer();
|
||||
let port: number;
|
||||
|
||||
await new Promise<void>(resolve => {
|
||||
server.listen(0, () => {
|
||||
port = (server.address() as any).port;
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
server.on("connection", socket => {
|
||||
let requestData = "";
|
||||
|
||||
socket.on("data", data => {
|
||||
requestData += data.toString();
|
||||
|
||||
if (requestData.includes("\r\n\r\n")) {
|
||||
const lines = requestData.split("\r\n");
|
||||
let websocketKey = "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("Sec-WebSocket-Key:")) {
|
||||
websocketKey = line.split(":")[1].trim();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const acceptKey = getAcceptKey(websocketKey);
|
||||
|
||||
const response = [
|
||||
"HTTP/1.1 101 Switching Protocols",
|
||||
"Upgrade: websocket",
|
||||
"Connection: Upgrade",
|
||||
`Sec-WebSocket-Accept: ${acceptKey}`,
|
||||
"\r\n",
|
||||
].join("\r\n");
|
||||
|
||||
socket.write(response);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return {
|
||||
port: port!,
|
||||
[Symbol.asyncDispose]: async () => {
|
||||
server.close();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
it("should accept valid Sec-WebSocket-Accept header", async () => {
|
||||
await using server = await createFakeServer(key => computeAcceptKey(key));
|
||||
|
||||
const { promise, resolve, reject } = Promise.withResolvers();
|
||||
const ws = new WebSocket(`ws://localhost:${server.port}`);
|
||||
|
||||
ws.onopen = () => resolve(undefined);
|
||||
ws.onerror = () => reject(new Error("connection failed"));
|
||||
|
||||
await promise;
|
||||
ws.close();
|
||||
});
|
||||
|
||||
it("should reject invalid Sec-WebSocket-Accept header", async () => {
|
||||
// Server returns a completely wrong accept key
|
||||
await using server = await createFakeServer(_key => "dGhlIHNhbXBsZSBub25jZQ==");
|
||||
|
||||
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
|
||||
const onopenMock = mock(() => {});
|
||||
|
||||
const ws = new WebSocket(`ws://localhost:${server.port}`);
|
||||
ws.onopen = onopenMock;
|
||||
ws.onclose = event => {
|
||||
resolve({ code: event.code, reason: event.reason });
|
||||
};
|
||||
|
||||
const result = await promise;
|
||||
expect(onopenMock).not.toHaveBeenCalled();
|
||||
expect(result.code).toBe(1002);
|
||||
expect(result.reason).toBe("Mismatch websocket accept header");
|
||||
});
|
||||
|
||||
it("should reject empty Sec-WebSocket-Accept value", async () => {
|
||||
// Server returns an empty accept key
|
||||
await using server = await createFakeServer(_key => "");
|
||||
|
||||
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
|
||||
const onopenMock = mock(() => {});
|
||||
|
||||
const ws = new WebSocket(`ws://localhost:${server.port}`);
|
||||
ws.onopen = onopenMock;
|
||||
ws.onclose = event => {
|
||||
resolve({ code: event.code, reason: event.reason });
|
||||
};
|
||||
|
||||
const result = await promise;
|
||||
expect(onopenMock).not.toHaveBeenCalled();
|
||||
// Empty value should be caught by either the missing header check or the accept validation
|
||||
expect(result.code).toBe(1002);
|
||||
});
|
||||
|
||||
it("should reject Sec-WebSocket-Accept with wrong key computation", async () => {
|
||||
// Server computes accept from a different key (simulating MitM)
|
||||
await using server = await createFakeServer(_key => {
|
||||
// Compute valid accept but for a different (attacker-chosen) key
|
||||
return computeAcceptKey("AAAAAAAAAAAAAAAAAAAAAA==");
|
||||
});
|
||||
|
||||
const { promise, resolve } = Promise.withResolvers<{ code: number; reason: string }>();
|
||||
const onopenMock = mock(() => {});
|
||||
|
||||
const ws = new WebSocket(`ws://localhost:${server.port}`);
|
||||
ws.onopen = onopenMock;
|
||||
ws.onclose = event => {
|
||||
resolve({ code: event.code, reason: event.reason });
|
||||
};
|
||||
|
||||
const result = await promise;
|
||||
expect(onopenMock).not.toHaveBeenCalled();
|
||||
expect(result.code).toBe(1002);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user