This commit is contained in:
Ciro Spaciari
2025-09-03 19:12:26 -07:00
parent ed9353f95e
commit 345666b194
5 changed files with 339 additions and 67 deletions

View File

@@ -108,6 +108,7 @@ pub const FetchTasklet = struct {
// custom checkServerIdentity
check_server_identity: jsc.Strong.Optional = .empty,
reject_unauthorized: bool = true,
is_websocket_upgrade: bool = false,
// Custom Hostname
hostname: ?[]u8 = null,
is_waiting_body: bool = false,
@@ -1069,6 +1070,7 @@ pub const FetchTasklet = struct {
.memory_reporter = fetch_options.memory_reporter,
.check_server_identity = fetch_options.check_server_identity,
.reject_unauthorized = fetch_options.reject_unauthorized,
.is_websocket_upgrade = fetch_options.is_websocket_upgrade,
};
fetch_tasklet.signals = fetch_tasklet.signal_store.to();
@@ -1201,19 +1203,23 @@ pub const FetchTasklet = struct {
// dont have backpressure so we will schedule the data to be written
// if we have backpressure the onWritable will drain the buffer
needs_schedule = stream_buffer.isEmpty();
//16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n
var formated_size_buffer: [18]u8 = undefined;
const formated_size = std.fmt.bufPrint(
formated_size_buffer[0..],
"{x}\r\n",
.{data.len},
) catch |err| switch (err) {
error.NoSpaceLeft => unreachable,
};
bun.handleOom(stream_buffer.ensureUnusedCapacity(formated_size.len + data.len + 2));
stream_buffer.writeAssumeCapacity(formated_size);
stream_buffer.writeAssumeCapacity(data);
stream_buffer.writeAssumeCapacity("\r\n");
if (this.is_websocket_upgrade) {
bun.handleOom(stream_buffer.write(data));
} else {
//16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n
var formated_size_buffer: [18]u8 = undefined;
const formated_size = std.fmt.bufPrint(
formated_size_buffer[0..],
"{x}\r\n",
.{data.len},
) catch |err| switch (err) {
error.NoSpaceLeft => unreachable,
};
bun.handleOom(stream_buffer.ensureUnusedCapacity(formated_size.len + data.len + 2));
stream_buffer.writeAssumeCapacity(formated_size);
stream_buffer.writeAssumeCapacity(data);
stream_buffer.writeAssumeCapacity("\r\n");
}
// pause the stream if we hit the high water mark
return stream_buffer.size() >= highWaterMark;
@@ -1271,6 +1277,7 @@ pub const FetchTasklet = struct {
check_server_identity: jsc.Strong.Optional = .empty,
unix_socket_path: ZigString.Slice,
ssl_config: ?*SSLConfig = null,
is_websocket_upgrade: bool = false,
};
pub fn queue(
@@ -1494,6 +1501,7 @@ pub fn Bun__fetch_(
var memory_reporter = bun.handleOom(bun.default_allocator.create(bun.MemoryReportingAllocator));
// used to clean up dynamically allocated memory on error (a poor man's errdefer)
var is_error = false;
var is_websocket_upgrade = false;
var allocator = memory_reporter.wrap(bun.default_allocator);
errdefer bun.default_allocator.destroy(memory_reporter);
defer {
@@ -2198,6 +2206,15 @@ pub fn Bun__fetch_(
}
}
if (headers_.fastGet(bun.webcore.FetchHeaders.HTTPHeaderName.Upgrade)) |_upgrade| {
const upgrade = _upgrade.toSlice(bun.default_allocator);
defer upgrade.deinit();
const slice = upgrade.slice();
if (bun.strings.eqlComptime(slice, "websocket")) {
is_websocket_upgrade = true;
}
}
break :extract_headers Headers.from(headers_, allocator, .{ .body = body.getAnyBlob() }) catch |err| bun.handleOom(err);
}
@@ -2333,7 +2350,7 @@ pub fn Bun__fetch_(
}
}
if (!method.hasRequestBody() and body.hasBody()) {
if (!method.hasRequestBody() and body.hasBody() and !is_websocket_upgrade) {
const err = globalThis.toTypeError(.INVALID_ARG_VALUE, fetch_error_unexpected_body, .{});
is_error = true;
return JSPromise.dangerouslyCreateRejectedPromiseValueWithoutNotifyingVM(globalThis, err);
@@ -2651,6 +2668,7 @@ pub fn Bun__fetch_(
.ssl_config = ssl_config,
.hostname = hostname,
.memory_reporter = memory_reporter,
.is_websocket_upgrade = is_websocket_upgrade,
.check_server_identity = if (check_server_identity.isEmptyOrUndefinedOrNull()) .empty else .create(check_server_identity, globalThis),
.unix_socket_path = unix_socket_path,
},

View File

@@ -405,7 +405,9 @@ pub const Flags = packed struct(u16) {
is_preconnect_only: bool = false,
is_streaming_request_body: bool = false,
defer_fail_until_connecting_is_complete: bool = false,
_padding: u5 = 0,
is_websockets: bool = false,
websocket_upgraded: bool = false,
_padding: u3 = 0,
};
// TODO: reduce the size of this struct
@@ -592,6 +594,11 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
hashHeaderConst("Accept-Encoding") => {
override_accept_encoding = true;
},
hashHeaderConst("Upgrade") => {
if (std.ascii.eqlIgnoreCase(this.headerStr(header_values[i]), "websocket")) {
this.flags.is_websockets = true;
}
},
hashHeaderConst(chunked_encoded_header.name) => {
// We don't want to override chunked encoding header if it was set by the user
add_transfer_encoding = false;
@@ -1019,11 +1026,14 @@ fn writeToStreamUsingBuffer(this: *HTTPClient, comptime is_ssl: bool, socket: Ne
// no data to send so we are done
return false;
}
pub fn writeToStream(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket, data: []const u8) void {
log("flushStream", .{});
var stream = &this.state.original_request_body.stream;
const stream_buffer = stream.buffer orelse return;
if (this.flags.is_websockets and !this.flags.websocket_upgraded) {
// cannot drain yet, websocket is waiting for upgrade
return;
}
const buffer = stream_buffer.acquire();
const wasEmpty = buffer.isEmpty() and data.len == 0;
if (wasEmpty and stream.ended) {
@@ -1324,56 +1334,78 @@ pub fn handleOnDataHeaders(
) void {
log("handleOnDataHeaders", .{});
var to_read = incoming_data;
var amount_read: usize = 0;
var needs_move = true;
if (this.state.response_message_buffer.list.items.len > 0) {
// this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating
bun.handleOom(this.state.response_message_buffer.appendSliceExact(incoming_data));
to_read = this.state.response_message_buffer.list.items;
needs_move = false;
}
// we reset the pending_response each time wich means that on parse error this will be always be empty
this.state.pending_response = picohttp.Response{};
// minimal http/1.1 request size is 16 bytes without headers and 26 with Host header
// if is less than 16 will always be a ShortRead
if (to_read.len < 16) {
log("handleShortRead", .{});
this.handleShortRead(is_ssl, incoming_data, socket, needs_move);
return;
}
var response = picohttp.Response.parseParts(
to_read,
&shared_response_headers_buf,
&amount_read,
) catch |err| {
switch (err) {
error.ShortRead => {
this.handleShortRead(is_ssl, incoming_data, socket, needs_move);
},
else => {
this.closeAndFail(err, is_ssl, socket);
},
while (true) {
var amount_read: usize = 0;
var needs_move = true;
if (this.state.response_message_buffer.list.items.len > 0) {
// this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating
bun.handleOom(this.state.response_message_buffer.appendSliceExact(incoming_data));
to_read = this.state.response_message_buffer.list.items;
needs_move = false;
}
return;
};
// we save the successful parsed response
this.state.pending_response = response;
// we reset the pending_response each time wich means that on parse error this will be always be empty
this.state.pending_response = picohttp.Response{};
const body_buf = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..];
// handle the case where we have a 100 Continue
if (response.status_code >= 100 and response.status_code < 200) {
log("information headers", .{});
// we still can have the 200 OK in the same buffer sometimes
if (body_buf.len > 0) {
log("information headers with body", .{});
this.onData(is_ssl, body_buf, ctx, socket);
// minimal http/1.1 request size is 16 bytes without headers and 26 with Host header
// if is less than 16 will always be a ShortRead
if (to_read.len < 16) {
log("handleShortRead", .{});
this.handleShortRead(is_ssl, incoming_data, socket, needs_move);
return;
}
return;
const response = picohttp.Response.parseParts(
to_read,
&shared_response_headers_buf,
&amount_read,
) catch |err| {
switch (err) {
error.ShortRead => {
this.handleShortRead(is_ssl, incoming_data, socket, needs_move);
},
else => {
this.closeAndFail(err, is_ssl, socket);
},
}
return;
};
// we save the successful parsed response
this.state.pending_response = response;
to_read = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..];
if (response.status_code == 101) {
if (!this.flags.is_websockets) {
// we cannot upgrade to websocket because the client did not request it!
this.closeAndFail(error.UnrequestedUpgrade, is_ssl, socket);
return;
}
// special case for websocket upgrade
this.flags.is_websockets = true;
this.flags.websocket_upgraded = true;
if (this.signals.upgraded) |upgraded| {
upgraded.store(true, .monotonic);
}
// start draining the request body
this.flushStream(is_ssl, socket);
break;
}
// handle the case where we have a 100 Continue
if (response.status_code >= 100 and response.status_code < 200) {
log("information headers", .{});
// we still can have the 200 OK in the same buffer sometimes
// 1XX responses MUST NOT include a message-body, therefore we need to continue parsing
continue;
}
break;
}
var response = this.state.pending_response.?;
const should_continue = this.handleResponseMetadata(
&response,
) catch |err| {
@@ -1409,14 +1441,14 @@ pub fn handleOnDataHeaders(
if (this.flags.proxy_tunneling and this.proxy_tunnel == null) {
// we are proxing we dont need to cloneMetadata yet
this.startProxyHandshake(is_ssl, socket, body_buf);
this.startProxyHandshake(is_ssl, socket, to_read);
return;
}
// we have body data incoming so we clone metadata and keep going
this.cloneMetadata();
if (body_buf.len == 0) {
if (to_read.len == 0) {
// no body data yet, but we can report the headers
if (this.signals.get(.header_progress)) {
this.progressUpdate(is_ssl, ctx, socket);
@@ -1426,7 +1458,7 @@ pub fn handleOnDataHeaders(
if (this.state.response_stage == .body) {
{
const report_progress = this.handleResponseBody(body_buf, true) catch |err| {
const report_progress = this.handleResponseBody(to_read, true) catch |err| {
this.closeAndFail(err, is_ssl, socket);
return;
};
@@ -1439,7 +1471,7 @@ pub fn handleOnDataHeaders(
} else if (this.state.response_stage == .body_chunk) {
this.setTimeout(socket, 5);
{
const report_progress = this.handleResponseBodyChunkedEncoding(body_buf) catch |err| {
const report_progress = this.handleResponseBodyChunkedEncoding(to_read) catch |err| {
this.closeAndFail(err, is_ssl, socket);
return;
};
@@ -2149,7 +2181,7 @@ pub fn handleResponseMetadata(
// [...] cannot contain a message body or trailer section.
// therefore in these cases set content-length to 0, so the response body is always ignored
// and is not waited for (which could cause a timeout)
if ((response.status_code >= 100 and response.status_code < 200) or response.status_code == 204 or response.status_code == 304) {
if ((response.status_code >= 100 and response.status_code < 200 and response.status_code != 101) or response.status_code == 204 or response.status_code == 304) {
this.state.content_length = 0;
}
@@ -2416,7 +2448,7 @@ pub fn handleResponseMetadata(
log("handleResponseMetadata: content_length is null and transfer_encoding {}", .{this.state.transfer_encoding});
}
if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events)) {
if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.websocket_upgraded)) {
return ShouldContinue.continue_streaming;
} else {
return ShouldContinue.finished;

View File

@@ -4,8 +4,9 @@ header_progress: ?*std.atomic.Value(bool) = null,
body_streaming: ?*std.atomic.Value(bool) = null,
aborted: ?*std.atomic.Value(bool) = null,
cert_errors: ?*std.atomic.Value(bool) = null,
upgraded: ?*std.atomic.Value(bool) = null,
pub fn isEmpty(this: *const Signals) bool {
return this.aborted == null and this.body_streaming == null and this.header_progress == null and this.cert_errors == null;
return this.aborted == null and this.body_streaming == null and this.header_progress == null and this.cert_errors == null and this.upgraded == null;
}
pub const Store = struct {
@@ -13,12 +14,14 @@ pub const Store = struct {
body_streaming: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
aborted: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
cert_errors: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
upgraded: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
pub fn to(this: *Store) Signals {
return .{
.header_progress = &this.header_progress,
.body_streaming = &this.body_streaming,
.aborted = &this.aborted,
.cert_errors = &this.cert_errors,
.upgraded = &this.upgraded,
};
}
};

View File

@@ -0,0 +1,63 @@
import { describe, expect, test } from "bun:test";
import { encodeTextFrame, encodeCloseFrame, decodeFrames, upgradeHeaders } from "./websocket.helpers";
describe("fetch upgrade", () => {
test("should upgrade to websocket", async () => {
const serverMessages: string[] = [];
using server = Bun.serve({
port: 3000,
fetch(req) {
if (server.upgrade(req)) return;
return new Response("Hello World");
},
websocket: {
open(ws) {
ws.send("Hello World");
},
message(ws, message) {
serverMessages.push(message as string);
},
close(ws) {
serverMessages.push("close");
},
},
});
const res = await fetch(server.url, {
method: "GET",
headers: upgradeHeaders(),
async *body() {
yield encodeTextFrame("hello");
yield encodeTextFrame("world");
yield encodeTextFrame("bye");
yield encodeCloseFrame();
},
});
expect(res.status).toBe(101);
expect(res.headers.get("upgrade")).toBe("websocket");
expect(res.headers.get("sec-websocket-accept")).toBeString();
expect(res.headers.get("connection")).toBe("Upgrade");
const clientMessages: string[] = [];
const { promise, resolve } = Promise.withResolvers<void>();
const reader = res.body!.getReader();
while (true) {
const { value, done } = await reader.read();
if (done) break;
for (const msg of decodeFrames(Buffer.from(value))) {
if (typeof msg === "string") {
clientMessages.push(msg);
} else {
clientMessages.push(msg.type);
}
if (msg.type === "close") {
resolve();
}
}
}
await promise;
expect(serverMessages).toEqual(["hello", "world", "bye", "close"]);
expect(clientMessages).toEqual(["Hello World", "close"]);
});
});

View File

@@ -0,0 +1,156 @@
import { createHash, randomBytes } from "node:crypto";
// RFC 6455 magic GUID
const WS_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
function makeKey() {
return randomBytes(16).toString("base64");
}
function acceptFor(key) {
return createHash("sha1")
.update(key + WS_GUID)
.digest("base64");
}
export function encodeCloseFrame(code = 1000, reason = "") {
const reasonBuf = Buffer.from(reason, "utf8");
const payloadLen = 2 + reasonBuf.length; // 2 bytes for code + reason
const header = [];
let headerLen = 2;
if (payloadLen < 126) {
// masked bit (0x80) + length
header.push(0x88, 0x80 | payloadLen);
} else if (payloadLen <= 0xffff) {
headerLen += 2;
header.push(0x88, 0x80 | 126, payloadLen >> 8, payloadLen & 0xff);
} else {
throw new Error("Close reason too long");
}
const mask = randomBytes(4);
const buf = Buffer.alloc(headerLen + 4 + payloadLen);
Buffer.from(header).copy(buf, 0);
mask.copy(buf, headerLen);
// write code + reason
const unmasked = Buffer.alloc(payloadLen);
unmasked.writeUInt16BE(code, 0);
reasonBuf.copy(unmasked, 2);
// apply mask
for (let i = 0; i < payloadLen; i++) {
buf[headerLen + 4 + i] = unmasked[i] ^ mask[i & 3];
}
return buf;
}
export function* decodeFrames(buffer) {
let i = 0;
while (i + 2 <= buffer.length) {
const b0 = buffer[i++];
const b1 = buffer[i++];
const fin = (b0 & 0x80) !== 0;
const opcode = b0 & 0x0f;
const masked = (b1 & 0x80) !== 0;
let len = b1 & 0x7f;
if (len === 126) {
if (i + 2 > buffer.length) break;
len = buffer.readUInt16BE(i);
i += 2;
} else if (len === 127) {
if (i + 8 > buffer.length) break;
const big = buffer.readBigUInt64BE(i);
i += 8;
if (big > BigInt(Number.MAX_SAFE_INTEGER)) throw new Error("frame too large");
len = Number(big);
}
let mask;
if (masked) {
if (i + 4 > buffer.length) break;
mask = buffer.subarray(i, i + 4);
i += 4;
}
if (i + len > buffer.length) break;
let payload = buffer.subarray(i, i + len);
i += len;
if (masked && mask) {
const unmasked = Buffer.alloc(len);
for (let j = 0; j < len; j++) unmasked[j] = payload[j] ^ mask[j & 3];
payload = unmasked;
}
if (!fin) throw new Error("fragmentation not supported in this demo");
if (opcode === 0x1) {
// text
yield payload.toString("utf8");
} else if (opcode === 0x8) {
// CLOSE
yield { type: "close" };
return;
} else if (opcode === 0x9) {
// PING -> respond with PONG if you implement writes here
yield { type: "ping", data: payload };
} else if (opcode === 0xa) {
// PONG
yield { type: "pong", data: payload };
} else {
// ignore other opcodes for brevity
}
}
}
// Encode a single unfragmented TEXT frame (client -> server must be masked)
export function encodeTextFrame(str) {
const payload = Buffer.from(str, "utf8");
const len = payload.length;
let headerLen = 2;
if (len >= 126 && len <= 0xffff) headerLen += 2;
else if (len > 0xffff) headerLen += 8;
const maskKeyLen = 4;
const buf = Buffer.alloc(headerLen + maskKeyLen + len);
// FIN=1, RSV=0, opcode=0x1 (text)
buf[0] = 0x80 | 0x1;
// Set masked bit and length field(s)
let offset = 1;
if (len < 126) {
buf[offset++] = 0x80 | len; // mask bit + length
} else if (len <= 0xffff) {
buf[offset++] = 0x80 | 126;
buf.writeUInt16BE(len, offset);
offset += 2;
} else {
buf[offset++] = 0x80 | 127;
buf.writeBigUInt64BE(BigInt(len), offset);
offset += 8;
}
// Mask key
const mask = randomBytes(4);
mask.copy(buf, offset);
offset += 4;
// Mask the payload
for (let i = 0; i < len; i++) {
buf[offset + i] = payload[i] ^ mask[i & 3];
}
return buf;
}
export function upgradeHeaders() {
const secWebSocketKey = makeKey();
return {
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": secWebSocketKey,
};
}