diff --git a/src/http.zig b/src/http.zig index 8ba538cbda..c5a1600eae 100644 --- a/src/http.zig +++ b/src/http.zig @@ -1,6 +1,11 @@ +const HTTPClient = @This(); + const bun = @import("root").bun; +const uws = bun.uws; const picohttp = bun.picohttp; const JSC = bun.JSC; +const URL = bun.URL; +const BoringSSL = bun.BoringSSL; const string = bun.string; const Output = bun.Output; const Global = bun.Global; @@ -8,45 +13,57 @@ const Environment = bun.Environment; const strings = bun.strings; const MutableString = bun.MutableString; const FeatureFlags = bun.FeatureFlags; -const stringZ = bun.stringZ; -const C = bun.C; -const Loc = bun.logger.Loc; -const Log = bun.logger.Log; -const DotEnv = @import("./env_loader.zig"); +// const stringZ = bun.stringZ; +// const C = bun.C; +// const Loc = bun.logger.Loc; +// const Log = bun.logger.Log; +// const DotEnv = @import("./env_loader.zig"); const std = @import("std"); -const URL = @import("./url.zig").URL; -const PercentEncoding = @import("./url.zig").PercentEncoding; +const posix = std.posix; +const SOCK = posix.SOCK; +pub const MimeType = @import("./http/mime_type.zig"); + +// const URL = @import("./url.zig").URL; pub const Method = @import("./http/method.zig").Method; -const Api = @import("./api/schema.zig").Api; -const Lock = bun.Mutex; -const HTTPClient = @This(); +// const Api = @import("./api/schema.zig").Api; +// const Lock = bun.Mutex; + const Zlib = @import("./zlib.zig"); const Brotli = bun.brotli; const StringBuilder = bun.StringBuilder; -const ThreadPool = bun.ThreadPool; const ObjectPool = @import("./pool.zig").ObjectPool; -const posix = std.posix; -const SOCK = posix.SOCK; -const Arena = @import("./allocators/mimalloc_arena.zig").Arena; -const ZlibPool = @import("./http/zlib.zig"); -const BoringSSL = bun.BoringSSL.c; -const Progress = bun.Progress; -const X509 = @import("./bun.js/api/bun/x509.zig"); -const SSLConfig = @import("./bun.js/api/server.zig").ServerConfig.SSLConfig; -const SSLWrapper = @import("./bun.js/api/bun/ssl_wrapper.zig").SSLWrapper; +const default_allocator = bun.default_allocator; +pub const AsyncHTTP = @import("./http/client/async_http.zig").AsyncHTTP; +const registerAsyncHTTPAbortTracker = @import("./http/client/async_http.zig").registerAbortTracker; +const unregisterAsyncHTTPAbortTracker = @import("./http/client/async_http.zig").unregisterAbortTracker; +const HTTPThread = @import("./http/client/thread.zig").HTTPThread; +const getHttpContext = @import("./http/client/thread.zig").getContext; +const Encoding = @import("./http/client/async_http.zig").Encoding; +const HTTPCertError = @import("./http/client/errors.zig").HTTPCertError; +const HTTPRequestBody = @import("./http/client/request_body.zig").HTTPRequestBody; +const CertificateInfo = @import("./http/client/certificate_info.zig").CertificateInfo; +const HTTPVerboseLevel = @import("./http/client/async_http.zig").HTTPVerboseLevel; +const HTTPClientResult = @import("./http/client/result.zig").HTTPClientResult; +const ProxyTunnel = @import("./http/client/proxy_tunnel.zig").ProxyTunnel; +const Signals = @import("./http/client/signals.zig").Signals; +// const Arena = @import("./allocators/mimalloc_arena.zig").Arena; +// const ZlibPool = @import("./http/zlib.zig"); +// const BoringSSL = bun.BoringSSL.c; +const Progress = bun.Progress; +// const X509 = @import("./bun.js/api/bun/x509.zig"); +const SSLConfig = bun.server.ServerConfig.SSLConfig; +// const SSLWrapper = @import("./bun.js/api/bun/ssl_wrapper.zig").SSLWrapper; +const NewHTTPContext = @import("./http/client/thread.zig").NewHTTPContext; +const http_thread = @import("./http/client/thread.zig").getHttpThread(); const URLBufferPool = ObjectPool([8192]u8, null, false, 10); -const uws = bun.uws; -pub const MimeType = @import("./http/mime_type.zig"); -pub const URLPath = @import("./http/url_path.zig"); + +pub const HTTPResponseMetadata = @import("./http/client/result.zig").HTTPResponseMetadata; // This becomes Arena.allocator -pub var default_allocator: std.mem.Allocator = undefined; -var default_arena: Arena = undefined; const TaggedPointerUnion = @import("./tagged_pointer.zig").TaggedPointerUnion; -const DeadSocket = opaque {}; -var dead_socket = @as(*DeadSocket, @ptrFromInt(1)); +pub const end_of_chunked_http1_1_encoding_response_body = @import("./http/client/async_http.zig").end_of_chunked_http1_1_encoding_response_body; + //TODO: this needs to be freed when Worker Threads are implemented -var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, uws.InternalSocket).init(bun.default_allocator); var async_http_id_monotonic: std.atomic.Value(u32) = std.atomic.Value(u32).init(0); const MAX_REDIRECT_URL_LENGTH = 128 * 1024; @@ -71,8 +88,6 @@ var shared_response_headers_buf: [256]picohttp.Header = undefined; // never finishing sending the body const preallocate_max = 1024 * 1024 * 256; -pub const end_of_chunked_http1_1_encoding_response_body = "0\r\n\r\n"; - pub const FetchRedirect = enum(u8) { follow, manual, @@ -106,7 +121,7 @@ pub fn checkServerIdentity( if (client.signals.get(.cert_errors)) { // clone the relevant data const cert_size = BoringSSL.i2d_X509(x509, null); - const cert = bun.default_allocator.alloc(u8, @intCast(cert_size)) catch bun.outOfMemory(); + const cert = default_allocator.alloc(u8, @intCast(cert_size)) catch bun.outOfMemory(); var cert_ptr = cert.ptr; const result_size = BoringSSL.i2d_X509(x509, &cert_ptr); assert(result_size == cert_size); @@ -120,16 +135,16 @@ pub fn checkServerIdentity( client.state.certificate_info = .{ .cert = cert, - .hostname = bun.default_allocator.dupe(u8, hostname) catch bun.outOfMemory(), + .hostname = default_allocator.dupe(u8, hostname) catch bun.outOfMemory(), .cert_error = .{ .error_no = certError.error_no, - .code = bun.default_allocator.dupeZ(u8, certError.code) catch bun.outOfMemory(), - .reason = bun.default_allocator.dupeZ(u8, certError.reason) catch bun.outOfMemory(), + .code = default_allocator.dupeZ(u8, certError.code) catch bun.outOfMemory(), + .reason = default_allocator.dupeZ(u8, certError.reason) catch bun.outOfMemory(), }, }; // we inform the user that the cert is invalid - client.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); + client.progressUpdate(is_ssl, getHttpContext(is_ssl), socket); // continue until we are aborted or not return true; } else { @@ -163,7 +178,7 @@ fn registerAbortTracker( socket: NewHTTPContext(is_ssl).HTTPSocket, ) void { if (client.signals.aborted != null) { - socket_async_http_abort_tracker.put(client.async_http_id, socket.socket) catch unreachable; + registerAsyncHTTPAbortTracker(client.async_http_id, socket.socket) catch unreachable; } } @@ -171,7 +186,7 @@ fn unregisterAbortTracker( client: *HTTPClient, ) void { if (client.signals.aborted != null) { - _ = socket_async_http_abort_tracker.swapRemove(client.async_http_id); + _ = unregisterAsyncHTTPAbortTracker(client.async_http_id); } } @@ -211,12 +226,12 @@ pub fn onOpen( temp_hostname[_hostname.len] = 0; hostname = temp_hostname[0.._hostname.len :0]; } else { - hostname = bun.default_allocator.dupeZ(u8, _hostname) catch unreachable; + hostname = default_allocator.dupeZ(u8, _hostname) catch unreachable; hostname_needs_free = true; } } - defer if (hostname_needs_free) bun.default_allocator.free(hostname); + defer if (hostname_needs_free) default_allocator.free(hostname); ssl_ptr.configureHTTPClient(hostname); } @@ -263,7 +278,7 @@ pub fn onClose( if (client.state.flags.is_redirect_pending) { // if the connection is closed and we are pending redirect just do the redirect // in this case we will re-connect or go to a different socket if needed - client.doRedirect(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); + client.doRedirect(is_ssl, getHttpContext(is_ssl), socket); return; } if (in_progress) { @@ -275,14 +290,14 @@ pub fn onClose( const buf = client.state.getBodyBuffer(); if (buf.list.items.len > 0) { client.state.flags.received_last_chunk = true; - client.progressUpdate(comptime is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); + client.progressUpdate(comptime is_ssl, getHttpContext(is_ssl), socket); return; } } } else if (client.state.content_length == null and client.state.response_stage == .body) { // no content length informed so we are done here client.state.flags.received_last_chunk = true; - client.progressUpdate(comptime is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); + client.progressUpdate(comptime is_ssl, getHttpContext(is_ssl), socket); return; } } @@ -338,10 +353,6 @@ inline fn getRequestBodySendBuffer(this: *@This()) HTTPThread.RequestBodyBuffer return http_thread.getRequestBodySendBuffer(estimated_size); } -pub inline fn cleanup(force: bool) void { - default_arena.gc(force); -} - pub const Headers = JSC.WebCore.Headers; pub const SOCKET_FLAGS: u32 = if (Environment.isLinux) @@ -626,7 +637,7 @@ pub const InternalState = struct { // if exists we own this info if (this.certificate_info) |info| { this.certificate_info = null; - info.deinit(bun.default_allocator); + info.deinit(default_allocator); } this.original_request_body.deinit(); @@ -767,12 +778,6 @@ pub const InternalState = struct { const default_redirect_count = 127; -pub const HTTPVerboseLevel = enum { - none, - headers, - curl, -}; - pub const Flags = packed struct { disable_timeout: bool = false, disable_keepalive: bool = false, @@ -822,7 +827,7 @@ unix_socket_path: JSC.ZigString.Slice = JSC.ZigString.Slice.empty, pub fn deinit(this: *HTTPClient) void { if (this.redirect.len > 0) { - bun.default_allocator.free(this.redirect); + default_allocator.free(this.redirect); this.redirect = &.{}; } if (this.proxy_authorization) |auth| { @@ -891,28 +896,6 @@ pub fn hashHeaderConst(comptime name: string) u64 { return hasher.final(); } -pub const Encoding = enum { - identity, - gzip, - deflate, - brotli, - chunked, - - pub fn canUseLibDeflate(this: Encoding) bool { - return switch (this) { - .gzip, .deflate => true, - else => false, - }; - } - - pub fn isCompressed(this: Encoding) bool { - return switch (this) { - .brotli, .gzip, .deflate => true, - else => false, - }; - } -}; - const host_header_name = "Host"; const content_length_header_name = "Content-Length"; const chunked_encoded_header = picohttp.Header{ .name = "Transfer-Encoding", .value = "chunked" }; @@ -1164,12 +1147,6 @@ pub fn start(this: *HTTPClient, body: HTTPRequestBody, body_out_str: *MutableStr } fn start_(this: *HTTPClient, comptime is_ssl: bool) void { - if (comptime Environment.allow_assert) { - if (this.allocator.vtable == default_allocator.vtable and this.allocator.ptr != default_allocator.ptr) { - @panic("HTTPClient used with threadlocal allocator belonging to another thread. This will cause crashes."); - } - } - // Aborted before connecting if (this.signals.get(.aborted)) { this.fail(error.AbortedBeforeConnecting); @@ -1190,8 +1167,6 @@ fn start_(this: *HTTPClient, comptime is_ssl: bool) void { } } -const Task = ThreadPool.Task; - fn printRequest(request: picohttp.Request, url: string, ignore_insecure: bool, body: []const u8, curl: bool) void { @branchHint(.cold); var request_ = request; @@ -1215,7 +1190,7 @@ fn printResponse(response: picohttp.Response) void { pub fn onPreconnect(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket) void { log("onPreconnect({})", .{this.url}); this.unregisterAbortTracker(); - const ctx = if (comptime is_ssl) &http_thread.https_context else &http_thread.http_context; + const ctx = getHttpContext(is_ssl); ctx.releaseSocket( socket, this.flags.did_have_handshaking_error and !this.flags.reject_unauthorized, @@ -1374,7 +1349,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s this.state.request_stage = .body; if (this.flags.is_streaming_request_body) { // lets signal to start streaming the body - this.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); + this.progressUpdate(is_ssl, getHttpContext(is_ssl), socket); } } return; @@ -1387,7 +1362,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s this.state.request_stage = .body; if (this.flags.is_streaming_request_body) { // lets signal to start streaming the body - this.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); + this.progressUpdate(is_ssl, getHttpContext(is_ssl), socket); } } assert( @@ -1520,7 +1495,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s .proxy_headers => { if (this.proxy_tunnel) |proxy| { this.setTimeout(socket, 5); - var stack_buffer = std.heap.stackFallback(1024 * 16, bun.default_allocator); + var stack_buffer = std.heap.stackFallback(1024 * 16, default_allocator); const allocator = stack_buffer.get(); var temporary_send_buffer = std.ArrayList(u8).fromOwnedSlice(allocator, &stack_buffer.buffer); temporary_send_buffer.items.len = 0; @@ -1579,7 +1554,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s this.state.request_stage = .proxy_body; if (this.flags.is_streaming_request_body) { // lets signal to start streaming the body - this.progressUpdate(is_ssl, if (is_ssl) &http_thread.https_context else &http_thread.http_context, socket); + this.progressUpdate(is_ssl, getHttpContext(is_ssl), socket); } assert(this.state.request_body.len > 0); @@ -1953,7 +1928,6 @@ pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPCon if (print_every_i % print_every == 0) { Output.prettyln("Heap stats for HTTP thread\n", .{}); Output.flush(); - default_arena.dumpThreadStats(); print_every_i = 0; } } @@ -2385,7 +2359,7 @@ pub fn handleResponseMetadata( var is_same_origin = true; { - var url_arena = std.heap.ArenaAllocator.init(bun.default_allocator); + var url_arena = std.heap.ArenaAllocator.init(default_allocator); defer url_arena.deinit(); var fba = std.heap.stackFallback(4096, url_arena.allocator()); const url_allocator = fba.get(); @@ -2434,7 +2408,7 @@ pub fn handleResponseMetadata( // URL__getHref failed, dont pass dead tagged string to toOwnedSlice. return error.RedirectURLInvalid; } - const normalized_url_str = try normalized_url.toOwnedSlice(bun.default_allocator); + const normalized_url_str = try normalized_url.toOwnedSlice(default_allocator); const new_url = URL.parse(normalized_url_str); is_same_origin = strings.eqlCaseInsensitiveASCII(strings.withoutTrailingSlash(new_url.origin), strings.withoutTrailingSlash(this.url.origin), true); @@ -2474,7 +2448,7 @@ pub fn handleResponseMetadata( const normalized_url = JSC.URL.hrefFromString(bun.String.fromBytes(string_builder.allocatedSlice())); defer normalized_url.deref(); - const normalized_url_str = try normalized_url.toOwnedSlice(bun.default_allocator); + const normalized_url_str = try normalized_url.toOwnedSlice(default_allocator); const new_url = URL.parse(normalized_url_str); is_same_origin = strings.eqlCaseInsensitiveASCII(strings.withoutTrailingSlash(new_url.origin), strings.withoutTrailingSlash(this.url.origin), true); @@ -2493,7 +2467,7 @@ pub fn handleResponseMetadata( return error.InvalidRedirectURL; } - const new_url = new_url_.toOwnedSlice(bun.default_allocator) catch { + const new_url = new_url_.toOwnedSlice(default_allocator) catch { return error.RedirectURLTooLong; }; this.url = URL.parse(new_url); diff --git a/src/http/client/async_http.zig b/src/http/client/async_http.zig index 721ee195a2..04fd33527f 100644 --- a/src/http/client/async_http.zig +++ b/src/http/client/async_http.zig @@ -1,7 +1,7 @@ const bun = @import("root").bun; const std = @import("std"); const string = bun.string; - +const SSLConfig = bun.server.ServerConfig.SSLConfig; const picohttp = bun.picohttp; const JSC = bun.JSC; const MutableString = bun.MutableString; @@ -14,13 +14,64 @@ const HTTPRequestBody = @import("./request_body.zig").HTTPRequestBody; const Method = @import("../method.zig").Method; const URL = bun.URL; const ThreadPool = bun.ThreadPool; +const uws = bun.uws; +const PercentEncoding = @import("../../url.zig").PercentEncoding; +const http_thread = @import("./thread.zig").getHttpThread(); +const Signals = @import("./signals.zig").Signals; +var async_http_id_monotonic: std.atomic.Value(u32) = std.atomic.Value(u32).init(0); +var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, uws.InternalSocket).init(bun.default_allocator); +pub const HTTPVerboseLevel = enum { + none, + headers, + curl, +}; + +pub fn registerAbortTracker( + async_http_id: u32, + socket: uws.InternalSocket, +) void { + socket_async_http_abort_tracker.put(async_http_id, socket) catch unreachable; +} + +pub fn unregisterAbortTracker( + async_http_id: u32, +) void { + _ = socket_async_http_abort_tracker.swapRemove(async_http_id); +} + +const HTTPClientResult = @import("./result.zig").HTTPClientResult; +pub fn getSocketAsyncHTTPAbortTracker() *std.AutoArrayHashMap(u32, uws.InternalSocket) { + return &socket_async_http_abort_tracker; +} // Exists for heap stats reasons. pub const ThreadlocalAsyncHTTP = struct { async_http: AsyncHTTP, pub usingnamespace bun.New(@This()); }; +pub const Encoding = enum { + identity, + gzip, + deflate, + brotli, + chunked, + + pub fn canUseLibDeflate(this: Encoding) bool { + return switch (this) { + .gzip, .deflate => true, + else => false, + }; + } + + pub fn isCompressed(this: Encoding) bool { + return switch (this) { + .brotli, .gzip, .deflate => true, + else => false, + }; + } +}; + pub const AsyncHTTP = struct { request: ?picohttp.Request = null, response: ?picohttp.Response = null, diff --git a/src/http/client/proxy_tunnel.zig b/src/http/client/proxy_tunnel.zig index 84acc8644d..b9dab958d8 100644 --- a/src/http/client/proxy_tunnel.zig +++ b/src/http/client/proxy_tunnel.zig @@ -1,5 +1,7 @@ const bun = @import("root").bun; const SSLWrapper = @import("../../bun.js/api/bun/ssl_wrapper.zig").SSLWrapper; +const getHttpContext = @import("./thread.zig").getContext; +const http_thread = @import("./http/client/thread.zig").getHttpThread(); const ProxyTunnel = struct { wrapper: ?ProxyTunnelWrapper = null, @@ -64,10 +66,10 @@ const ProxyTunnel = struct { if (report_progress) { switch (proxy.socket) { .ssl => |socket| { - this.progressUpdate(true, &http_thread.https_context, socket); + this.progressUpdate(true, getHttpContext(true), socket); }, .tcp => |socket| { - this.progressUpdate(false, &http_thread.http_context, socket); + this.progressUpdate(false, getHttpContext(false), socket); }, .none => {}, } @@ -84,10 +86,10 @@ const ProxyTunnel = struct { if (report_progress) { switch (proxy.socket) { .ssl => |socket| { - this.progressUpdate(true, &http_thread.https_context, socket); + this.progressUpdate(true, getHttpContext(true), socket); }, .tcp => |socket| { - this.progressUpdate(false, &http_thread.http_context, socket); + this.progressUpdate(false, getHttpContext(false), socket); }, .none => {}, } @@ -97,10 +99,10 @@ const ProxyTunnel = struct { .proxy_headers => { switch (proxy.socket) { .ssl => |socket| { - this.handleOnDataHeaders(true, decoded_data, &http_thread.https_context, socket); + this.handleOnDataHeaders(true, decoded_data, getHttpContext(true), socket); }, .tcp => |socket| { - this.handleOnDataHeaders(false, decoded_data, &http_thread.http_context, socket); + this.handleOnDataHeaders(false, decoded_data, getHttpContext(false), socket); }, .none => {}, } diff --git a/src/http/client/thread.zig b/src/http/client/thread.zig index 85681f6527..e25a05011f 100644 --- a/src/http/client/thread.zig +++ b/src/http/client/thread.zig @@ -1,5 +1,7 @@ const bun = @import("root").bun; const std = @import("std"); + +const Global = bun.Global; const picohttp = bun.picohttp; const BoringSSL = bun.BoringSSL; const JSC = bun.JSC; @@ -19,6 +21,7 @@ const assert = bun.assert; const strings = bun.strings; const Batch = bun.ThreadPool.Batch; +const HTTPAllocator = @import("./http_allocator.zig"); pub var http_thread: HTTPThread = undefined; var custom_ssl_context_map = std.AutoArrayHashMap(*SSLConfig, *NewHTTPContext(true)).init(bun.default_allocator); @@ -27,6 +30,18 @@ const HTTPCertError = @import("./errors.zig").HTTPCertError; const Queue = @import("./async_http.zig").Queue; const ThreadlocalAsyncHTTP = @import("./async_http.zig").ThreadlocalAsyncHTTP; const ProxyTunnel = @import("./proxy_tunnel.zig").ProxyTunnel; +const AsyncHTTP = @import("./async_http.zig").AsyncHTTP; +const socket_async_http_abort_tracker = AsyncHTTP.getSocketAsyncHTTPAbortTracker(); +const log = Output.scoped(.fetch, false); +const HTTPClient = @import("../../http.zig").HTTPClient; +pub const end_of_chunked_http1_1_encoding_response_body = "0\r\n\r\n"; + +pub fn getContext(comptime ssl: bool) *NewHTTPContext(ssl) { + return if (ssl) &http_thread.https_context else &http_thread.http_context; +} +pub fn getHttpThread() *HTTPThread { + return &http_thread; +} pub fn NewHTTPContext(comptime ssl: bool) type { return struct { @@ -74,11 +89,7 @@ pub fn NewHTTPContext(comptime ssl: bool) type { pub const HTTPSocket = uws.NewSocketHandler(ssl); pub fn context() *@This() { - if (comptime ssl) { - return &http_thread.https_context; - } else { - return &http_thread.http_context; - } + return getContext(ssl); } const ActiveSocket = TaggedPointerUnion(.{ @@ -351,7 +362,7 @@ pub fn NewHTTPContext(comptime ssl: bool) type { return client.onData( comptime ssl, buf, - if (comptime ssl) &http_thread.https_context else &http_thread.http_context, + getContext(ssl), socket, ); } else if (tagged.is(PooledSocket)) { @@ -675,8 +686,7 @@ pub const HTTPThread = struct { pub fn onStart(opts: InitOpts) void { Output.Source.configureNamedThread("HTTP Client"); - default_arena = Arena.init() catch unreachable; - default_allocator = default_arena.allocator(); + HTTPAllocator.init(); const loop = bun.JSC.MiniEventLoop.initGlobal(null);