Compare commits

...

6 Commits

Author SHA1 Message Date
Jarred-Sumner
c8a3a84388 bun run zig-format 2025-03-15 23:35:24 +00:00
Jarred Sumner
acc69ecbe9 feat(WebSocket): implement client-side permessage-deflate compression
This adds support for the permessage-deflate WebSocket extension (RFC7692) to
Bun's WebSocket client. The implementation:

- Negotiates compression parameters during the WebSocket handshake
- Compresses outgoing messages using libdeflate
- Decompresses incoming messages
- Handles context takeover settings appropriately
- Implements window size parameters
- Includes proper RSV1 bit handling
- Follows RFC7692 specification for frame processing

This allows WebSocket connections to use significantly less bandwidth for
repetitive or compressible data like JSON or text content.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-03-15 16:32:21 -07:00
Jarred Sumner
71cd154579 wip 2025-03-15 15:16:17 -07:00
Jarred Sumner
0bba14daf9 Update websocket_http_client.zig 2025-03-15 14:59:11 -07:00
Jarred Sumner
8db9162be3 move 2025-03-15 14:59:05 -07:00
Jarred Sumner
7a7576a1d9 Update websocket_http_client.zig 2025-03-15 14:56:54 -07:00
5 changed files with 2496 additions and 2022 deletions

1199
src/http/WebSocket.zig Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,633 @@
const std = @import("std");
const bun = @import("root").bun;
const string = bun.string;
const Output = bun.Output;
const Global = bun.Global;
const Environment = bun.Environment;
const strings = bun.strings;
const MutableString = bun.MutableString;
const stringZ = bun.stringZ;
const default_allocator = bun.default_allocator;
const C = bun.C;
const BoringSSL = bun.BoringSSL;
const uws = bun.uws;
const JSC = bun.JSC;
const PicoHTTP = bun.picohttp;
const websocket_client = @import("websocket_client.zig");
const CppWebSocket = websocket_client.CppWebSocket;
const ErrorCode = websocket_client.ErrorCode;
const NonUTF8Headers = @import("websocket_client.zig").NonUTF8Headers;
const Async = bun.Async;
const log = Output.scoped(.WebSocketClient, false);
fn buildRequestBody(
vm: *JSC.VirtualMachine,
pathname: *const JSC.ZigString,
is_https: bool,
host: *const JSC.ZigString,
port: u16,
client_protocol: *const JSC.ZigString,
client_protocol_hash: *u64,
extra_headers: NonUTF8Headers,
) std.mem.Allocator.Error![]u8 {
const allocator = vm.allocator;
const input_rand_buf = vm.rareData().nextUUID().bytes;
const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16);
var encoded_buf: [temp_buf_size]u8 = undefined;
const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf);
var static_headers = [_]PicoHTTP.Header{
.{
.name = "Sec-WebSocket-Key",
.value = accept_key,
},
.{
.name = "Sec-WebSocket-Protocol",
.value = client_protocol.slice(),
},
.{
.name = "Sec-WebSocket-Extensions",
.value = "permessage-deflate; client_max_window_bits",
},
};
if (client_protocol.len > 0)
client_protocol_hash.* = bun.hash(static_headers[1].value);
const pathname_ = pathname.toSlice(allocator);
const host_ = host.toSlice(allocator);
defer {
pathname_.deinit();
host_.deinit();
}
const host_fmt = bun.fmt.HostFormatter{
.is_https = is_https,
.host = host_.slice(),
.port = port,
};
// Include the extension header
const headers_ = static_headers[0 .. 2 + @as(usize, @intFromBool(client_protocol.len > 0))];
const pico_headers = PicoHTTP.Headers{ .headers = headers_ };
return try std.fmt.allocPrint(
allocator,
"GET {s} HTTP/1.1\r\n" ++
"Host: {any}\r\n" ++
"Connection: Upgrade\r\n" ++
"Upgrade: websocket\r\n" ++
"Sec-WebSocket-Version: 13\r\n" ++
"{s}" ++
"{s}" ++
"\r\n",
.{ pathname_.slice(), host_fmt, pico_headers, extra_headers },
);
}
pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
return struct {
pub const Socket = uws.NewSocketHandler(ssl);
tcp: Socket,
outgoing_websocket: ?*CppWebSocket,
input_body_buf: []u8 = &[_]u8{},
client_protocol: []const u8 = "",
to_send: []const u8 = "",
read_length: usize = 0,
headers_buf: [128]PicoHTTP.Header = undefined,
body: std.ArrayListUnmanaged(u8) = .{},
websocket_protocol: u64 = 0,
hostname: [:0]const u8 = "",
poll_ref: Async.KeepAlive = Async.KeepAlive.init(),
state: State = .initializing,
ref_count: u32 = 1,
const State = enum { initializing, reading, failed };
pub const name = if (ssl) "WebSocketHTTPSClient" else "WebSocketHTTPClient";
pub const shim = JSC.Shimmer("Bun", name, @This());
pub usingnamespace bun.NewRefCounted(@This(), deinit, null);
const HTTPClient = @This();
pub fn register(_: *JSC.JSGlobalObject, _: *anyopaque, ctx: *uws.SocketContext) callconv(.C) void {
Socket.configure(
ctx,
true,
*HTTPClient,
struct {
pub const onOpen = handleOpen;
pub const onClose = handleClose;
pub const onData = handleData;
pub const onWritable = handleWritable;
pub const onTimeout = handleTimeout;
pub const onLongTimeout = handleTimeout;
pub const onConnectError = handleConnectError;
pub const onEnd = handleEnd;
pub const onHandshake = handleHandshake;
},
);
}
pub fn deinit(this: *HTTPClient) void {
this.clearData();
bun.debugAssert(this.tcp.isDetached());
this.destroy();
}
/// On error, this returns null.
/// Returning null signals to the parent function that the connection failed.
pub fn connect(
global: *JSC.JSGlobalObject,
socket_ctx: *anyopaque,
websocket: *CppWebSocket,
host: *const JSC.ZigString,
port: u16,
pathname: *const JSC.ZigString,
client_protocol: *const JSC.ZigString,
header_names: ?[*]const JSC.ZigString,
header_values: ?[*]const JSC.ZigString,
header_count: usize,
) callconv(.C) ?*HTTPClient {
const vm = global.bunVM();
bun.assert(vm.event_loop_handle != null);
var client_protocol_hash: u64 = 0;
const body = buildRequestBody(
vm,
pathname,
ssl,
host,
port,
client_protocol,
&client_protocol_hash,
NonUTF8Headers.init(header_names, header_values, header_count),
) catch return null;
var client = HTTPClient.new(.{
.tcp = .{ .socket = .{ .detached = {} } },
.outgoing_websocket = websocket,
.input_body_buf = body,
.websocket_protocol = client_protocol_hash,
.state = .initializing,
});
var host_ = host.toSlice(bun.default_allocator);
defer host_.deinit();
client.poll_ref.ref(vm);
const display_host_ = host_.slice();
const display_host = if (bun.FeatureFlags.hardcode_localhost_to_127_0_0_1 and strings.eqlComptime(display_host_, "localhost"))
"127.0.0.1"
else
display_host_;
if (Socket.connectPtr(
display_host,
port,
@as(*uws.SocketContext, @ptrCast(socket_ctx)),
HTTPClient,
client,
"tcp",
false,
)) |out| {
// I don't think this case gets reached.
if (out.state == .failed) {
client.deref();
return null;
}
bun.Analytics.Features.WebSocket += 1;
if (comptime ssl) {
if (!strings.isIPAddress(host_.slice())) {
out.hostname = bun.default_allocator.dupeZ(u8, host_.slice()) catch "";
}
}
out.tcp.timeout(120);
out.state = .reading;
// +1 for cpp_websocket
out.ref();
return out;
} else |_| {
client.deref();
}
return null;
}
pub fn clearInput(this: *HTTPClient) void {
if (this.input_body_buf.len > 0) bun.default_allocator.free(this.input_body_buf);
this.input_body_buf.len = 0;
}
pub fn clearData(this: *HTTPClient) void {
this.poll_ref.unref(JSC.VirtualMachine.get());
this.clearInput();
this.body.clearAndFree(bun.default_allocator);
}
pub fn cancel(this: *HTTPClient) callconv(.C) void {
this.clearData();
// Either of the below two operations - closing the TCP socket or clearing the C++ reference could trigger a deref
// Therefore, we need to make sure the `this` pointer is valid until the end of the function.
this.ref();
defer this.deref();
// The C++ end of the socket is no longer holding a reference to this, sowe must clear it.
if (this.outgoing_websocket != null) {
this.outgoing_websocket = null;
this.deref();
}
// no need to be .failure we still wanna to send pending SSL buffer + close_notify
if (comptime ssl) {
this.tcp.close(.normal);
} else {
this.tcp.close(.failure);
}
}
pub fn fail(this: *HTTPClient, code: ErrorCode) void {
log("onFail: {s}", .{@tagName(code)});
JSC.markBinding(@src());
this.ref();
defer this.deref();
this.dispatchAbruptClose(code);
if (comptime ssl) {
this.tcp.close(.normal);
} else {
this.tcp.close(.failure);
}
}
fn dispatchAbruptClose(this: *HTTPClient, code: ErrorCode) void {
if (this.outgoing_websocket) |ws| {
this.outgoing_websocket = null;
ws.didAbruptClose(code);
this.deref();
}
}
pub fn handleClose(this: *HTTPClient, _: Socket, _: c_int, _: ?*anyopaque) void {
log("onClose", .{});
JSC.markBinding(@src());
this.clearData();
this.tcp.detach();
this.dispatchAbruptClose(ErrorCode.ended);
this.deref();
}
pub fn terminate(this: *HTTPClient, code: ErrorCode) void {
this.fail(code);
// We cannot access the pointer after fail is called.
}
pub fn handleHandshake(this: *HTTPClient, socket: Socket, success: i32, ssl_error: uws.us_bun_verify_error_t) void {
log("onHandshake({d})", .{success});
const handshake_success = if (success == 1) true else false;
var reject_unauthorized = false;
if (this.outgoing_websocket) |ws| {
reject_unauthorized = ws.rejectUnauthorized();
}
if (handshake_success) {
// handshake completed but we may have ssl errors
if (reject_unauthorized) {
// only reject the connection if reject_unauthorized == true
if (ssl_error.error_no != 0) {
this.fail(ErrorCode.tls_handshake_failed);
return;
}
const ssl_ptr = @as(*BoringSSL.c.SSL, @ptrCast(socket.getNativeHandle()));
if (BoringSSL.c.SSL_get_servername(ssl_ptr, 0)) |servername| {
const hostname = servername[0..bun.len(servername)];
if (!BoringSSL.checkServerIdentity(ssl_ptr, hostname)) {
this.fail(ErrorCode.tls_handshake_failed);
}
}
}
} else {
// if we are here is because server rejected us, and the error_no is the cause of this
// if we set reject_unauthorized == false this means the server requires custom CA aka NODE_EXTRA_CA_CERTS
this.fail(ErrorCode.tls_handshake_failed);
}
}
pub fn handleOpen(this: *HTTPClient, socket: Socket) void {
log("onOpen", .{});
this.tcp = socket;
bun.assert(this.input_body_buf.len > 0);
bun.assert(this.to_send.len == 0);
if (comptime ssl) {
if (this.hostname.len > 0) {
socket.getNativeHandle().?.configureHTTPClient(this.hostname);
bun.default_allocator.free(this.hostname);
this.hostname = "";
}
}
// Do not set MSG_MORE, see https://github.com/oven-sh/bun/issues/4010
const wrote = socket.write(this.input_body_buf, false);
if (wrote < 0) {
this.terminate(ErrorCode.failed_to_write);
return;
}
this.to_send = this.input_body_buf[@as(usize, @intCast(wrote))..];
}
pub fn isSameSocket(this: *HTTPClient, socket: Socket) bool {
return socket.socket.eq(this.tcp.socket);
}
pub fn handleData(this: *HTTPClient, socket: Socket, data: []const u8) void {
log("onData", .{});
if (this.outgoing_websocket == null) {
this.clearData();
socket.close(.failure);
return;
}
this.ref();
defer this.deref();
bun.assert(this.isSameSocket(socket));
if (comptime Environment.allow_assert)
bun.assert(!socket.isShutdown());
var body = data;
if (this.body.items.len > 0) {
this.body.appendSlice(bun.default_allocator, data) catch bun.outOfMemory();
body = this.body.items;
}
const is_first = this.body.items.len == 0;
const http_101 = "HTTP/1.1 101 ";
if (is_first and body.len > http_101.len) {
// fail early if we receive a non-101 status code
if (!strings.hasPrefixComptime(body, http_101)) {
this.terminate(ErrorCode.expected_101_status_code);
return;
}
}
const response = PicoHTTP.Response.parse(body, &this.headers_buf) catch |err| {
switch (err) {
error.Malformed_HTTP_Response => {
this.terminate(ErrorCode.invalid_response);
return;
},
error.ShortRead => {
if (this.body.items.len == 0) {
this.body.appendSlice(bun.default_allocator, data) catch bun.outOfMemory();
}
return;
},
}
};
this.processResponse(response, body[@as(usize, @intCast(response.bytes_read))..]);
}
pub fn handleEnd(this: *HTTPClient, _: Socket) void {
log("onEnd", .{});
this.terminate(ErrorCode.ended);
}
pub fn processResponse(this: *HTTPClient, response: PicoHTTP.Response, remain_buf: []const u8) void {
var upgrade_header = PicoHTTP.Header{ .name = "", .value = "" };
var connection_header = PicoHTTP.Header{ .name = "", .value = "" };
var websocket_accept_header = PicoHTTP.Header{ .name = "", .value = "" };
var websocket_extensions_header = PicoHTTP.Header{ .name = "", .value = "" };
var visited_protocol = this.websocket_protocol == 0;
// var visited_version = false;
if (response.status_code != 101) {
this.terminate(ErrorCode.expected_101_status_code);
return;
}
for (response.headers.list) |header| {
switch (header.name.len) {
"Connection".len => {
if (connection_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Connection", false)) {
connection_header = header;
if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) {
break;
}
}
},
"Upgrade".len => {
if (upgrade_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Upgrade", false)) {
upgrade_header = header;
if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) {
break;
}
}
},
"Sec-WebSocket-Version".len => {
if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Version", false)) {
if (!strings.eqlComptimeIgnoreLen(header.value, "13")) {
this.terminate(ErrorCode.invalid_websocket_version);
return;
}
}
},
"Sec-WebSocket-Accept".len => {
if (websocket_accept_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Accept", false)) {
websocket_accept_header = header;
if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) {
break;
}
}
},
"Sec-WebSocket-Extensions".len => {
if (websocket_extensions_header.name.len == 0 and strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Extensions", false)) {
websocket_extensions_header = header;
}
},
"Sec-WebSocket-Protocol".len => {
if (strings.eqlCaseInsensitiveASCII(header.name, "Sec-WebSocket-Protocol", false)) {
if (this.websocket_protocol == 0 or bun.hash(header.value) != this.websocket_protocol) {
this.terminate(ErrorCode.mismatch_client_protocol);
return;
}
visited_protocol = true;
if (visited_protocol and upgrade_header.name.len > 0 and connection_header.name.len > 0 and websocket_accept_header.name.len > 0) {
break;
}
}
},
else => {},
}
}
// if (!visited_version) {
// this.terminate(ErrorCode.invalid_websocket_version);
// return;
// }
if (@min(upgrade_header.name.len, upgrade_header.value.len) == 0) {
this.terminate(ErrorCode.missing_upgrade_header);
return;
}
if (@min(connection_header.name.len, connection_header.value.len) == 0) {
this.terminate(ErrorCode.missing_connection_header);
return;
}
if (@min(websocket_accept_header.name.len, websocket_accept_header.value.len) == 0) {
this.terminate(ErrorCode.missing_websocket_accept_header);
return;
}
if (!visited_protocol) {
this.terminate(ErrorCode.mismatch_client_protocol);
return;
}
if (!strings.eqlCaseInsensitiveASCII(connection_header.value, "Upgrade", true)) {
this.terminate(ErrorCode.invalid_connection_header);
return;
}
if (!strings.eqlCaseInsensitiveASCII(upgrade_header.value, "websocket", true)) {
this.terminate(ErrorCode.invalid_upgrade_header);
return;
}
// TODO: check websocket_accept_header.value
const overflow_len = remain_buf.len;
var overflow: []u8 = &.{};
if (overflow_len > 0) {
overflow = bun.default_allocator.alloc(u8, overflow_len) catch {
this.terminate(ErrorCode.invalid_response);
return;
};
@memcpy(overflow, remain_buf);
}
this.clearData();
JSC.markBinding(@src());
if (!this.tcp.isClosed() and this.outgoing_websocket != null) {
this.tcp.timeout(0);
log("onDidConnect", .{});
// Once for the outgoing_websocket.
defer this.deref();
const ws = bun.take(&this.outgoing_websocket).?;
const socket = this.tcp;
// Setup compression if the server accepted our request
if (websocket_extensions_header.name.len > 0 and
websocket_extensions_header.value.len > 0 and
std.mem.indexOf(u8, websocket_extensions_header.value, "permessage-deflate") != null)
{
log("Setting up compression with extension: {s}", .{websocket_extensions_header.value});
// After connection is established, we'll need to initialize compression
// Note: Initializing compression after didConnect since that's when we get WebSocket instance
}
this.tcp.detach();
// Once again for the TCP socket.
defer this.deref();
ws.didConnect(socket.socket.get().?, overflow.ptr, overflow.len);
} else if (this.tcp.isClosed()) {
this.terminate(ErrorCode.cancel);
} else if (this.outgoing_websocket == null) {
this.tcp.close(.failure);
}
}
pub fn memoryCost(this: *HTTPClient) callconv(.C) usize {
var cost: usize = @sizeOf(HTTPClient);
cost += this.body.capacity;
cost += this.to_send.len;
return cost;
}
pub fn handleWritable(
this: *HTTPClient,
socket: Socket,
) void {
bun.assert(this.isSameSocket(socket));
if (this.to_send.len == 0)
return;
this.ref();
defer this.deref();
// Do not set MSG_MORE, see https://github.com/oven-sh/bun/issues/4010
const wrote = socket.write(this.to_send, false);
if (wrote < 0) {
this.terminate(ErrorCode.failed_to_write);
return;
}
this.to_send = this.to_send[@min(@as(usize, @intCast(wrote)), this.to_send.len)..];
}
pub fn handleTimeout(
this: *HTTPClient,
_: Socket,
) void {
this.terminate(ErrorCode.timeout);
}
// In theory, this could be called immediately
// In that case, we set `state` to `failed` and return, expecting the parent to call `destroy`.
pub fn handleConnectError(this: *HTTPClient, _: Socket, _: c_int) void {
this.tcp.detach();
// For the TCP socket.
defer this.deref();
if (this.state == .reading) {
this.terminate(ErrorCode.failed_to_connect);
} else {
this.state = .failed;
}
}
pub const Export = shim.exportFunctions(.{
.connect = connect,
.cancel = cancel,
.register = register,
.memoryCost = memoryCost,
});
comptime {
@export(&connect, .{
.name = Export[0].symbol_name,
});
@export(&cancel, .{
.name = Export[1].symbol_name,
});
@export(&register, .{
.name = Export[2].symbol_name,
});
@export(&memoryCost, .{
.name = Export[3].symbol_name,
});
}
};
}
pub const WebSocketHTTPClient = NewHTTPUpgradeClient(false);
pub const WebSocketHTTPSClient = NewHTTPUpgradeClient(true);

View File

@@ -0,0 +1,643 @@
// This code is based on https://github.com/frmdstryr/zhp/blob/a4b5700c289c3619647206144e10fb414113a888/src/websocket.zig
// Thank you @frmdstryr.
const std = @import("std");
const bun = @import("root").bun;
const string = bun.string;
const Output = bun.Output;
const Global = bun.Global;
const Environment = bun.Environment;
const strings = bun.strings;
const MutableString = bun.MutableString;
const stringZ = bun.stringZ;
const default_allocator = bun.default_allocator;
const C = bun.C;
const BoringSSL = bun.BoringSSL;
const uws = bun.uws;
const JSC = bun.JSC;
const PicoHTTP = bun.picohttp;
const ObjectPool = @import("../pool.zig").ObjectPool;
const protocol = @import("./websocket_protocol.zig");
const WebsocketHeader = protocol.WebsocketHeader;
const WebsocketDataFrame = protocol.WebsocketDataFrame;
const Opcode = protocol.Opcode;
const ZigURL = @import("../url.zig").URL;
const libdeflate = @import("../deps/libdeflate.zig");
const Async = bun.Async;
pub const WebSocketCompression = struct {
// Compression state
enabled: bool = false,
client_no_context_takeover: bool = false,
server_no_context_takeover: bool = false,
client_max_window_bits: u8 = 15, // Default is 15 (32KB window)
server_max_window_bits: u8 = 15, // Default is 15 (32KB window)
compressor: ?*libdeflate.Compressor = null,
decompressor: ?*libdeflate.Decompressor = null,
compression_buffer: []u8 = &[_]u8{},
pub fn init() WebSocketCompression {
return .{};
}
pub fn deinit(self: *WebSocketCompression) void {
if (self.compressor) |compressor| {
compressor.deinit();
self.compressor = null;
}
if (self.decompressor) |decompressor| {
decompressor.deinit();
self.decompressor = null;
}
if (self.compression_buffer.len > 0) {
default_allocator.free(self.compression_buffer);
self.compression_buffer = &[_]u8{};
}
}
pub fn setup(self: *WebSocketCompression, extensions_header: []const u8) bool {
// Parse the extensions header to set up compression parameters
// Example: "permessage-deflate; client_max_window_bits=15; server_max_window_bits=15"
if (extensions_header.len == 0) return false;
// Simple check for permessage-deflate extension
if (std.mem.indexOf(u8, extensions_header, "permessage-deflate") == null) return false;
self.enabled = true;
// Parse parameters
if (std.mem.indexOf(u8, extensions_header, "client_no_context_takeover") != null) {
self.client_no_context_takeover = true;
}
if (std.mem.indexOf(u8, extensions_header, "server_no_context_takeover") != null) {
self.server_no_context_takeover = true;
}
// Parse window bits parameters
if (std.mem.indexOf(u8, extensions_header, "client_max_window_bits=")) |client_pos| {
// Find the actual value after the equals sign
const start_pos = client_pos + "client_max_window_bits=".len;
var end_pos = start_pos;
while (end_pos < extensions_header.len and
std.ascii.isDigit(extensions_header[end_pos]))
{
end_pos += 1;
}
if (end_pos > start_pos) {
const value_str = extensions_header[start_pos..end_pos];
const value = std.fmt.parseInt(u8, value_str, 10) catch 15;
// Valid window bits values are 8-15, with 15 being the default
if (value >= 8 and value <= 15) {
self.client_max_window_bits = value;
}
}
}
if (std.mem.indexOf(u8, extensions_header, "server_max_window_bits=")) |server_pos| {
// Find the actual value after the equals sign
const start_pos = server_pos + "server_max_window_bits=".len;
var end_pos = start_pos;
while (end_pos < extensions_header.len and
std.ascii.isDigit(extensions_header[end_pos]))
{
end_pos += 1;
}
if (end_pos > start_pos) {
const value_str = extensions_header[start_pos..end_pos];
const value = std.fmt.parseInt(u8, value_str, 10) catch 15;
// Valid window bits values are 8-15, with 15 being the default
if (value >= 8 and value <= 15) {
self.server_max_window_bits = value;
}
}
}
// Initialize compression/decompression
if (self.enabled) {
libdeflate.load();
// Level 6 compression is a good balance of speed vs compression ratio
// Choose compression level based on window bits
var compression_level: c_int = 6; // Default
if (self.server_max_window_bits <= 9) {
compression_level = 1; // For very small windows, use fast compression
} else if (self.server_max_window_bits <= 11) {
compression_level = 3; // For small windows
} else if (self.server_max_window_bits <= 13) {
compression_level = 5; // For medium windows
}
self.compressor = libdeflate.Compressor.alloc(compression_level);
self.decompressor = libdeflate.Decompressor.alloc();
// Log the negotiated parameters
log("Initialized compression with client_max_window_bits={d}, server_max_window_bits={d}", .{ self.client_max_window_bits, self.server_max_window_bits });
// Allocate a shared buffer for compression/decompression
// Start with a reasonable size that will be grown if needed
self.compression_buffer = default_allocator.alloc(u8, 8192) catch &[_]u8{};
return self.compressor != null and self.decompressor != null and self.compression_buffer.len > 0;
}
return false;
}
pub fn compress(self: *WebSocketCompression, data: []const u8) ?[]const u8 {
if (!self.enabled or self.compressor == null) return null;
// Make sure our buffer is large enough
const max_size = self.compressor.?.maxBytesNeeded(data, .deflate);
if (max_size > self.compression_buffer.len) {
if (self.compression_buffer.len > 0) {
default_allocator.free(self.compression_buffer);
}
self.compression_buffer = default_allocator.alloc(u8, max_size) catch return null;
}
// Compress the data
const result = self.compressor.?.compress(data, self.compression_buffer, .deflate);
if (result.status == .success and result.written > 0) {
// Remove the last 4 bytes (0x00 0x00 0xff 0xff) as per RFC7692
if (result.written >= 4 and
self.compression_buffer[result.written - 4] == 0x00 and
self.compression_buffer[result.written - 3] == 0x00 and
self.compression_buffer[result.written - 2] == 0xff and
self.compression_buffer[result.written - 1] == 0xff)
{
// If server_no_context_takeover is true, we should reset the compressor
// However, libdeflate doesn't have a direct way to reset the compressor
// So we would need to free and recreate it if needed
if (self.server_no_context_takeover) {
if (self.compressor) |compressor| {
compressor.deinit();
// Create a new compressor with the same settings
var compression_level: c_int = 6; // Default
if (self.server_max_window_bits <= 9) {
compression_level = 1; // For very small windows, use fast compression
} else if (self.server_max_window_bits <= 11) {
compression_level = 3; // For small windows
} else if (self.server_max_window_bits <= 13) {
compression_level = 5; // For medium windows
}
self.compressor = libdeflate.Compressor.alloc(compression_level);
}
}
return self.compression_buffer[0 .. result.written - 4];
}
return self.compression_buffer[0..result.written];
}
return null;
}
pub fn decompress(self: *WebSocketCompression, data: []const u8, estimated_size: usize) ?[]const u8 {
if (!self.enabled or self.decompressor == null) return null;
// Make sure our buffer is large enough for decompression
// We might need to grow the buffer if the uncompressed data is large
if (estimated_size > self.compression_buffer.len) {
if (self.compression_buffer.len > 0) {
default_allocator.free(self.compression_buffer);
}
self.compression_buffer = default_allocator.alloc(u8, estimated_size) catch return null;
}
// Append 0x00 0x00 0xff 0xff to the data as required by RFC7692
var input_buffer = default_allocator.alloc(u8, data.len + 4) catch return null;
defer default_allocator.free(input_buffer);
@memcpy(input_buffer[0..data.len], data);
input_buffer[data.len] = 0x00;
input_buffer[data.len + 1] = 0x00;
input_buffer[data.len + 2] = 0xff;
input_buffer[data.len + 3] = 0xff;
// Decompress
const result = self.decompressor.?.decompress(input_buffer, self.compression_buffer, .deflate);
if (result.status == .success and result.written > 0) {
// If client_no_context_takeover is true, we should reset the decompressor
// Similar to the compressor, libdeflate doesn't have a direct way to reset
// So we would need to free and recreate it
if (self.client_no_context_takeover) {
if (self.decompressor) |decompressor| {
decompressor.deinit();
self.decompressor = libdeflate.Decompressor.alloc();
}
}
return self.compression_buffer[0..result.written];
}
return null;
}
};
const log = Output.scoped(.WebSocketClient, false);
pub const NonUTF8Headers = struct {
names: []const JSC.ZigString,
values: []const JSC.ZigString,
pub fn format(self: NonUTF8Headers, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
const count = self.names.len;
var i: usize = 0;
while (i < count) : (i += 1) {
try std.fmt.format(writer, "{any}: {any}\r\n", .{ self.names[i], self.values[i] });
}
}
pub fn init(names: ?[*]const JSC.ZigString, values: ?[*]const JSC.ZigString, len: usize) NonUTF8Headers {
if (len == 0) {
return .{
.names = &[_]JSC.ZigString{},
.values = &[_]JSC.ZigString{},
};
}
return .{
.names = names.?[0..len],
.values = values.?[0..len],
};
}
};
pub const ErrorCode = enum(i32) {
cancel,
invalid_response,
expected_101_status_code,
missing_upgrade_header,
missing_connection_header,
missing_websocket_accept_header,
invalid_upgrade_header,
invalid_connection_header,
invalid_websocket_version,
mismatch_websocket_accept_header,
missing_client_protocol,
mismatch_client_protocol,
timeout,
closed,
failed_to_write,
failed_to_connect,
headers_too_large,
ended,
failed_to_allocate_memory,
control_frame_is_fragmented,
invalid_control_frame,
compression_unsupported,
unexpected_mask_from_server,
expected_control_frame,
unsupported_control_frame,
unexpected_opcode,
invalid_utf8,
tls_handshake_failed,
};
pub const CppWebSocket = opaque {
extern fn WebSocket__didConnect(
websocket_context: *CppWebSocket,
socket: *uws.Socket,
buffered_data: ?[*]u8,
buffered_len: usize,
) void;
extern fn WebSocket__didAbruptClose(websocket_context: *CppWebSocket, reason: ErrorCode) void;
extern fn WebSocket__didClose(websocket_context: *CppWebSocket, code: u16, reason: *const bun.String) void;
extern fn WebSocket__didReceiveText(websocket_context: *CppWebSocket, clone: bool, text: *const JSC.ZigString) void;
extern fn WebSocket__didReceiveBytes(websocket_context: *CppWebSocket, bytes: [*]const u8, byte_len: usize, opcode: u8) void;
extern fn WebSocket__rejectUnauthorized(websocket_context: *CppWebSocket) bool;
extern fn WebSocket__setupCompression(websocket_context: *CppWebSocket, extensions: [*]const u8, extensions_len: usize) void;
pub fn didAbruptClose(this: *CppWebSocket, reason: ErrorCode) void {
const loop = JSC.VirtualMachine.get().eventLoop();
loop.enter();
defer loop.exit();
WebSocket__didAbruptClose(this, reason);
}
pub fn didClose(this: *CppWebSocket, code: u16, reason: *bun.String) void {
const loop = JSC.VirtualMachine.get().eventLoop();
loop.enter();
defer loop.exit();
WebSocket__didClose(this, code, reason);
}
pub fn didReceiveText(this: *CppWebSocket, clone: bool, text: *const JSC.ZigString) void {
const loop = JSC.VirtualMachine.get().eventLoop();
loop.enter();
defer loop.exit();
WebSocket__didReceiveText(this, clone, text);
}
pub fn didReceiveBytes(this: *CppWebSocket, bytes: [*]const u8, byte_len: usize, opcode: u8) void {
const loop = JSC.VirtualMachine.get().eventLoop();
loop.enter();
defer loop.exit();
WebSocket__didReceiveBytes(this, bytes, byte_len, opcode);
}
pub fn rejectUnauthorized(this: *CppWebSocket) bool {
const loop = JSC.VirtualMachine.get().eventLoop();
loop.enter();
defer loop.exit();
return WebSocket__rejectUnauthorized(this);
}
pub fn didConnect(this: *CppWebSocket, socket: *uws.Socket, buffered_data: ?[*]u8, buffered_len: usize) void {
const loop = JSC.VirtualMachine.get().eventLoop();
loop.enter();
defer loop.exit();
WebSocket__didConnect(this, socket, buffered_data, buffered_len);
}
pub fn setupCompression(_: *CppWebSocket, _: []const u8) void {
// Skip for now since setupCompression is not yet implemented in C++ side
// When implementing this in C++, uncomment the following:
// const loop = JSC.VirtualMachine.get().eventLoop();
// loop.enter();
// defer loop.exit();
// WebSocket__setupCompression(this, extensions.ptr, extensions.len);
}
extern fn WebSocket__incrementPendingActivity(websocket_context: *CppWebSocket) void;
extern fn WebSocket__decrementPendingActivity(websocket_context: *CppWebSocket) void;
pub fn ref(this: *CppWebSocket) void {
JSC.markBinding(@src());
WebSocket__incrementPendingActivity(this);
}
pub fn unref(this: *CppWebSocket) void {
JSC.markBinding(@src());
WebSocket__decrementPendingActivity(this);
}
};
pub const Mask = struct {
pub fn fill(globalThis: *JSC.JSGlobalObject, mask_buf: *[4]u8, output_: []u8, input_: []const u8) void {
mask_buf.* = globalThis.bunVM().rareData().entropySlice(4)[0..4].*;
const mask = mask_buf.*;
const skip_mask = @as(u32, @bitCast(mask)) == 0;
if (!skip_mask) {
fillWithSkipMask(mask, output_, input_, false);
} else {
fillWithSkipMask(mask, output_, input_, true);
}
}
fn fillWithSkipMask(mask: [4]u8, output_: []u8, input_: []const u8, comptime skip_mask: bool) void {
var input = input_;
var output = output_;
if (comptime Environment.enableSIMD) {
if (input.len >= strings.ascii_vector_size) {
const vec: strings.AsciiVector = brk: {
var in: [strings.ascii_vector_size]u8 = undefined;
comptime var i: usize = 0;
inline while (i < strings.ascii_vector_size) : (i += 4) {
in[i..][0..4].* = mask;
}
break :brk @as(strings.AsciiVector, in);
};
const end_ptr_wrapped_to_last_16 = input.ptr + input.len - (input.len % strings.ascii_vector_size);
if (comptime skip_mask) {
while (input.ptr != end_ptr_wrapped_to_last_16) {
const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*);
output.ptr[0..strings.ascii_vector_size].* = input_vec;
output = output[strings.ascii_vector_size..];
input = input[strings.ascii_vector_size..];
}
} else {
while (input.ptr != end_ptr_wrapped_to_last_16) {
const input_vec: strings.AsciiVector = @as(strings.AsciiVector, input[0..strings.ascii_vector_size].*);
output.ptr[0..strings.ascii_vector_size].* = input_vec ^ vec;
output = output[strings.ascii_vector_size..];
input = input[strings.ascii_vector_size..];
}
}
}
// hint to the compiler not to vectorize the next loop
bun.assert(input.len < strings.ascii_vector_size);
}
if (comptime !skip_mask) {
while (input.len >= 4) {
const input_vec: [4]u8 = input[0..4].*;
output.ptr[0..4].* = [4]u8{
input_vec[0] ^ mask[0],
input_vec[1] ^ mask[1],
input_vec[2] ^ mask[2],
input_vec[3] ^ mask[3],
};
output = output[4..];
input = input[4..];
}
} else {
while (input.len >= 4) {
const input_vec: [4]u8 = input[0..4].*;
output.ptr[0..4].* = input_vec;
output = output[4..];
input = input[4..];
}
}
if (comptime !skip_mask) {
for (input, 0..) |c, i| {
output[i] = c ^ mask[i % 4];
}
} else {
for (input, 0..) |c, i| {
output[i] = c;
}
}
}
};
pub const ReceiveState = enum {
need_header,
need_mask,
need_body,
extended_payload_length_16,
extended_payload_length_64,
ping,
pong,
close,
fail,
pub fn needControlFrame(this: ReceiveState) bool {
return this != .need_body;
}
pub fn parseWebSocketHeader(
bytes: [2]u8,
receiving_type: *Opcode,
payload_length: *usize,
is_fragmented: *bool,
is_final: *bool,
need_compression: *bool,
) ReceiveState {
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64) |
// |N|V|V|V| |S| | (if payload len==126/127) |
// | |1|2|3| |K| | |
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
// | Extended payload length continued, if payload len == 127 |
// + - - - - - - - - - - - - - - - +-------------------------------+
// | |Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
// | Masking-key (continued) | Payload Data |
// +-------------------------------- - - - - - - - - - - - - - - - +
// : Payload Data continued ... :
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
// | Payload Data continued ... |
// +---------------------------------------------------------------+
const header = WebsocketHeader.fromSlice(bytes);
const payload = @as(usize, header.len);
payload_length.* = payload;
receiving_type.* = header.opcode;
is_fragmented.* = switch (header.opcode) {
.Continue => true,
else => false,
} or !header.final;
is_final.* = header.final;
need_compression.* = header.compressed;
if (header.mask and (header.opcode == .Text or header.opcode == .Binary)) {
return .need_mask;
}
// reserved bits must be 0
if (header.rsv != 0) {
return .fail;
}
return switch (header.opcode) {
.Text, .Continue, .Binary => if (payload <= 125)
return .need_body
else if (payload == 126)
return .extended_payload_length_16
else if (payload == 127)
return .extended_payload_length_64
else
return .fail,
.Close => .close,
.Ping => .ping,
.Pong => .pong,
else => .fail,
};
}
};
pub const DataType = enum {
none,
text,
binary,
};
pub const Copy = union(enum) {
utf16: []const u16,
latin1: []const u8,
bytes: []const u8,
raw: []const u8,
pub fn len(this: @This(), byte_len: *usize) usize {
switch (this) {
.utf16 => {
byte_len.* = strings.elementLengthUTF16IntoUTF8([]const u16, this.utf16);
return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
},
.latin1 => {
byte_len.* = strings.elementLengthLatin1IntoUTF8([]const u8, this.latin1);
return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
},
.bytes => {
byte_len.* = this.bytes.len;
return WebsocketHeader.frameSizeIncludingMask(byte_len.*);
},
.raw => {
byte_len.* = this.raw.len;
return this.raw.len;
},
}
}
pub fn copy(this: @This(), globalThis: *JSC.JSGlobalObject, buf: []u8, content_byte_len: usize, opcode: Opcode, compressed: bool) void {
if (this == .raw) {
bun.assert(buf.len >= this.raw.len);
bun.assert(buf.ptr != this.raw.ptr);
@memcpy(buf[0..this.raw.len], this.raw);
return;
}
const how_big_is_the_length_integer = WebsocketHeader.lengthByteCount(content_byte_len);
const how_big_is_the_mask = 4;
const mask_offset = 2 + how_big_is_the_length_integer;
const content_offset = mask_offset + how_big_is_the_mask;
// 2 byte header
// 4 byte mask
// 0, 2, 8 byte length
var to_mask = buf[content_offset..];
var header = @as(WebsocketHeader, @bitCast(@as(u16, 0)));
// Write extended length if needed
switch (how_big_is_the_length_integer) {
0 => {},
2 => std.mem.writeInt(u16, buf[2..][0..2], @as(u16, @truncate(content_byte_len)), .big),
8 => std.mem.writeInt(u64, buf[2..][0..8], @as(u64, @truncate(content_byte_len)), .big),
else => unreachable,
}
header.mask = true;
header.compressed = compressed;
header.final = true;
header.opcode = opcode;
bun.assert(WebsocketHeader.frameSizeIncludingMask(content_byte_len) == buf.len);
switch (this) {
.utf16 => |utf16| {
header.len = WebsocketHeader.packLength(content_byte_len);
const encode_into_result = strings.copyUTF16IntoUTF8(to_mask, []const u16, utf16, true);
bun.assert(@as(usize, encode_into_result.written) == content_byte_len);
bun.assert(@as(usize, encode_into_result.read) == utf16.len);
header.len = WebsocketHeader.packLength(encode_into_result.written);
var fib = std.io.fixedBufferStream(buf);
header.writeHeader(fib.writer(), encode_into_result.written) catch unreachable;
Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]);
},
.latin1 => |latin1| {
const encode_into_result = strings.copyLatin1IntoUTF8(to_mask, []const u8, latin1);
bun.assert(@as(usize, encode_into_result.written) == content_byte_len);
// latin1 can contain non-ascii
bun.assert(@as(usize, encode_into_result.read) == latin1.len);
header.len = WebsocketHeader.packLength(encode_into_result.written);
var fib = std.io.fixedBufferStream(buf);
header.writeHeader(fib.writer(), encode_into_result.written) catch unreachable;
Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], to_mask[0..content_byte_len]);
},
.bytes => |bytes| {
header.len = WebsocketHeader.packLength(bytes.len);
var fib = std.io.fixedBufferStream(buf);
header.writeHeader(fib.writer(), bytes.len) catch unreachable;
Mask.fill(globalThis, buf[mask_offset..][0..4], to_mask[0..content_byte_len], bytes);
},
.raw => unreachable,
}
}
};
// Add public exports at the end
pub const WebSocketHTTPClient = @import("WebsocketHTTPUpgradeClient.zig").WebSocketHTTPClient;
pub const WebSocketHTTPSClient = @import("WebsocketHTTPUpgradeClient.zig").WebSocketHTTPSClient;
pub const WebSocketClient = @import("WebSocket.zig").WebSocketClient;
pub const WebSocketClientTLS = @import("WebSocket.zig").WebSocketClientTLS;

File diff suppressed because it is too large Load Diff