more generic

This commit is contained in:
Ciro Spaciari
2025-09-03 19:34:31 -07:00
parent 4b61f99adf
commit b05964def4
2 changed files with 23 additions and 19 deletions

View File

@@ -108,7 +108,7 @@ pub const FetchTasklet = struct {
// custom checkServerIdentity
check_server_identity: jsc.Strong.Optional = .empty,
reject_unauthorized: bool = true,
is_websocket_upgrade: bool = false,
upgraded_connection: bool = false,
// Custom Hostname
hostname: ?[]u8 = null,
is_waiting_body: bool = false,
@@ -1070,7 +1070,7 @@ pub const FetchTasklet = struct {
.memory_reporter = fetch_options.memory_reporter,
.check_server_identity = fetch_options.check_server_identity,
.reject_unauthorized = fetch_options.reject_unauthorized,
.is_websocket_upgrade = fetch_options.is_websocket_upgrade,
.upgraded_connection = fetch_options.upgraded_connection,
};
fetch_tasklet.signals = fetch_tasklet.signal_store.to();
@@ -1203,7 +1203,7 @@ pub const FetchTasklet = struct {
// dont have backpressure so we will schedule the data to be written
// if we have backpressure the onWritable will drain the buffer
needs_schedule = stream_buffer.isEmpty();
if (this.is_websocket_upgrade) {
if (this.upgraded_connection) {
bun.handleOom(stream_buffer.write(data));
} else {
//16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n
@@ -1277,7 +1277,7 @@ pub const FetchTasklet = struct {
check_server_identity: jsc.Strong.Optional = .empty,
unix_socket_path: ZigString.Slice,
ssl_config: ?*SSLConfig = null,
is_websocket_upgrade: bool = false,
upgraded_connection: bool = false,
};
pub fn queue(
@@ -1501,7 +1501,7 @@ pub fn Bun__fetch_(
var memory_reporter = bun.handleOom(bun.default_allocator.create(bun.MemoryReportingAllocator));
// used to clean up dynamically allocated memory on error (a poor man's errdefer)
var is_error = false;
var is_websocket_upgrade = false;
var upgraded_connection = false;
var allocator = memory_reporter.wrap(bun.default_allocator);
errdefer bun.default_allocator.destroy(memory_reporter);
defer {
@@ -2210,8 +2210,8 @@ pub fn Bun__fetch_(
const upgrade = _upgrade.toSlice(bun.default_allocator);
defer upgrade.deinit();
const slice = upgrade.slice();
if (bun.strings.eqlComptime(slice, "websocket")) {
is_websocket_upgrade = true;
if (!bun.strings.eqlComptime(slice, "h2") and !bun.strings.eqlComptime(slice, "h2c")) {
upgraded_connection = true;
}
}
@@ -2350,7 +2350,7 @@ pub fn Bun__fetch_(
}
}
if (!method.hasRequestBody() and body.hasBody() and !is_websocket_upgrade) {
if (!method.hasRequestBody() and body.hasBody() and !upgraded_connection) {
const err = globalThis.toTypeError(.INVALID_ARG_VALUE, fetch_error_unexpected_body, .{});
is_error = true;
return JSPromise.dangerouslyCreateRejectedPromiseValueWithoutNotifyingVM(globalThis, err);
@@ -2668,7 +2668,7 @@ pub fn Bun__fetch_(
.ssl_config = ssl_config,
.hostname = hostname,
.memory_reporter = memory_reporter,
.is_websocket_upgrade = is_websocket_upgrade,
.upgraded_connection = upgraded_connection,
.check_server_identity = if (check_server_identity.isEmptyOrUndefinedOrNull()) .empty else .create(check_server_identity, globalThis),
.unix_socket_path = unix_socket_path,
},

View File

@@ -393,6 +393,11 @@ pub const HTTPVerboseLevel = enum {
curl,
};
const HTTPUpgradeState = enum(u2) {
none = 0,
pending = 1,
upgraded = 2,
};
pub const Flags = packed struct(u16) {
disable_timeout: bool = false,
disable_keepalive: bool = false,
@@ -405,8 +410,7 @@ pub const Flags = packed struct(u16) {
is_preconnect_only: bool = false,
is_streaming_request_body: bool = false,
defer_fail_until_connecting_is_complete: bool = false,
is_websockets: bool = false,
websocket_upgraded: bool = false,
upgrade_state: HTTPUpgradeState = .none,
_padding: u3 = 0,
};
@@ -595,8 +599,9 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
override_accept_encoding = true;
},
hashHeaderConst("Upgrade") => {
if (std.ascii.eqlIgnoreCase(this.headerStr(header_values[i]), "websocket")) {
this.flags.is_websockets = true;
const value = this.headerStr(header_values[i]);
if (!std.ascii.eqlIgnoreCase(value, "h2") and !std.ascii.eqlIgnoreCase(value, "h2c")) {
this.flags.upgrade_state = .pending;
}
},
hashHeaderConst(chunked_encoded_header.name) => {
@@ -1030,8 +1035,8 @@ pub fn writeToStream(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCo
log("flushStream", .{});
var stream = &this.state.original_request_body.stream;
const stream_buffer = stream.buffer orelse return;
if (this.flags.is_websockets and !this.flags.websocket_upgraded) {
// cannot drain yet, websocket is waiting for upgrade
if (this.flags.upgrade_state == .pending) {
// cannot drain yet, upgrade is waiting for upgrade
return;
}
const buffer = stream_buffer.acquire();
@@ -1378,14 +1383,13 @@ pub fn handleOnDataHeaders(
to_read = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..];
if (response.status_code == 101) {
if (!this.flags.is_websockets) {
if (this.flags.upgrade_state == .none) {
// we cannot upgrade to websocket because the client did not request it!
this.closeAndFail(error.UnrequestedUpgrade, is_ssl, socket);
return;
}
// special case for websocket upgrade
this.flags.is_websockets = true;
this.flags.websocket_upgraded = true;
this.flags.upgrade_state = .upgraded;
if (this.signals.upgraded) |upgraded| {
upgraded.store(true, .monotonic);
}
@@ -2448,7 +2452,7 @@ pub fn handleResponseMetadata(
log("handleResponseMetadata: content_length is null and transfer_encoding {}", .{this.state.transfer_encoding});
}
if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.websocket_upgraded)) {
if (this.method.hasBody() and (content_length == null or content_length.? > 0 or !this.state.flags.allow_keepalive or this.state.transfer_encoding == .chunked or is_server_sent_events or this.flags.upgrade_state == .upgraded)) {
return ShouldContinue.continue_streaming;
} else {
return ShouldContinue.finished;