Files
bun.sh/src/http/websocket.zig
2023-01-03 18:53:40 -08:00

349 lines
11 KiB
Zig

// This code is based on https://github.com/frmdstryr/zhp/blob/a4b5700c289c3619647206144e10fb414113a888/src/websocket.zig
// Thank you @frmdstryr.
const std = @import("std");
const native_endian = @import("builtin").target.cpu.arch.endian();
const os = std.os;
const bun = @import("bun");
const string = bun.string;
const Output = bun.Output;
const Global = bun.Global;
const Environment = bun.Environment;
const strings = bun.strings;
const MutableString = bun.MutableString;
const stringZ = bun.stringZ;
const default_allocator = bun.default_allocator;
const C = bun.C;
pub const Opcode = enum(u4) {
Continue = 0x0,
Text = 0x1,
Binary = 0x2,
Res3 = 0x3,
Res4 = 0x4,
Res5 = 0x5,
Res6 = 0x6,
Res7 = 0x7,
Close = 0x8,
Ping = 0x9,
Pong = 0xA,
ResB = 0xB,
ResC = 0xC,
ResD = 0xD,
ResE = 0xE,
ResF = 0xF,
pub fn isControl(opcode: Opcode) bool {
return @enumToInt(opcode) & 0x8 != 0;
}
};
pub const WebsocketHeader = packed struct {
len: u7,
mask: bool,
opcode: Opcode,
rsv3: u1 = 0,
rsv2: u1 = 0,
compressed: bool = false, // rsv1
final: bool = true,
pub fn writeHeader(header: WebsocketHeader, writer: anytype, n: usize) anyerror!void {
// packed structs are sometimes buggy
// lets check it worked right
if (comptime Environment.allow_assert) {
var buf_ = [2]u8{ 0, 0 };
var stream = std.io.fixedBufferStream(&buf_);
stream.writer().writeIntBig(u16, @bitCast(u16, header)) catch unreachable;
stream.pos = 0;
const casted = stream.reader().readIntBig(u16) catch unreachable;
std.debug.assert(casted == @bitCast(u16, header));
std.debug.assert(std.meta.eql(@bitCast(WebsocketHeader, casted), header));
}
try writer.writeIntBig(u16, @bitCast(u16, header));
std.debug.assert(header.len == packLength(n));
}
pub fn packLength(length: usize) u7 {
return switch (length) {
0...125 => @truncate(u7, length),
126...0xFFFF => 126,
else => 127,
};
}
const mask_length = 4;
const header_length = 2;
pub fn lengthByteCount(byte_length: usize) usize {
return switch (byte_length) {
0...125 => 0,
126...0xFFFF => @sizeOf(u16),
else => @sizeOf(u64),
};
}
pub fn frameSize(byte_length: usize) usize {
return header_length + byte_length + lengthByteCount(byte_length);
}
pub fn frameSizeIncludingMask(byte_length: usize) usize {
return frameSize(byte_length) + mask_length;
}
};
pub const WebsocketDataFrame = struct {
header: WebsocketHeader,
mask: [4]u8 = undefined,
data: []const u8,
pub fn isValid(dataframe: WebsocketDataFrame) bool {
// Validate control frame
if (dataframe.header.opcode.isControl()) {
if (!dataframe.header.final) {
return false; // Control frames cannot be fragmented
}
if (dataframe.data.len > 125) {
return false; // Control frame payloads cannot exceed 125 bytes
}
}
// Validate header len field
const expected = switch (dataframe.data.len) {
0...126 => dataframe.data.len,
127...0xFFFF => 126,
else => 127,
};
return dataframe.header.len == expected;
}
};
// Create a buffered writer
// TODO: This will still split packets
pub fn Writer(comptime size: usize, comptime opcode: Opcode) type {
const WriterType = switch (opcode) {
.Text => Websocket.TextFrameWriter,
.Binary => Websocket.BinaryFrameWriter,
else => @compileError("Unsupported writer opcode"),
};
return std.io.BufferedWriter(size, WriterType);
}
const ReadStream = std.io.FixedBufferStream([]u8);
pub const Websocket = struct {
pub const WriteError = error{
InvalidMessage,
MessageTooLarge,
EndOfStream,
} || std.fs.File.WriteError;
stream: std.net.Stream,
err: ?anyerror = null,
buf: [8096]u8 = undefined,
read_stream: ReadStream,
reader: ReadStream.Reader,
flags: u32 = 0,
pub fn create(
fd: std.os.fd_t,
comptime flags: u32,
) Websocket {
var stream = ReadStream{
.buffer = &[_]u8{},
.pos = 0,
};
var socket = Websocket{
.read_stream = undefined,
.reader = undefined,
.stream = std.net.Stream{ .handle = @intCast(std.os.socket_t, fd) },
.flags = flags,
};
socket.read_stream = stream;
socket.reader = socket.read_stream.reader();
return socket;
}
// ------------------------------------------------------------------------
// Stream API
// ------------------------------------------------------------------------
pub const TextFrameWriter = std.io.Writer(*Websocket, WriteError, Websocket.writeText);
pub const BinaryFrameWriter = std.io.Writer(*Websocket, anyerror, Websocket.writeBinary);
// A buffered writer that will buffer up to size bytes before writing out
pub fn newWriter(self: *Websocket, comptime size: usize, comptime opcode: Opcode) Writer(size, opcode) {
const BufferedWriter = Writer(size, opcode);
const frame_writer = switch (opcode) {
.Text => TextFrameWriter{ .context = self },
.Binary => BinaryFrameWriter{ .context = self },
else => @compileError("Unsupported writer type"),
};
return BufferedWriter{ .unbuffered_writer = frame_writer };
}
// Close and send the status
pub fn close(self: *Websocket, code: u16) !void {
const c = if (native_endian == .Big) code else @byteSwap(code);
const data = @bitCast([2]u8, c);
_ = try self.writeMessage(.Close, &data);
}
// ------------------------------------------------------------------------
// Low level API
// ------------------------------------------------------------------------
// Flush any buffered data out the underlying stream
pub fn flush(self: *Websocket) !void {
try self.io.flush();
}
pub fn writeText(self: *Websocket, data: []const u8) !usize {
return self.writeMessage(.Text, data);
}
pub fn writeBinary(self: *Websocket, data: []const u8) anyerror!usize {
return self.writeMessage(.Binary, data);
}
// Write a final message packet with the given opcode
pub fn writeMessage(self: *Websocket, opcode: Opcode, message: []const u8) anyerror!usize {
return self.writeSplitMessage(opcode, true, message);
}
// Write a message packet with the given opcode and final flag
pub fn writeSplitMessage(self: *Websocket, opcode: Opcode, final: bool, message: []const u8) anyerror!usize {
return self.writeDataFrame(WebsocketDataFrame{
.header = WebsocketHeader{
.final = final,
.opcode = opcode,
.mask = false, // Server to client is not masked
.len = WebsocketHeader.packLength(message.len),
},
.data = message,
});
}
// Write a raw data frame
pub fn writeDataFrame(self: *Websocket, dataframe: WebsocketDataFrame) anyerror!usize {
var stream = self.stream.writer();
if (!dataframe.isValid()) return error.InvalidMessage;
try stream.writeIntBig(u16, @bitCast(u16, dataframe.header));
// Write extended length if needed
const n = dataframe.data.len;
switch (n) {
0...126 => {}, // Included in header
127...0xFFFF => try stream.writeIntBig(u16, @truncate(u16, n)),
else => try stream.writeIntBig(u64, n),
}
// TODO: Handle compression
if (dataframe.header.compressed) return error.InvalidMessage;
if (dataframe.header.mask) {
const mask = &dataframe.mask;
try stream.writeAll(mask);
// Encode
for (dataframe.data) |c, i| {
try stream.writeByte(c ^ mask[i % 4]);
}
} else {
try stream.writeAll(dataframe.data);
}
// try self.io.flush();
return dataframe.data.len;
}
pub fn read(self: *Websocket) !WebsocketDataFrame {
@memset(&self.buf, 0, self.buf.len);
// Read and retry if we hit the end of the stream buffer
var start = try self.stream.read(&self.buf);
if (start == 0) {
return error.ConnectionClosed;
}
self.read_stream.pos = start;
return try self.readDataFrameInBuffer();
}
pub fn eatAt(self: *Websocket, offset: usize, _len: usize) []u8 {
const len = std.math.min(self.read_stream.buffer.len, _len);
self.read_stream.pos = len;
return self.read_stream.buffer[offset..len];
}
// Read assuming everything can fit before the stream hits the end of
// it's buffer
pub fn readDataFrameInBuffer(
self: *Websocket,
) !WebsocketDataFrame {
var buf: []u8 = self.buf[0..];
const header_bytes = buf[0..2];
var header = std.mem.zeroes(WebsocketHeader);
header.final = header_bytes[0] & 0x80 == 0x80;
// header.rsv1 = header_bytes[0] & 0x40 == 0x40;
// header.rsv2 = header_bytes[0] & 0x20;
// header.rsv3 = header_bytes[0] & 0x10;
header.opcode = @intToEnum(Opcode, @truncate(u4, header_bytes[0]));
header.mask = header_bytes[1] & 0x80 == 0x80;
header.len = @truncate(u7, header_bytes[1]);
// Decode length
var length: u64 = header.len;
switch (header.len) {
126 => {
length = std.mem.readIntBig(u16, buf[2..4]);
buf = buf[4..];
},
127 => {
length = std.mem.readIntBig(u64, buf[2..10]);
// Most significant bit must be 0
if (length >> 63 == 1) {
return error.InvalidMessage;
}
buf = buf[10..];
},
else => {
buf = buf[2..];
},
}
const start: usize = if (header.mask) 4 else 0;
const end = start + length;
if (end > self.read_stream.pos) {
var extend_length = try self.stream.read(self.buf[self.read_stream.pos..]);
if (self.read_stream.pos + extend_length > self.buf.len) {
return error.MessageTooLarge;
}
self.read_stream.pos += extend_length;
}
var data = buf[start..end];
if (header.mask) {
const mask = buf[0..4];
// Decode data in place
for (data) |_, i| {
data[i] ^= mask[i % 4];
}
}
return WebsocketDataFrame{
.header = header,
.mask = if (header.mask) buf[0..4].* else undefined,
.data = data,
};
}
};