From 345666b19413dc0b58cc511c87064fb658e1322a Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Wed, 3 Sep 2025 19:12:26 -0700 Subject: [PATCH] ok --- src/bun.js/webcore/fetch.zig | 46 ++++--- src/http.zig | 136 +++++++++++++-------- src/http/Signals.zig | 5 +- test/js/web/fetch/fetch.upgrade.test.ts | 63 ++++++++++ test/js/web/fetch/websocket.helpers.ts | 156 ++++++++++++++++++++++++ 5 files changed, 339 insertions(+), 67 deletions(-) create mode 100644 test/js/web/fetch/fetch.upgrade.test.ts create mode 100644 test/js/web/fetch/websocket.helpers.ts diff --git a/src/bun.js/webcore/fetch.zig b/src/bun.js/webcore/fetch.zig index a0c8d7dc90..c7d421b769 100644 --- a/src/bun.js/webcore/fetch.zig +++ b/src/bun.js/webcore/fetch.zig @@ -108,6 +108,7 @@ pub const FetchTasklet = struct { // custom checkServerIdentity check_server_identity: jsc.Strong.Optional = .empty, reject_unauthorized: bool = true, + is_websocket_upgrade: bool = false, // Custom Hostname hostname: ?[]u8 = null, is_waiting_body: bool = false, @@ -1069,6 +1070,7 @@ pub const FetchTasklet = struct { .memory_reporter = fetch_options.memory_reporter, .check_server_identity = fetch_options.check_server_identity, .reject_unauthorized = fetch_options.reject_unauthorized, + .is_websocket_upgrade = fetch_options.is_websocket_upgrade, }; fetch_tasklet.signals = fetch_tasklet.signal_store.to(); @@ -1201,19 +1203,23 @@ pub const FetchTasklet = struct { // dont have backpressure so we will schedule the data to be written // if we have backpressure the onWritable will drain the buffer needs_schedule = stream_buffer.isEmpty(); - //16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n - var formated_size_buffer: [18]u8 = undefined; - const formated_size = std.fmt.bufPrint( - formated_size_buffer[0..], - "{x}\r\n", - .{data.len}, - ) catch |err| switch (err) { - error.NoSpaceLeft => unreachable, - }; - bun.handleOom(stream_buffer.ensureUnusedCapacity(formated_size.len + data.len + 2)); - stream_buffer.writeAssumeCapacity(formated_size); - stream_buffer.writeAssumeCapacity(data); - stream_buffer.writeAssumeCapacity("\r\n"); + if (this.is_websocket_upgrade) { + bun.handleOom(stream_buffer.write(data)); + } else { + //16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n + var formated_size_buffer: [18]u8 = undefined; + const formated_size = std.fmt.bufPrint( + formated_size_buffer[0..], + "{x}\r\n", + .{data.len}, + ) catch |err| switch (err) { + error.NoSpaceLeft => unreachable, + }; + bun.handleOom(stream_buffer.ensureUnusedCapacity(formated_size.len + data.len + 2)); + stream_buffer.writeAssumeCapacity(formated_size); + stream_buffer.writeAssumeCapacity(data); + stream_buffer.writeAssumeCapacity("\r\n"); + } // pause the stream if we hit the high water mark return stream_buffer.size() >= highWaterMark; @@ -1271,6 +1277,7 @@ pub const FetchTasklet = struct { check_server_identity: jsc.Strong.Optional = .empty, unix_socket_path: ZigString.Slice, ssl_config: ?*SSLConfig = null, + is_websocket_upgrade: bool = false, }; pub fn queue( @@ -1494,6 +1501,7 @@ pub fn Bun__fetch_( var memory_reporter = bun.handleOom(bun.default_allocator.create(bun.MemoryReportingAllocator)); // used to clean up dynamically allocated memory on error (a poor man's errdefer) var is_error = false; + var is_websocket_upgrade = false; var allocator = memory_reporter.wrap(bun.default_allocator); errdefer bun.default_allocator.destroy(memory_reporter); defer { @@ -2198,6 +2206,15 @@ pub fn Bun__fetch_( } } + if (headers_.fastGet(bun.webcore.FetchHeaders.HTTPHeaderName.Upgrade)) |_upgrade| { + const upgrade = _upgrade.toSlice(bun.default_allocator); + defer upgrade.deinit(); + const slice = upgrade.slice(); + if (bun.strings.eqlComptime(slice, "websocket")) { + is_websocket_upgrade = true; + } + } + break :extract_headers Headers.from(headers_, allocator, .{ .body = body.getAnyBlob() }) catch |err| bun.handleOom(err); } @@ -2333,7 +2350,7 @@ pub fn Bun__fetch_( } } - if (!method.hasRequestBody() and body.hasBody()) { + if (!method.hasRequestBody() and body.hasBody() and !is_websocket_upgrade) { const err = globalThis.toTypeError(.INVALID_ARG_VALUE, fetch_error_unexpected_body, .{}); is_error = true; return JSPromise.dangerouslyCreateRejectedPromiseValueWithoutNotifyingVM(globalThis, err); @@ -2651,6 +2668,7 @@ pub fn Bun__fetch_( .ssl_config = ssl_config, .hostname = hostname, .memory_reporter = memory_reporter, + .is_websocket_upgrade = is_websocket_upgrade, .check_server_identity = if (check_server_identity.isEmptyOrUndefinedOrNull()) .empty else .create(check_server_identity, globalThis), .unix_socket_path = unix_socket_path, }, diff --git a/src/http.zig b/src/http.zig index f4119f2a49..1d8eef5037 100644 --- a/src/http.zig +++ b/src/http.zig @@ -405,7 +405,9 @@ pub const Flags = packed struct(u16) { is_preconnect_only: bool = false, is_streaming_request_body: bool = false, defer_fail_until_connecting_is_complete: bool = false, - _padding: u5 = 0, + is_websockets: bool = false, + websocket_upgraded: bool = false, + _padding: u3 = 0, }; // TODO: reduce the size of this struct @@ -592,6 +594,11 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { hashHeaderConst("Accept-Encoding") => { override_accept_encoding = true; }, + hashHeaderConst("Upgrade") => { + if (std.ascii.eqlIgnoreCase(this.headerStr(header_values[i]), "websocket")) { + this.flags.is_websockets = true; + } + }, hashHeaderConst(chunked_encoded_header.name) => { // We don't want to override chunked encoding header if it was set by the user add_transfer_encoding = false; @@ -1019,11 +1026,14 @@ fn writeToStreamUsingBuffer(this: *HTTPClient, comptime is_ssl: bool, socket: Ne // no data to send so we are done return false; } - pub fn writeToStream(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket, data: []const u8) void { log("flushStream", .{}); var stream = &this.state.original_request_body.stream; const stream_buffer = stream.buffer orelse return; + if (this.flags.is_websockets and !this.flags.websocket_upgraded) { + // cannot drain yet, websocket is waiting for upgrade + return; + } const buffer = stream_buffer.acquire(); const wasEmpty = buffer.isEmpty() and data.len == 0; if (wasEmpty and stream.ended) { @@ -1324,56 +1334,78 @@ pub fn handleOnDataHeaders( ) void { log("handleOnDataHeaders", .{}); var to_read = incoming_data; - var amount_read: usize = 0; - var needs_move = true; - if (this.state.response_message_buffer.list.items.len > 0) { - // this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating - bun.handleOom(this.state.response_message_buffer.appendSliceExact(incoming_data)); - to_read = this.state.response_message_buffer.list.items; - needs_move = false; - } - // we reset the pending_response each time wich means that on parse error this will be always be empty - this.state.pending_response = picohttp.Response{}; - - // minimal http/1.1 request size is 16 bytes without headers and 26 with Host header - // if is less than 16 will always be a ShortRead - if (to_read.len < 16) { - log("handleShortRead", .{}); - this.handleShortRead(is_ssl, incoming_data, socket, needs_move); - return; - } - - var response = picohttp.Response.parseParts( - to_read, - &shared_response_headers_buf, - &amount_read, - ) catch |err| { - switch (err) { - error.ShortRead => { - this.handleShortRead(is_ssl, incoming_data, socket, needs_move); - }, - else => { - this.closeAndFail(err, is_ssl, socket); - }, + while (true) { + var amount_read: usize = 0; + var needs_move = true; + if (this.state.response_message_buffer.list.items.len > 0) { + // this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating + bun.handleOom(this.state.response_message_buffer.appendSliceExact(incoming_data)); + to_read = this.state.response_message_buffer.list.items; + needs_move = false; } - return; - }; - // we save the successful parsed response - this.state.pending_response = response; + // we reset the pending_response each time wich means that on parse error this will be always be empty + this.state.pending_response = picohttp.Response{}; - const body_buf = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..]; - // handle the case where we have a 100 Continue - if (response.status_code >= 100 and response.status_code < 200) { - log("information headers", .{}); - // we still can have the 200 OK in the same buffer sometimes - if (body_buf.len > 0) { - log("information headers with body", .{}); - this.onData(is_ssl, body_buf, ctx, socket); + // minimal http/1.1 request size is 16 bytes without headers and 26 with Host header + // if is less than 16 will always be a ShortRead + if (to_read.len < 16) { + log("handleShortRead", .{}); + this.handleShortRead(is_ssl, incoming_data, socket, needs_move); + return; } - return; + + const response = picohttp.Response.parseParts( + to_read, + &shared_response_headers_buf, + &amount_read, + ) catch |err| { + switch (err) { + error.ShortRead => { + this.handleShortRead(is_ssl, incoming_data, socket, needs_move); + }, + else => { + this.closeAndFail(err, is_ssl, socket); + }, + } + return; + }; + + // we save the successful parsed response + this.state.pending_response = response; + + to_read = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..]; + + if (response.status_code == 101) { + if (!this.flags.is_websockets) { + // we cannot upgrade to websocket because the client did not request it! + this.closeAndFail(error.UnrequestedUpgrade, is_ssl, socket); + return; + } + // special case for websocket upgrade + this.flags.is_websockets = true; + this.flags.websocket_upgraded = true; + if (this.signals.upgraded) |upgraded| { + upgraded.store(true, .monotonic); + } + // start draining the request body + this.flushStream(is_ssl, socket); + break; + } + + // handle the case where we have a 100 Continue + if (response.status_code >= 100 and response.status_code < 200) { + log("information headers", .{}); + // we still can have the 200 OK in the same buffer sometimes + // 1XX responses MUST NOT include a message-body, therefore we need to continue parsing + + continue; + } + + break; } + var response = this.state.pending_response.?; const should_continue = this.handleResponseMetadata( &response, ) catch |err| { @@ -1409,14 +1441,14 @@ pub fn handleOnDataHeaders( if (this.flags.proxy_tunneling and this.proxy_tunnel == null) { // we are proxing we dont need to cloneMetadata yet - this.startProxyHandshake(is_ssl, socket, body_buf); + this.startProxyHandshake(is_ssl, socket, to_read); return; } // we have body data incoming so we clone metadata and keep going this.cloneMetadata(); - if (body_buf.len == 0) { + if (to_read.len == 0) { // no body data yet, but we can report the headers if (this.signals.get(.header_progress)) { this.progressUpdate(is_ssl, ctx, socket); @@ -1426,7 +1458,7 @@ pub fn handleOnDataHeaders( if (this.state.response_stage == .body) { { - const report_progress = this.handleResponseBody(body_buf, true) catch |err| { + const report_progress = this.handleResponseBody(to_read, true) catch |err| { this.closeAndFail(err, is_ssl, socket); return; }; @@ -1439,7 +1471,7 @@ pub fn handleOnDataHeaders( } else if (this.state.response_stage == .body_chunk) { this.setTimeout(socket, 5); { - const report_progress = this.handleResponseBodyChunkedEncoding(body_buf) catch |err| { + const report_progress = this.handleResponseBodyChunkedEncoding(to_read) catch |err| { this.closeAndFail(err, is_ssl, socket); return; }; @@ -2149,7 +2181,7 @@ pub fn handleResponseMetadata( // [...] cannot contain a message body or trailer section. // therefore in these cases set content-length to 0, so the response body is always ignored // and is not waited for (which could cause a timeout) - if ((response.status_code >= 100 and response.status_code < 200) or response.status_code == 204 or response.status_code == 304) { + if ((response.status_code >= 100 and response.status_code < 200 and response.status_code != 101) or response.status_code == 204 or response.status_code == 304) { this.state.content_length = 0; } @@ -2416,7 +2448,7 @@ pub fn handleResponseMetadata( log("handleResponseMetadata: content_length is null and transfer_encoding {}", .{this.state.transfer_encoding}); } - if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events)) { + if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.websocket_upgraded)) { return ShouldContinue.continue_streaming; } else { return ShouldContinue.finished; diff --git a/src/http/Signals.zig b/src/http/Signals.zig index 78531e7f41..bf8d1d8360 100644 --- a/src/http/Signals.zig +++ b/src/http/Signals.zig @@ -4,8 +4,9 @@ header_progress: ?*std.atomic.Value(bool) = null, body_streaming: ?*std.atomic.Value(bool) = null, aborted: ?*std.atomic.Value(bool) = null, cert_errors: ?*std.atomic.Value(bool) = null, +upgraded: ?*std.atomic.Value(bool) = null, pub fn isEmpty(this: *const Signals) bool { - return this.aborted == null and this.body_streaming == null and this.header_progress == null and this.cert_errors == null; + return this.aborted == null and this.body_streaming == null and this.header_progress == null and this.cert_errors == null and this.upgraded == null; } pub const Store = struct { @@ -13,12 +14,14 @@ pub const Store = struct { body_streaming: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), aborted: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), cert_errors: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), + upgraded: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), pub fn to(this: *Store) Signals { return .{ .header_progress = &this.header_progress, .body_streaming = &this.body_streaming, .aborted = &this.aborted, .cert_errors = &this.cert_errors, + .upgraded = &this.upgraded, }; } }; diff --git a/test/js/web/fetch/fetch.upgrade.test.ts b/test/js/web/fetch/fetch.upgrade.test.ts new file mode 100644 index 0000000000..243bea7762 --- /dev/null +++ b/test/js/web/fetch/fetch.upgrade.test.ts @@ -0,0 +1,63 @@ +import { describe, expect, test } from "bun:test"; +import { encodeTextFrame, encodeCloseFrame, decodeFrames, upgradeHeaders } from "./websocket.helpers"; + +describe("fetch upgrade", () => { + test("should upgrade to websocket", async () => { + const serverMessages: string[] = []; + using server = Bun.serve({ + port: 3000, + fetch(req) { + if (server.upgrade(req)) return; + return new Response("Hello World"); + }, + websocket: { + open(ws) { + ws.send("Hello World"); + }, + message(ws, message) { + serverMessages.push(message as string); + }, + close(ws) { + serverMessages.push("close"); + }, + }, + }); + const res = await fetch(server.url, { + method: "GET", + headers: upgradeHeaders(), + async *body() { + yield encodeTextFrame("hello"); + yield encodeTextFrame("world"); + yield encodeTextFrame("bye"); + yield encodeCloseFrame(); + }, + }); + expect(res.status).toBe(101); + expect(res.headers.get("upgrade")).toBe("websocket"); + expect(res.headers.get("sec-websocket-accept")).toBeString(); + expect(res.headers.get("connection")).toBe("Upgrade"); + + const clientMessages: string[] = []; + const { promise, resolve } = Promise.withResolvers(); + const reader = res.body!.getReader(); + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + for (const msg of decodeFrames(Buffer.from(value))) { + if (typeof msg === "string") { + clientMessages.push(msg); + } else { + clientMessages.push(msg.type); + } + + if (msg.type === "close") { + resolve(); + } + } + } + await promise; + expect(serverMessages).toEqual(["hello", "world", "bye", "close"]); + expect(clientMessages).toEqual(["Hello World", "close"]); + }); +}); diff --git a/test/js/web/fetch/websocket.helpers.ts b/test/js/web/fetch/websocket.helpers.ts new file mode 100644 index 0000000000..6425735039 --- /dev/null +++ b/test/js/web/fetch/websocket.helpers.ts @@ -0,0 +1,156 @@ +import { createHash, randomBytes } from "node:crypto"; + +// RFC 6455 magic GUID +const WS_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +function makeKey() { + return randomBytes(16).toString("base64"); +} + +function acceptFor(key) { + return createHash("sha1") + .update(key + WS_GUID) + .digest("base64"); +} + +export function encodeCloseFrame(code = 1000, reason = "") { + const reasonBuf = Buffer.from(reason, "utf8"); + const payloadLen = 2 + reasonBuf.length; // 2 bytes for code + reason + const header = []; + let headerLen = 2; + if (payloadLen < 126) { + // masked bit (0x80) + length + header.push(0x88, 0x80 | payloadLen); + } else if (payloadLen <= 0xffff) { + headerLen += 2; + header.push(0x88, 0x80 | 126, payloadLen >> 8, payloadLen & 0xff); + } else { + throw new Error("Close reason too long"); + } + + const mask = randomBytes(4); + const buf = Buffer.alloc(headerLen + 4 + payloadLen); + Buffer.from(header).copy(buf, 0); + mask.copy(buf, headerLen); + + // write code + reason + const unmasked = Buffer.alloc(payloadLen); + unmasked.writeUInt16BE(code, 0); + reasonBuf.copy(unmasked, 2); + + // apply mask + for (let i = 0; i < payloadLen; i++) { + buf[headerLen + 4 + i] = unmasked[i] ^ mask[i & 3]; + } + + return buf; +} +export function* decodeFrames(buffer) { + let i = 0; + while (i + 2 <= buffer.length) { + const b0 = buffer[i++]; + const b1 = buffer[i++]; + const fin = (b0 & 0x80) !== 0; + const opcode = b0 & 0x0f; + const masked = (b1 & 0x80) !== 0; + let len = b1 & 0x7f; + + if (len === 126) { + if (i + 2 > buffer.length) break; + len = buffer.readUInt16BE(i); + i += 2; + } else if (len === 127) { + if (i + 8 > buffer.length) break; + const big = buffer.readBigUInt64BE(i); + i += 8; + if (big > BigInt(Number.MAX_SAFE_INTEGER)) throw new Error("frame too large"); + len = Number(big); + } + + let mask; + if (masked) { + if (i + 4 > buffer.length) break; + mask = buffer.subarray(i, i + 4); + i += 4; + } + + if (i + len > buffer.length) break; + let payload = buffer.subarray(i, i + len); + i += len; + + if (masked && mask) { + const unmasked = Buffer.alloc(len); + for (let j = 0; j < len; j++) unmasked[j] = payload[j] ^ mask[j & 3]; + payload = unmasked; + } + + if (!fin) throw new Error("fragmentation not supported in this demo"); + if (opcode === 0x1) { + // text + yield payload.toString("utf8"); + } else if (opcode === 0x8) { + // CLOSE + yield { type: "close" }; + return; + } else if (opcode === 0x9) { + // PING -> respond with PONG if you implement writes here + yield { type: "ping", data: payload }; + } else if (opcode === 0xa) { + // PONG + yield { type: "pong", data: payload }; + } else { + // ignore other opcodes for brevity + } + } +} + +// Encode a single unfragmented TEXT frame (client -> server must be masked) +export function encodeTextFrame(str) { + const payload = Buffer.from(str, "utf8"); + const len = payload.length; + + let headerLen = 2; + if (len >= 126 && len <= 0xffff) headerLen += 2; + else if (len > 0xffff) headerLen += 8; + const maskKeyLen = 4; + + const buf = Buffer.alloc(headerLen + maskKeyLen + len); + // FIN=1, RSV=0, opcode=0x1 (text) + buf[0] = 0x80 | 0x1; + + // Set masked bit and length field(s) + let offset = 1; + if (len < 126) { + buf[offset++] = 0x80 | len; // mask bit + length + } else if (len <= 0xffff) { + buf[offset++] = 0x80 | 126; + buf.writeUInt16BE(len, offset); + offset += 2; + } else { + buf[offset++] = 0x80 | 127; + buf.writeBigUInt64BE(BigInt(len), offset); + offset += 8; + } + + // Mask key + const mask = randomBytes(4); + mask.copy(buf, offset); + offset += 4; + + // Mask the payload + for (let i = 0; i < len; i++) { + buf[offset + i] = payload[i] ^ mask[i & 3]; + } + + return buf; +} + +export function upgradeHeaders() { + const secWebSocketKey = makeKey(); + return { + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + "Sec-WebSocket-Key": secWebSocketKey, + }; +}