This commit is contained in:
Ciro Spaciari
2025-03-13 16:41:04 -07:00
parent 682aa09944
commit 15a89b98c6
4 changed files with 146 additions and 109 deletions

View File

@@ -1,6 +1,11 @@
const HTTPClient = @This();
const bun = @import("root").bun;
const uws = bun.uws;
const picohttp = bun.picohttp;
const JSC = bun.JSC;
const URL = bun.URL;
const BoringSSL = bun.BoringSSL;
const string = bun.string;
const Output = bun.Output;
const Global = bun.Global;
@@ -8,45 +13,57 @@ const Environment = bun.Environment;
const strings = bun.strings;
const MutableString = bun.MutableString;
const FeatureFlags = bun.FeatureFlags;
const stringZ = bun.stringZ;
const C = bun.C;
const Loc = bun.logger.Loc;
const Log = bun.logger.Log;
const DotEnv = @import("./env_loader.zig");
// const stringZ = bun.stringZ;
// const C = bun.C;
// const Loc = bun.logger.Loc;
// const Log = bun.logger.Log;
// const DotEnv = @import("./env_loader.zig");
const std = @import("std");
const URL = @import("./url.zig").URL;
const PercentEncoding = @import("./url.zig").PercentEncoding;
const posix = std.posix;
const SOCK = posix.SOCK;
pub const MimeType = @import("./http/mime_type.zig");
// const URL = @import("./url.zig").URL;
pub const Method = @import("./http/method.zig").Method;
const Api = @import("./api/schema.zig").Api;
const Lock = bun.Mutex;
const HTTPClient = @This();
// const Api = @import("./api/schema.zig").Api;
// const Lock = bun.Mutex;
const Zlib = @import("./zlib.zig");
const Brotli = bun.brotli;
const StringBuilder = bun.StringBuilder;
const ThreadPool = bun.ThreadPool;
const ObjectPool = @import("./pool.zig").ObjectPool;
const posix = std.posix;
const SOCK = posix.SOCK;
const Arena = @import("./allocators/mimalloc_arena.zig").Arena;
const ZlibPool = @import("./http/zlib.zig");
const BoringSSL = bun.BoringSSL.c;
const Progress = bun.Progress;
const X509 = @import("./bun.js/api/bun/x509.zig");
const SSLConfig = @import("./bun.js/api/server.zig").ServerConfig.SSLConfig;
const SSLWrapper = @import("./bun.js/api/bun/ssl_wrapper.zig").SSLWrapper;
const default_allocator = bun.default_allocator;
pub const AsyncHTTP = @import("./http/client/async_http.zig").AsyncHTTP;
const registerAsyncHTTPAbortTracker = @import("./http/client/async_http.zig").registerAbortTracker;
const unregisterAsyncHTTPAbortTracker = @import("./http/client/async_http.zig").unregisterAbortTracker;
const HTTPThread = @import("./http/client/thread.zig").HTTPThread;
const getHttpContext = @import("./http/client/thread.zig").getContext;
const Encoding = @import("./http/client/async_http.zig").Encoding;
const HTTPCertError = @import("./http/client/errors.zig").HTTPCertError;
const HTTPRequestBody = @import("./http/client/request_body.zig").HTTPRequestBody;
const CertificateInfo = @import("./http/client/certificate_info.zig").CertificateInfo;
const HTTPVerboseLevel = @import("./http/client/async_http.zig").HTTPVerboseLevel;
const HTTPClientResult = @import("./http/client/result.zig").HTTPClientResult;
const ProxyTunnel = @import("./http/client/proxy_tunnel.zig").ProxyTunnel;
const Signals = @import("./http/client/signals.zig").Signals;
// const Arena = @import("./allocators/mimalloc_arena.zig").Arena;
// const ZlibPool = @import("./http/zlib.zig");
// const BoringSSL = bun.BoringSSL.c;
const Progress = bun.Progress;
// const X509 = @import("./bun.js/api/bun/x509.zig");
const SSLConfig = bun.server.ServerConfig.SSLConfig;
// const SSLWrapper = @import("./bun.js/api/bun/ssl_wrapper.zig").SSLWrapper;
const NewHTTPContext = @import("./http/client/thread.zig").NewHTTPContext;
const http_thread = @import("./http/client/thread.zig").getHttpThread();
const URLBufferPool = ObjectPool([8192]u8, null, false, 10);
const uws = bun.uws;
pub const MimeType = @import("./http/mime_type.zig");
pub const URLPath = @import("./http/url_path.zig");
pub const HTTPResponseMetadata = @import("./http/client/result.zig").HTTPResponseMetadata;
// This becomes Arena.allocator
pub var default_allocator: std.mem.Allocator = undefined;
var default_arena: Arena = undefined;
const TaggedPointerUnion = @import("./tagged_pointer.zig").TaggedPointerUnion;
const DeadSocket = opaque {};
var dead_socket = @as(*DeadSocket, @ptrFromInt(1));
pub const end_of_chunked_http1_1_encoding_response_body = @import("./http/client/async_http.zig").end_of_chunked_http1_1_encoding_response_body;
//TODO: this needs to be freed when Worker Threads are implemented
var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, uws.InternalSocket).init(bun.default_allocator);
var async_http_id_monotonic: std.atomic.Value(u32) = std.atomic.Value(u32).init(0);
const MAX_REDIRECT_URL_LENGTH = 128 * 1024;
@@ -71,8 +88,6 @@ var shared_response_headers_buf: [256]picohttp.Header = undefined;
// never finishing sending the body
const preallocate_max = 1024 * 1024 * 256;
pub const end_of_chunked_http1_1_encoding_response_body = "0\r\n\r\n";
pub const FetchRedirect = enum(u8) {
follow,
manual,
@@ -106,7 +121,7 @@ pub fn checkServerIdentity(
if (client.signals.get(.cert_errors)) {
// clone the relevant data
const cert_size = BoringSSL.i2d_X509(x509, null);
const cert = bun.default_allocator.alloc(u8, @intCast(cert_size)) catch bun.outOfMemory();
const cert = default_allocator.alloc(u8, @intCast(cert_size)) catch bun.outOfMemory();
var cert_ptr = cert.ptr;
const result_size = BoringSSL.i2d_X509(x509, &cert_ptr);
assert(result_size == cert_size);
@@ -120,16 +135,16 @@ pub fn checkServerIdentity(
client.state.certificate_info = .{
.cert = cert,
.hostname = bun.default_allocator.dupe(u8, hostname) catch bun.outOfMemory(),
.hostname = default_allocator.dupe(u8, hostname) catch bun.outOfMemory(),
.cert_error = .{
.error_no = certError.error_no,
.code = bun.default_allocator.dupeZ(u8, certError.code) catch bun.outOfMemory(),
.reason = bun.default_allocator.dupeZ(u8, certError.reason) catch bun.outOfMemory(),
.code = default_allocator.dupeZ(u8, certError.code) catch bun.outOfMemory(),
.reason = default_allocator.dupeZ(u8, certError.reason) catch bun.outOfMemory(),
},
};
// we inform the user that the cert is invalid
client.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket);
client.progressUpdate(is_ssl, getHttpContext(is_ssl), socket);
// continue until we are aborted or not
return true;
} else {
@@ -163,7 +178,7 @@ fn registerAbortTracker(
socket: NewHTTPContext(is_ssl).HTTPSocket,
) void {
if (client.signals.aborted != null) {
socket_async_http_abort_tracker.put(client.async_http_id, socket.socket) catch unreachable;
registerAsyncHTTPAbortTracker(client.async_http_id, socket.socket) catch unreachable;
}
}
@@ -171,7 +186,7 @@ fn unregisterAbortTracker(
client: *HTTPClient,
) void {
if (client.signals.aborted != null) {
_ = socket_async_http_abort_tracker.swapRemove(client.async_http_id);
_ = unregisterAsyncHTTPAbortTracker(client.async_http_id);
}
}
@@ -211,12 +226,12 @@ pub fn onOpen(
temp_hostname[_hostname.len] = 0;
hostname = temp_hostname[0.._hostname.len :0];
} else {
hostname = bun.default_allocator.dupeZ(u8, _hostname) catch unreachable;
hostname = default_allocator.dupeZ(u8, _hostname) catch unreachable;
hostname_needs_free = true;
}
}
defer if (hostname_needs_free) bun.default_allocator.free(hostname);
defer if (hostname_needs_free) default_allocator.free(hostname);
ssl_ptr.configureHTTPClient(hostname);
}
@@ -263,7 +278,7 @@ pub fn onClose(
if (client.state.flags.is_redirect_pending) {
// if the connection is closed and we are pending redirect just do the redirect
// in this case we will re-connect or go to a different socket if needed
client.doRedirect(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket);
client.doRedirect(is_ssl, getHttpContext(is_ssl), socket);
return;
}
if (in_progress) {
@@ -275,14 +290,14 @@ pub fn onClose(
const buf = client.state.getBodyBuffer();
if (buf.list.items.len > 0) {
client.state.flags.received_last_chunk = true;
client.progressUpdate(comptime is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket);
client.progressUpdate(comptime is_ssl, getHttpContext(is_ssl), socket);
return;
}
}
} else if (client.state.content_length == null and client.state.response_stage == .body) {
// no content length informed so we are done here
client.state.flags.received_last_chunk = true;
client.progressUpdate(comptime is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket);
client.progressUpdate(comptime is_ssl, getHttpContext(is_ssl), socket);
return;
}
}
@@ -338,10 +353,6 @@ inline fn getRequestBodySendBuffer(this: *@This()) HTTPThread.RequestBodyBuffer
return http_thread.getRequestBodySendBuffer(estimated_size);
}
pub inline fn cleanup(force: bool) void {
default_arena.gc(force);
}
pub const Headers = JSC.WebCore.Headers;
pub const SOCKET_FLAGS: u32 = if (Environment.isLinux)
@@ -626,7 +637,7 @@ pub const InternalState = struct {
// if exists we own this info
if (this.certificate_info) |info| {
this.certificate_info = null;
info.deinit(bun.default_allocator);
info.deinit(default_allocator);
}
this.original_request_body.deinit();
@@ -767,12 +778,6 @@ pub const InternalState = struct {
const default_redirect_count = 127;
pub const HTTPVerboseLevel = enum {
none,
headers,
curl,
};
pub const Flags = packed struct {
disable_timeout: bool = false,
disable_keepalive: bool = false,
@@ -822,7 +827,7 @@ unix_socket_path: JSC.ZigString.Slice = JSC.ZigString.Slice.empty,
pub fn deinit(this: *HTTPClient) void {
if (this.redirect.len > 0) {
bun.default_allocator.free(this.redirect);
default_allocator.free(this.redirect);
this.redirect = &.{};
}
if (this.proxy_authorization) |auth| {
@@ -891,28 +896,6 @@ pub fn hashHeaderConst(comptime name: string) u64 {
return hasher.final();
}
pub const Encoding = enum {
identity,
gzip,
deflate,
brotli,
chunked,
pub fn canUseLibDeflate(this: Encoding) bool {
return switch (this) {
.gzip, .deflate => true,
else => false,
};
}
pub fn isCompressed(this: Encoding) bool {
return switch (this) {
.brotli, .gzip, .deflate => true,
else => false,
};
}
};
const host_header_name = "Host";
const content_length_header_name = "Content-Length";
const chunked_encoded_header = picohttp.Header{ .name = "Transfer-Encoding", .value = "chunked" };
@@ -1164,12 +1147,6 @@ pub fn start(this: *HTTPClient, body: HTTPRequestBody, body_out_str: *MutableStr
}
fn start_(this: *HTTPClient, comptime is_ssl: bool) void {
if (comptime Environment.allow_assert) {
if (this.allocator.vtable == default_allocator.vtable and this.allocator.ptr != default_allocator.ptr) {
@panic("HTTPClient used with threadlocal allocator belonging to another thread. This will cause crashes.");
}
}
// Aborted before connecting
if (this.signals.get(.aborted)) {
this.fail(error.AbortedBeforeConnecting);
@@ -1190,8 +1167,6 @@ fn start_(this: *HTTPClient, comptime is_ssl: bool) void {
}
}
const Task = ThreadPool.Task;
fn printRequest(request: picohttp.Request, url: string, ignore_insecure: bool, body: []const u8, curl: bool) void {
@branchHint(.cold);
var request_ = request;
@@ -1215,7 +1190,7 @@ fn printResponse(response: picohttp.Response) void {
pub fn onPreconnect(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket) void {
log("onPreconnect({})", .{this.url});
this.unregisterAbortTracker();
const ctx = if (comptime is_ssl) &http_thread.https_context else &http_thread.http_context;
const ctx = getHttpContext(is_ssl);
ctx.releaseSocket(
socket,
this.flags.did_have_handshaking_error and !this.flags.reject_unauthorized,
@@ -1374,7 +1349,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s
this.state.request_stage = .body;
if (this.flags.is_streaming_request_body) {
// lets signal to start streaming the body
this.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket);
this.progressUpdate(is_ssl, getHttpContext(is_ssl), socket);
}
}
return;
@@ -1387,7 +1362,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s
this.state.request_stage = .body;
if (this.flags.is_streaming_request_body) {
// lets signal to start streaming the body
this.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket);
this.progressUpdate(is_ssl, getHttpContext(is_ssl), socket);
}
}
assert(
@@ -1520,7 +1495,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s
.proxy_headers => {
if (this.proxy_tunnel) |proxy| {
this.setTimeout(socket, 5);
var stack_buffer = std.heap.stackFallback(1024 * 16, bun.default_allocator);
var stack_buffer = std.heap.stackFallback(1024 * 16, default_allocator);
const allocator = stack_buffer.get();
var temporary_send_buffer = std.ArrayList(u8).fromOwnedSlice(allocator, &stack_buffer.buffer);
temporary_send_buffer.items.len = 0;
@@ -1579,7 +1554,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s
this.state.request_stage = .proxy_body;
if (this.flags.is_streaming_request_body) {
// lets signal to start streaming the body
this.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket);
this.progressUpdate(is_ssl, getHttpContext(is_ssl), socket);
}
assert(this.state.request_body.len > 0);
@@ -1953,7 +1928,6 @@ pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPCon
if (print_every_i % print_every == 0) {
Output.prettyln("Heap stats for HTTP thread\n", .{});
Output.flush();
default_arena.dumpThreadStats();
print_every_i = 0;
}
}
@@ -2385,7 +2359,7 @@ pub fn handleResponseMetadata(
var is_same_origin = true;
{
var url_arena = std.heap.ArenaAllocator.init(bun.default_allocator);
var url_arena = std.heap.ArenaAllocator.init(default_allocator);
defer url_arena.deinit();
var fba = std.heap.stackFallback(4096, url_arena.allocator());
const url_allocator = fba.get();
@@ -2434,7 +2408,7 @@ pub fn handleResponseMetadata(
// URL__getHref failed, dont pass dead tagged string to toOwnedSlice.
return error.RedirectURLInvalid;
}
const normalized_url_str = try normalized_url.toOwnedSlice(bun.default_allocator);
const normalized_url_str = try normalized_url.toOwnedSlice(default_allocator);
const new_url = URL.parse(normalized_url_str);
is_same_origin = strings.eqlCaseInsensitiveASCII(strings.withoutTrailingSlash(new_url.origin), strings.withoutTrailingSlash(this.url.origin), true);
@@ -2474,7 +2448,7 @@ pub fn handleResponseMetadata(
const normalized_url = JSC.URL.hrefFromString(bun.String.fromBytes(string_builder.allocatedSlice()));
defer normalized_url.deref();
const normalized_url_str = try normalized_url.toOwnedSlice(bun.default_allocator);
const normalized_url_str = try normalized_url.toOwnedSlice(default_allocator);
const new_url = URL.parse(normalized_url_str);
is_same_origin = strings.eqlCaseInsensitiveASCII(strings.withoutTrailingSlash(new_url.origin), strings.withoutTrailingSlash(this.url.origin), true);
@@ -2493,7 +2467,7 @@ pub fn handleResponseMetadata(
return error.InvalidRedirectURL;
}
const new_url = new_url_.toOwnedSlice(bun.default_allocator) catch {
const new_url = new_url_.toOwnedSlice(default_allocator) catch {
return error.RedirectURLTooLong;
};
this.url = URL.parse(new_url);

View File

@@ -1,7 +1,7 @@
const bun = @import("root").bun;
const std = @import("std");
const string = bun.string;
const SSLConfig = bun.server.ServerConfig.SSLConfig;
const picohttp = bun.picohttp;
const JSC = bun.JSC;
const MutableString = bun.MutableString;
@@ -14,13 +14,64 @@ const HTTPRequestBody = @import("./request_body.zig").HTTPRequestBody;
const Method = @import("../method.zig").Method;
const URL = bun.URL;
const ThreadPool = bun.ThreadPool;
const uws = bun.uws;
const PercentEncoding = @import("../../url.zig").PercentEncoding;
const http_thread = @import("./thread.zig").getHttpThread();
const Signals = @import("./signals.zig").Signals;
var async_http_id_monotonic: std.atomic.Value(u32) = std.atomic.Value(u32).init(0);
var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, uws.InternalSocket).init(bun.default_allocator);
pub const HTTPVerboseLevel = enum {
none,
headers,
curl,
};
pub fn registerAbortTracker(
async_http_id: u32,
socket: uws.InternalSocket,
) void {
socket_async_http_abort_tracker.put(async_http_id, socket) catch unreachable;
}
pub fn unregisterAbortTracker(
async_http_id: u32,
) void {
_ = socket_async_http_abort_tracker.swapRemove(async_http_id);
}
const HTTPClientResult = @import("./result.zig").HTTPClientResult;
pub fn getSocketAsyncHTTPAbortTracker() *std.AutoArrayHashMap(u32, uws.InternalSocket) {
return &socket_async_http_abort_tracker;
}
// Exists for heap stats reasons.
pub const ThreadlocalAsyncHTTP = struct {
async_http: AsyncHTTP,
pub usingnamespace bun.New(@This());
};
pub const Encoding = enum {
identity,
gzip,
deflate,
brotli,
chunked,
pub fn canUseLibDeflate(this: Encoding) bool {
return switch (this) {
.gzip, .deflate => true,
else => false,
};
}
pub fn isCompressed(this: Encoding) bool {
return switch (this) {
.brotli, .gzip, .deflate => true,
else => false,
};
}
};
pub const AsyncHTTP = struct {
request: ?picohttp.Request = null,
response: ?picohttp.Response = null,

View File

@@ -1,5 +1,7 @@
const bun = @import("root").bun;
const SSLWrapper = @import("../../bun.js/api/bun/ssl_wrapper.zig").SSLWrapper;
const getHttpContext = @import("./thread.zig").getContext;
const http_thread = @import("./http/client/thread.zig").getHttpThread();
const ProxyTunnel = struct {
wrapper: ?ProxyTunnelWrapper = null,
@@ -64,10 +66,10 @@ const ProxyTunnel = struct {
if (report_progress) {
switch (proxy.socket) {
.ssl => |socket| {
this.progressUpdate(true, &http_thread.https_context, socket);
this.progressUpdate(true, getHttpContext(true), socket);
},
.tcp => |socket| {
this.progressUpdate(false, &http_thread.http_context, socket);
this.progressUpdate(false, getHttpContext(false), socket);
},
.none => {},
}
@@ -84,10 +86,10 @@ const ProxyTunnel = struct {
if (report_progress) {
switch (proxy.socket) {
.ssl => |socket| {
this.progressUpdate(true, &http_thread.https_context, socket);
this.progressUpdate(true, getHttpContext(true), socket);
},
.tcp => |socket| {
this.progressUpdate(false, &http_thread.http_context, socket);
this.progressUpdate(false, getHttpContext(false), socket);
},
.none => {},
}
@@ -97,10 +99,10 @@ const ProxyTunnel = struct {
.proxy_headers => {
switch (proxy.socket) {
.ssl => |socket| {
this.handleOnDataHeaders(true, decoded_data, &http_thread.https_context, socket);
this.handleOnDataHeaders(true, decoded_data, getHttpContext(true), socket);
},
.tcp => |socket| {
this.handleOnDataHeaders(false, decoded_data, &http_thread.http_context, socket);
this.handleOnDataHeaders(false, decoded_data, getHttpContext(false), socket);
},
.none => {},
}

View File

@@ -1,5 +1,7 @@
const bun = @import("root").bun;
const std = @import("std");
const Global = bun.Global;
const picohttp = bun.picohttp;
const BoringSSL = bun.BoringSSL;
const JSC = bun.JSC;
@@ -19,6 +21,7 @@ const assert = bun.assert;
const strings = bun.strings;
const Batch = bun.ThreadPool.Batch;
const HTTPAllocator = @import("./http_allocator.zig");
pub var http_thread: HTTPThread = undefined;
var custom_ssl_context_map = std.AutoArrayHashMap(*SSLConfig, *NewHTTPContext(true)).init(bun.default_allocator);
@@ -27,6 +30,18 @@ const HTTPCertError = @import("./errors.zig").HTTPCertError;
const Queue = @import("./async_http.zig").Queue;
const ThreadlocalAsyncHTTP = @import("./async_http.zig").ThreadlocalAsyncHTTP;
const ProxyTunnel = @import("./proxy_tunnel.zig").ProxyTunnel;
const AsyncHTTP = @import("./async_http.zig").AsyncHTTP;
const socket_async_http_abort_tracker = AsyncHTTP.getSocketAsyncHTTPAbortTracker();
const log = Output.scoped(.fetch, false);
const HTTPClient = @import("../../http.zig").HTTPClient;
pub const end_of_chunked_http1_1_encoding_response_body = "0\r\n\r\n";
pub fn getContext(comptime ssl: bool) *NewHTTPContext(ssl) {
return if (ssl) &http_thread.https_context else &http_thread.http_context;
}
pub fn getHttpThread() *HTTPThread {
return &http_thread;
}
pub fn NewHTTPContext(comptime ssl: bool) type {
return struct {
@@ -74,11 +89,7 @@ pub fn NewHTTPContext(comptime ssl: bool) type {
pub const HTTPSocket = uws.NewSocketHandler(ssl);
pub fn context() *@This() {
if (comptime ssl) {
return &http_thread.https_context;
} else {
return &http_thread.http_context;
}
return getContext(ssl);
}
const ActiveSocket = TaggedPointerUnion(.{
@@ -351,7 +362,7 @@ pub fn NewHTTPContext(comptime ssl: bool) type {
return client.onData(
comptime ssl,
buf,
if (comptime ssl) &http_thread.https_context else &http_thread.http_context,
getContext(ssl),
socket,
);
} else if (tagged.is(PooledSocket)) {
@@ -675,8 +686,7 @@ pub const HTTPThread = struct {
pub fn onStart(opts: InitOpts) void {
Output.Source.configureNamedThread("HTTP Client");
default_arena = Arena.init() catch unreachable;
default_allocator = default_arena.allocator();
HTTPAllocator.init();
const loop = bun.JSC.MiniEventLoop.initGlobal(null);