diff --git a/src/bun.js/webcore/fetch.zig b/src/bun.js/webcore/fetch.zig index c7d421b769..46b06fea9b 100644 --- a/src/bun.js/webcore/fetch.zig +++ b/src/bun.js/webcore/fetch.zig @@ -108,7 +108,7 @@ pub const FetchTasklet = struct { // custom checkServerIdentity check_server_identity: jsc.Strong.Optional = .empty, reject_unauthorized: bool = true, - is_websocket_upgrade: bool = false, + upgraded_connection: bool = false, // Custom Hostname hostname: ?[]u8 = null, is_waiting_body: bool = false, @@ -1070,7 +1070,7 @@ pub const FetchTasklet = struct { .memory_reporter = fetch_options.memory_reporter, .check_server_identity = fetch_options.check_server_identity, .reject_unauthorized = fetch_options.reject_unauthorized, - .is_websocket_upgrade = fetch_options.is_websocket_upgrade, + .upgraded_connection = fetch_options.upgraded_connection, }; fetch_tasklet.signals = fetch_tasklet.signal_store.to(); @@ -1203,7 +1203,7 @@ pub const FetchTasklet = struct { // dont have backpressure so we will schedule the data to be written // if we have backpressure the onWritable will drain the buffer needs_schedule = stream_buffer.isEmpty(); - if (this.is_websocket_upgrade) { + if (this.upgraded_connection) { bun.handleOom(stream_buffer.write(data)); } else { //16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n @@ -1277,7 +1277,7 @@ pub const FetchTasklet = struct { check_server_identity: jsc.Strong.Optional = .empty, unix_socket_path: ZigString.Slice, ssl_config: ?*SSLConfig = null, - is_websocket_upgrade: bool = false, + upgraded_connection: bool = false, }; pub fn queue( @@ -1501,7 +1501,7 @@ pub fn Bun__fetch_( var memory_reporter = bun.handleOom(bun.default_allocator.create(bun.MemoryReportingAllocator)); // used to clean up dynamically allocated memory on error (a poor man's errdefer) var is_error = false; - var is_websocket_upgrade = false; + var upgraded_connection = false; var allocator = memory_reporter.wrap(bun.default_allocator); errdefer bun.default_allocator.destroy(memory_reporter); defer { @@ -2210,8 +2210,8 @@ pub fn Bun__fetch_( const upgrade = _upgrade.toSlice(bun.default_allocator); defer upgrade.deinit(); const slice = upgrade.slice(); - if (bun.strings.eqlComptime(slice, "websocket")) { - is_websocket_upgrade = true; + if (!bun.strings.eqlComptime(slice, "h2") and !bun.strings.eqlComptime(slice, "h2c")) { + upgraded_connection = true; } } @@ -2350,7 +2350,7 @@ pub fn Bun__fetch_( } } - if (!method.hasRequestBody() and body.hasBody() and !is_websocket_upgrade) { + if (!method.hasRequestBody() and body.hasBody() and !upgraded_connection) { const err = globalThis.toTypeError(.INVALID_ARG_VALUE, fetch_error_unexpected_body, .{}); is_error = true; return JSPromise.dangerouslyCreateRejectedPromiseValueWithoutNotifyingVM(globalThis, err); @@ -2668,7 +2668,7 @@ pub fn Bun__fetch_( .ssl_config = ssl_config, .hostname = hostname, .memory_reporter = memory_reporter, - .is_websocket_upgrade = is_websocket_upgrade, + .upgraded_connection = upgraded_connection, .check_server_identity = if (check_server_identity.isEmptyOrUndefinedOrNull()) .empty else .create(check_server_identity, globalThis), .unix_socket_path = unix_socket_path, }, diff --git a/src/http.zig b/src/http.zig index 1d8eef5037..5f1e55f770 100644 --- a/src/http.zig +++ b/src/http.zig @@ -393,6 +393,11 @@ pub const HTTPVerboseLevel = enum { curl, }; +const HTTPUpgradeState = enum(u2) { + none = 0, + pending = 1, + upgraded = 2, +}; pub const Flags = packed struct(u16) { disable_timeout: bool = false, disable_keepalive: bool = false, @@ -405,8 +410,7 @@ pub const Flags = packed struct(u16) { is_preconnect_only: bool = false, is_streaming_request_body: bool = false, defer_fail_until_connecting_is_complete: bool = false, - is_websockets: bool = false, - websocket_upgraded: bool = false, + upgrade_state: HTTPUpgradeState = .none, _padding: u3 = 0, }; @@ -595,8 +599,9 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request { override_accept_encoding = true; }, hashHeaderConst("Upgrade") => { - if (std.ascii.eqlIgnoreCase(this.headerStr(header_values[i]), "websocket")) { - this.flags.is_websockets = true; + const value = this.headerStr(header_values[i]); + if (!std.ascii.eqlIgnoreCase(value, "h2") and !std.ascii.eqlIgnoreCase(value, "h2c")) { + this.flags.upgrade_state = .pending; } }, hashHeaderConst(chunked_encoded_header.name) => { @@ -1030,8 +1035,8 @@ pub fn writeToStream(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCo log("flushStream", .{}); var stream = &this.state.original_request_body.stream; const stream_buffer = stream.buffer orelse return; - if (this.flags.is_websockets and !this.flags.websocket_upgraded) { - // cannot drain yet, websocket is waiting for upgrade + if (this.flags.upgrade_state == .pending) { + // cannot drain yet, upgrade is waiting for upgrade return; } const buffer = stream_buffer.acquire(); @@ -1378,14 +1383,13 @@ pub fn handleOnDataHeaders( to_read = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..]; if (response.status_code == 101) { - if (!this.flags.is_websockets) { + if (this.flags.upgrade_state == .none) { // we cannot upgrade to websocket because the client did not request it! this.closeAndFail(error.UnrequestedUpgrade, is_ssl, socket); return; } // special case for websocket upgrade - this.flags.is_websockets = true; - this.flags.websocket_upgraded = true; + this.flags.upgrade_state = .upgraded; if (this.signals.upgraded) |upgraded| { upgraded.store(true, .monotonic); } @@ -2448,7 +2452,7 @@ pub fn handleResponseMetadata( log("handleResponseMetadata: content_length is null and transfer_encoding {}", .{this.state.transfer_encoding}); } - if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.websocket_upgraded)) { + if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.upgrade_state == .upgraded)) { return ShouldContinue.continue_streaming; } else { return ShouldContinue.finished;