[WebSocket] Implement headers support

Fixes https://github.com/oven-sh/bun/issues/1676
This commit is contained in:
Jarred Sumner
2022-12-28 18:39:19 -08:00
parent 384a9cda5e
commit ba0b5baee4
8 changed files with 204 additions and 23 deletions

View File

@@ -1656,6 +1656,25 @@ interface WebSocket extends EventTarget {
declare var WebSocket: {
prototype: WebSocket;
new (url: string | URL, protocols?: string | string[]): WebSocket;
new (
url: string | URL,
options: {
/**
* An object specifying connection headers
*
* This is a Bun-specific extension.
*/
headers?: HeadersInit;
/**
* A string specifying the subprotocols the server is willing to accept.
*/
protocol?: string;
/**
* A string array specifying the subprotocols the server is willing to accept.
*/
protocols?: string[];
},
): WebSocket;
readonly CLOSED: number;
readonly CLOSING: number;
readonly CONNECTING: number;

View File

@@ -1,4 +1,4 @@
//-- AUTOGENERATED FILE -- 1672229965
//-- AUTOGENERATED FILE -- 1672280340
// clang-format off
#pragma once

View File

@@ -1,5 +1,5 @@
// clang-format off
//-- AUTOGENERATED FILE -- 1672229965
//-- AUTOGENERATED FILE -- 1672280340
#pragma once
#include <stddef.h>
@@ -648,7 +648,7 @@ ZIG_DECL JSC__JSValue FileSink__write(JSC__JSGlobalObject* arg0, JSC__CallFrame*
#ifdef __cplusplus
ZIG_DECL void Bun__WebSocketHTTPClient__cancel(WebSocketHTTPClient* arg0);
ZIG_DECL WebSocketHTTPClient* Bun__WebSocketHTTPClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6);
ZIG_DECL WebSocketHTTPClient* Bun__WebSocketHTTPClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6, ZigString* arg7, ZigString* arg8, size_t arg9);
ZIG_DECL void Bun__WebSocketHTTPClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2);
#endif
@@ -656,7 +656,7 @@ ZIG_DECL void Bun__WebSocketHTTPClient__register(JSC__JSGlobalObject* arg0, void
#ifdef __cplusplus
ZIG_DECL void Bun__WebSocketHTTPSClient__cancel(WebSocketHTTPSClient* arg0);
ZIG_DECL WebSocketHTTPSClient* Bun__WebSocketHTTPSClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6);
ZIG_DECL WebSocketHTTPSClient* Bun__WebSocketHTTPSClient__connect(JSC__JSGlobalObject* arg0, void* arg1, void* arg2, const ZigString* arg3, uint16_t arg4, const ZigString* arg5, const ZigString* arg6, ZigString* arg7, ZigString* arg8, size_t arg9);
ZIG_DECL void Bun__WebSocketHTTPSClient__register(JSC__JSGlobalObject* arg0, void* arg1, void* arg2);
#endif

View File

@@ -37,6 +37,13 @@
#include "JSDOMConvertNumbers.h"
#include "JSDOMConvertSequences.h"
#include "JSDOMConvertStrings.h"
#include "JSDOMConvertBoolean.h"
#include "JSDOMConvertRecord.h"
#include "JSDOMConvertUnion.h"
#include "JSDOMExceptionHandling.h"
#include "JSDOMGlobalObjectInlines.h"
#include "JSDOMIterator.h"
#include "JSDOMOperation.h"
#include "JSDOMExceptionHandling.h"
#include "JSDOMGlobalObjectInlines.h"
#include "JSDOMOperation.h"
@@ -54,6 +61,8 @@
#include <wtf/GetPtr.h>
#include <wtf/PointerPreparations.h>
#include <wtf/URL.h>
#include "IDLTypes.h"
#include "FetchHeaders.h"
namespace WebCore {
using namespace JSC;
@@ -185,6 +194,54 @@ static inline EncodedJSValue constructJSWebSocket2(JSGlobalObject* lexicalGlobal
return JSValue::encode(jsValue);
}
static inline EncodedJSValue constructJSWebSocket3(JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame, JSValue urlValue, JSValue optionsObjectValue)
{
VM& vm = lexicalGlobalObject->vm();
auto throwScope = DECLARE_THROW_SCOPE(vm);
auto* globalObject = jsCast<Zig::GlobalObject*>(lexicalGlobalObject);
auto* context = globalObject->scriptExecutionContext();
if (UNLIKELY(!context))
return throwConstructorScriptExecutionContextUnavailableError(*lexicalGlobalObject, throwScope, "WebSocket");
auto url = convert<IDLUSVString>(*lexicalGlobalObject, urlValue);
RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
Vector<String> protocols;
auto headersInit = std::optional<Converter<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>::ReturnType>();
if (JSC::JSObject* options = optionsObjectValue.getObject()) {
if (JSValue headersValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "headers"_s)))) {
if (!headersValue.isUndefinedOrNull()) {
headersInit = convert<IDLUnion<IDLSequence<IDLSequence<IDLByteString>>, IDLRecord<IDLByteString, IDLByteString>>>(*lexicalGlobalObject, headersValue);
RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
}
}
if (JSValue protocolsValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "protocols"_s)))) {
if (!protocolsValue.isUndefinedOrNull()) {
protocols = convert<IDLSequence<IDLDOMString>>(*lexicalGlobalObject, protocolsValue);
RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
}
} else if (JSValue protocolValue = options->getIfPropertyExists(globalObject, PropertyName(Identifier::fromString(vm, "protocol"_s)))) {
if (!protocolValue.isUndefinedOrNull()) {
protocols = Vector<String> { convert<IDLDOMString>(*lexicalGlobalObject, protocolValue) };
RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
}
}
}
RETURN_IF_EXCEPTION(throwScope, encodedJSValue());
auto object = WebSocket::create(*context, WTFMove(url), protocols, WTFMove(headersInit));
if constexpr (IsExceptionOr<decltype(object)>)
RETURN_IF_EXCEPTION(throwScope, {});
static_assert(TypeOrExceptionOrUnderlyingType<decltype(object)>::isRef);
auto jsValue = toJSNewlyCreated<IDLInterface<WebSocket>>(*lexicalGlobalObject, *globalObject, throwScope, WTFMove(object));
if constexpr (IsExceptionOr<decltype(object)>)
RETURN_IF_EXCEPTION(throwScope, {});
setSubclassStructureIfNeeded<WebSocket>(lexicalGlobalObject, callFrame, asObject(jsValue));
RETURN_IF_EXCEPTION(throwScope, {});
return JSValue::encode(jsValue);
}
template<> EncodedJSValue JSC_HOST_CALL_ATTRIBUTES JSWebSocketDOMConstructor::construct(JSGlobalObject* lexicalGlobalObject, CallFrame* callFrame)
{
VM& vm = lexicalGlobalObject->vm();
@@ -204,7 +261,12 @@ template<> EncodedJSValue JSC_HOST_CALL_ATTRIBUTES JSWebSocketDOMConstructor::co
if (success)
RELEASE_AND_RETURN(throwScope, (constructJSWebSocket1(lexicalGlobalObject, callFrame)));
}
RELEASE_AND_RETURN(throwScope, (constructJSWebSocket2(lexicalGlobalObject, callFrame)));
if (distinguishingArg.isString()) {
RELEASE_AND_RETURN(throwScope, (constructJSWebSocket2(lexicalGlobalObject, callFrame)));
} else if (distinguishingArg.isObject()) {
RELEASE_AND_RETURN(throwScope, (constructJSWebSocket3(lexicalGlobalObject, callFrame, callFrame->uncheckedArgument(0), distinguishingArg)));
}
}
return argsCount < 1 ? throwVMError(lexicalGlobalObject, throwScope, createNotEnoughArgumentsError(lexicalGlobalObject)) : throwVMTypeError(lexicalGlobalObject, throwScope);
}

View File

@@ -197,10 +197,15 @@ WebSocket::~WebSocket()
ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url)
{
return create(context, url, Vector<String> {});
return create(context, url, Vector<String> {}, std::nullopt);
}
ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols)
{
return create(context, url, protocols, std::nullopt);
}
ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&& headers)
{
if (url.isNull())
return Exception { SyntaxError };
@@ -208,7 +213,7 @@ ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, c
auto socket = adoptRef(*new WebSocket(context));
// socket->suspendIfNeeded();
auto result = socket->connect(url, protocols);
auto result = socket->connect(url, protocols, WTFMove(headers));
// auto result = socket->connect(url, protocols);
if (result.hasException())
@@ -224,12 +229,12 @@ ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, c
ExceptionOr<void> WebSocket::connect(const String& url)
{
return connect(url, Vector<String> {});
return connect(url, Vector<String> {}, std::nullopt);
}
ExceptionOr<void> WebSocket::connect(const String& url, const String& protocol)
{
return connect(url, Vector<String> { 1, protocol });
return connect(url, Vector<String> { 1, protocol }, std::nullopt);
}
void WebSocket::failAsynchronously()
@@ -266,6 +271,11 @@ static String hostName(const URL& url, bool secure)
}
ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols)
{
return connect(url, protocols, std::nullopt);
}
ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&& headersInit)
{
LOG(Network, "WebSocket %p connect() url='%s'", this, url.utf8().data());
m_url = URL { url };
@@ -280,9 +290,9 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr
return Exception { SyntaxError, makeString("Invalid url for WebSocket "_s, m_url.stringCenterEllipsizedToLength()) };
}
bool is_secure = m_url.protocolIs("wss"_s);
bool is_secure = m_url.protocolIs("wss"_s) || m_url.protocolIs("https"_s);
if (!m_url.protocolIs("ws"_s) && !is_secure) {
if (!m_url.protocolIs("http"_s) && !m_url.protocolIs("ws"_s) && !is_secure) {
// context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, );
m_state = CLOSED;
updateHasPendingActivity();
@@ -371,19 +381,41 @@ ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& pr
port = userPort.value();
}
Vector<ZigString, 8> headerNames;
Vector<ZigString, 8> headerValues;
auto headersOrException = FetchHeaders::create(WTFMove(headersInit));
if (UNLIKELY(headersOrException.hasException())) {
m_state = CLOSED;
updateHasPendingActivity();
return headersOrException.releaseException();
}
auto headers = headersOrException.releaseReturnValue();
headerNames.reserveInitialCapacity(headers.get().internalHeaders().size());
headerValues.reserveInitialCapacity(headers.get().internalHeaders().size());
auto iterator = headers.get().createIterator();
while (auto value = iterator.next()) {
headerNames.uncheckedAppend(Zig::toZigString(value->key));
headerValues.uncheckedAppend(Zig::toZigString(value->value));
}
m_isSecure = is_secure;
this->incPendingActivityCount();
if (is_secure) {
us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<true>();
RELEASE_ASSERT(ctx);
this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString);
this->m_upgradeClient = Bun__WebSocketHTTPSClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString, headerNames.data(), headerValues.data(), headerNames.size());
} else {
us_socket_context_t* ctx = scriptExecutionContext()->webSocketContext<false>();
RELEASE_ASSERT(ctx);
this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString);
this->m_upgradeClient = Bun__WebSocketHTTPClient__connect(scriptExecutionContext()->jsGlobalObject(), ctx, this, &host, port, &path, &clientProtocolString, headerNames.data(), headerValues.data(), headerNames.size());
}
headerValues.clear();
headerNames.clear();
if (this->m_upgradeClient == nullptr) {
// context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, );
m_state = CLOSED;

View File

@@ -36,6 +36,7 @@
#include <wtf/URL.h>
#include <wtf/HashSet.h>
#include <wtf/Lock.h>
#include "FetchHeaders.h"
namespace uWS {
template<bool, bool, typename>
@@ -59,6 +60,7 @@ public:
static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url);
static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const String& protocol);
static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols);
static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&&);
~WebSocket();
enum State {
@@ -71,6 +73,7 @@ public:
ExceptionOr<void> connect(const String& url);
ExceptionOr<void> connect(const String& url, const String& protocol);
ExceptionOr<void> connect(const String& url, const Vector<String>& protocols);
ExceptionOr<void> connect(const String& url, const Vector<String>& protocols, std::optional<FetchHeaders::Init>&&);
ExceptionOr<void> send(const String& message);
ExceptionOr<void> send(JSC::ArrayBuffer&);

View File

@@ -24,14 +24,48 @@ const Opcode = @import("./websocket.zig").Opcode;
const log = Output.scoped(.WebSocketClient, false);
fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, host: *const JSC.ZigString, client_protocol: *const JSC.ZigString, client_protocol_hash: *u64) std.mem.Allocator.Error![]u8 {
const NonUTF8Headers = struct {
names: []const JSC.ZigString,
values: []const JSC.ZigString,
pub fn format(self: NonUTF8Headers, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
const count = self.names.len;
var i: usize = 0;
while (i < count) : (i += 1) {
try std.fmt.format(writer, "{any}: {any}\r\n", .{ self.names[i], self.values[i] });
}
}
pub fn init(names: ?[*]const JSC.ZigString, values: ?[*]const JSC.ZigString, len: usize) NonUTF8Headers {
if (len == 0) {
return .{
.names = &[_]JSC.ZigString{},
.values = &[_]JSC.ZigString{},
};
}
return .{
.names = names.?[0..len],
.values = values.?[0..len],
};
}
};
fn buildRequestBody(
vm: *JSC.VirtualMachine,
pathname: *const JSC.ZigString,
host: *const JSC.ZigString,
client_protocol: *const JSC.ZigString,
client_protocol_hash: *u64,
extra_headers: NonUTF8Headers,
) std.mem.Allocator.Error![]u8 {
const allocator = vm.allocator;
const input_rand_buf = vm.rareData().nextUUID();
const temp_buf_size = comptime std.base64.standard.Encoder.calcSize(16);
var encoded_buf: [temp_buf_size]u8 = undefined;
const accept_key = std.base64.standard.Encoder.encode(&encoded_buf, &input_rand_buf);
var headers = [_]PicoHTTP.Header{
var static_headers = [_]PicoHTTP.Header{
.{
.name = "Sec-WebSocket-Key",
.value = accept_key,
@@ -43,9 +77,10 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos
};
if (client_protocol.len > 0)
client_protocol_hash.* = std.hash.Wyhash.hash(0, headers[1].value);
client_protocol_hash.* = std.hash.Wyhash.hash(0, static_headers[1].value);
const headers_ = static_headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))];
var headers_: []PicoHTTP.Header = headers[0 .. 1 + @as(usize, @boolToInt(client_protocol.len > 0))];
const pathname_ = pathname.slice();
const host_ = host.slice();
const pico_headers = PicoHTTP.Headers{ .headers = headers_ };
@@ -59,12 +94,9 @@ fn buildRequestBody(vm: *JSC.VirtualMachine, pathname: *const JSC.ZigString, hos
"Upgrade: websocket\r\n" ++
"Sec-WebSocket-Version: 13\r\n" ++
"{any}" ++
"{any}" ++
"\r\n",
.{
pathname_,
host_,
pico_headers,
},
.{ pathname_, host_, pico_headers, extra_headers },
);
}
@@ -174,11 +206,21 @@ pub fn NewHTTPUpgradeClient(comptime ssl: bool) type {
port: u16,
pathname: *const JSC.ZigString,
client_protocol: *const JSC.ZigString,
header_names: ?[*]const JSC.ZigString,
header_values: ?[*]const JSC.ZigString,
header_count: usize,
) callconv(.C) ?*HTTPClient {
std.debug.assert(global.bunVM().uws_event_loop != null);
var client_protocol_hash: u64 = 0;
var body = buildRequestBody(global.bunVM(), pathname, host, client_protocol, &client_protocol_hash) catch return null;
var body = buildRequestBody(
global.bunVM(),
pathname,
host,
client_protocol,
&client_protocol_hash,
NonUTF8Headers.init(header_names, header_values, header_count),
) catch return null;
var client: HTTPClient = HTTPClient{
.tcp = undefined,
.outgoing_websocket = websocket,

View File

@@ -19,6 +19,29 @@ describe("WebSocket", () => {
await closed;
});
it("supports headers", (done) => {
const server = Bun.serve({
port: 8024,
fetch(req, server) {
expect(req.headers.get("X-Hello")).toBe("World");
expect(req.headers.get("content-type")).toBe("lolwut");
server.stop();
done();
return new Response();
},
websocket: {
open(ws) {
ws.close();
},
},
});
const ws = new WebSocket(`ws://${server.hostname}:${server.port}`, {
headers: {
"X-Hello": "World",
"content-type": "lolwut",
},
});
});
it("should send and receive messages", async () => {
const ws = new WebSocket(TEST_WEBSOCKET_HOST);
await new Promise((resolve, reject) => {