diff --git a/src/http.zig b/src/http.zig index df00a8038c..d383aa0f7d 100644 --- a/src/http.zig +++ b/src/http.zig @@ -23,6 +23,7 @@ var print_every_i: usize = 0; // we always rewrite the entire HTTP request when write() returns EAGAIN // so we can reuse this buffer var shared_request_headers_buf: [256]picohttp.Header = undefined; +var shared_request_headers_overflow: ?[]picohttp.Header = null; // this doesn't need to be stack memory because it is immediately cloned after use var shared_response_headers_buf: [256]picohttp.Header = undefined; @@ -605,7 +606,32 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { var header_entries = this.header_entries.slice(); const header_names = header_entries.items(.name); const header_values = header_entries.items(.value); - var request_headers_buf = &shared_request_headers_buf; + + // The maximum number of headers is the user-provided headers plus up to + // 6 extra headers that may be added below (Connection, User-Agent, + // Accept, Host, Accept-Encoding, Content-Length/Transfer-Encoding). + const max_headers = header_names.len + 6; + const static_buf_len = shared_request_headers_buf.len; + + // Use the static buffer for the common case, dynamically allocate for overflow. + // The overflow buffer is kept around for reuse to avoid repeated allocations. + var request_headers_buf: []picohttp.Header = if (max_headers <= static_buf_len) + &shared_request_headers_buf + else blk: { + if (shared_request_headers_overflow) |overflow| { + if (overflow.len >= max_headers) { + break :blk overflow; + } + bun.default_allocator.free(overflow); + shared_request_headers_overflow = null; + } + const buf = bun.default_allocator.alloc(picohttp.Header, max_headers) catch + // On allocation failure, fall back to the static buffer and + // truncate headers rather than writing out of bounds. + break :blk @as([]picohttp.Header, &shared_request_headers_buf); + shared_request_headers_overflow = buf; + break :blk buf; + }; var override_accept_encoding = false; var override_accept_header = false; @@ -667,43 +693,32 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { else => {}, } + if (header_count >= request_headers_buf.len) break; + request_headers_buf[header_count] = .{ .name = name, .value = this.headerStr(header_values[i]), }; - // header_name_hashes[header_count] = hash; - - // // ensure duplicate headers come after each other - // if (header_count > 2) { - // var head_i: usize = header_count - 1; - // while (head_i > 0) : (head_i -= 1) { - // if (header_name_hashes[head_i] == header_name_hashes[header_count]) { - // std.mem.swap(picohttp.Header, &header_name_hashes[header_count], &header_name_hashes[head_i + 1]); - // std.mem.swap(u64, &request_headers_buf[header_count], &request_headers_buf[head_i + 1]); - // break; - // } - // } - // } header_count += 1; } - if (!override_connection_header and !this.flags.disable_keepalive) { + if (!override_connection_header and !this.flags.disable_keepalive and header_count < request_headers_buf.len) { request_headers_buf[header_count] = connection_header; header_count += 1; } - if (!override_user_agent) { + if (!override_user_agent and header_count < request_headers_buf.len) { request_headers_buf[header_count] = getUserAgentHeader(); header_count += 1; } - if (!override_accept_header) { + if (!override_accept_header and header_count < request_headers_buf.len) { request_headers_buf[header_count] = accept_header; header_count += 1; } - if (!override_host_header) { + if (!override_host_header and header_count < request_headers_buf.len) { request_headers_buf[header_count] = .{ .name = host_header_name, .value = this.url.host, @@ -711,31 +726,33 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { header_count += 1; } - if (!override_accept_encoding and !this.flags.disable_decompression) { + if (!override_accept_encoding and !this.flags.disable_decompression and header_count < request_headers_buf.len) { request_headers_buf[header_count] = accept_encoding_header; header_count += 1; } - if (body_len > 0 or this.method.hasRequestBody()) { - if (this.flags.is_streaming_request_body) { - if (add_transfer_encoding and this.flags.upgrade_state == .none) { - request_headers_buf[header_count] = chunked_encoded_header; + if (header_count < request_headers_buf.len) { + if (body_len > 0 or this.method.hasRequestBody()) { + if (this.flags.is_streaming_request_body) { + if (add_transfer_encoding and this.flags.upgrade_state == .none) { + request_headers_buf[header_count] = chunked_encoded_header; + header_count += 1; + } + } else { + request_headers_buf[header_count] = .{ + .name = content_length_header_name, + .value = std.fmt.bufPrint(&this.request_content_len_buf, "{d}", .{body_len}) catch "0", + }; header_count += 1; } - } else { + } else if (original_content_length) |content_length| { request_headers_buf[header_count] = .{ .name = content_length_header_name, - .value = std.fmt.bufPrint(&this.request_content_len_buf, "{d}", .{body_len}) catch "0", + .value = content_length, }; header_count += 1; } - } else if (original_content_length) |content_length| { - request_headers_buf[header_count] = .{ - .name = content_length_header_name, - .value = content_length, - }; - header_count += 1; } return picohttp.Request{ diff --git a/test/js/web/fetch/fetch-header-overflow.test.ts b/test/js/web/fetch/fetch-header-overflow.test.ts new file mode 100644 index 0000000000..e5cef44164 --- /dev/null +++ b/test/js/web/fetch/fetch-header-overflow.test.ts @@ -0,0 +1,87 @@ +import { describe, expect, test } from "bun:test"; +import { once } from "node:events"; +import { createServer } from "node:net"; + +describe("fetch with many headers", () => { + test("should not crash or corrupt memory with more than 256 headers", async () => { + // Use a raw TCP server to avoid uws header count limits on the server side. + // We just need to verify that the client sends the request without crashing. + await using server = createServer(socket => { + let data = ""; + socket.on("data", (chunk: Buffer) => { + data += chunk.toString(); + // Wait for the end of HTTP headers (double CRLF) + if (data.includes("\r\n\r\n")) { + // Count headers (lines between the request line and the blank line) + const headerSection = data.split("\r\n\r\n")[0]; + const lines = headerSection.split("\r\n"); + // First line is the request line (GET / HTTP/1.1), rest are headers + const headerCount = lines.length - 1; + + const body = String(headerCount); + const response = ["HTTP/1.1 200 OK", `Content-Length: ${body.length}`, "Connection: close", "", body].join( + "\r\n", + ); + + socket.write(response); + socket.end(); + } + }); + }).listen(0); + await once(server, "listening"); + + const port = (server.address() as any).port; + + // Build 300 unique custom headers (exceeds the 256-entry static buffer) + const headers = new Headers(); + const headerCount = 300; + for (let i = 0; i < headerCount; i++) { + headers.set(`x-custom-${i}`, `value-${i}`); + } + + const res = await fetch(`http://localhost:${port}/`, { headers }); + const receivedCount = parseInt(await res.text(), 10); + + expect(res.status).toBe(200); + // The server should receive our custom headers plus default ones + // (host, connection, user-agent, accept, accept-encoding = 5 extra) + expect(receivedCount).toBeGreaterThanOrEqual(headerCount); + }); + + test("should handle exactly 256 user headers without issues", async () => { + await using server = createServer(socket => { + let data = ""; + socket.on("data", (chunk: Buffer) => { + data += chunk.toString(); + if (data.includes("\r\n\r\n")) { + const headerSection = data.split("\r\n\r\n")[0]; + const lines = headerSection.split("\r\n"); + const headerCount = lines.length - 1; + + const body = String(headerCount); + const response = ["HTTP/1.1 200 OK", `Content-Length: ${body.length}`, "Connection: close", "", body].join( + "\r\n", + ); + + socket.write(response); + socket.end(); + } + }); + }).listen(0); + await once(server, "listening"); + + const port = (server.address() as any).port; + + const headers = new Headers(); + const headerCount = 256; + for (let i = 0; i < headerCount; i++) { + headers.set(`x-custom-${i}`, `value-${i}`); + } + + const res = await fetch(`http://localhost:${port}/`, { headers }); + const receivedCount = parseInt(await res.text(), 10); + + expect(res.status).toBe(200); + expect(receivedCount).toBeGreaterThanOrEqual(headerCount); + }); +});