Implement keep-alive but disable it

This commit is contained in:
Jarred Sumner
2022-02-05 00:29:41 -08:00
parent 860d7e93c0
commit 2b45c8dffe
2 changed files with 126 additions and 17 deletions

View File

@@ -67,7 +67,7 @@ else
pub const OPEN_SOCKET_FLAGS = SOCK.CLOEXEC;
pub const extremely_verbose = false;
pub const extremely_verbose = Environment.isDebug;
fn writeRequest(
comptime Writer: type,
@@ -113,6 +113,7 @@ socket: AsyncSocket.SSL = undefined,
socket_loaded: bool = false,
gzip_elapsed: u64 = 0,
stage: Stage = Stage.pending,
received_keep_alive: bool = false,
/// Some HTTP servers (such as npm) report Last-Modified times but ignore If-Modified-Since.
/// This is a workaround for that.
@@ -135,7 +136,9 @@ pub fn init(
.url = url,
.header_entries = header_entries,
.header_buf = header_buf,
.socket = undefined,
.socket = AsyncSocket.SSL{
.socket = undefined,
},
};
}
@@ -246,6 +249,62 @@ pub const HTTPChannelContext = struct {
}
};
// This causes segfaults when resume connect()
pub const KeepAlive = struct {
const limit = 2;
pub const disabled = true;
fds: [limit]u32 = undefined,
hosts: [limit]u64 = undefined,
ports: [limit]u16 = undefined,
used: u8 = 0,
pub var instance = KeepAlive{};
pub fn append(this: *KeepAlive, host: []const u8, port: u16, fd: os.socket_t) bool {
if (disabled) return false;
if (this.used >= limit or fd > std.math.maxInt(u32)) return false;
const i = this.used;
const hash = std.hash.Wyhash.hash(0, host);
this.fds[i] = @truncate(u32, @intCast(u64, fd));
this.hosts[i] = hash;
this.ports[i] = port;
this.used += 1;
return true;
}
pub fn find(this: *KeepAlive, host: []const u8, port: u16) ?os.socket_t {
if (disabled) return null;
if (this.used == 0) {
return null;
}
const hash = std.hash.Wyhash.hash(0, host);
const list = this.hosts[0..this.used];
for (list) |host_hash, i| {
if (host_hash == hash and this.ports[i] == port) {
const fd = this.fds[i];
const last = this.used - 1;
if (i > last) {
const end_host = this.hosts[last];
const end_fd = this.fds[last];
const end_port = this.ports[last];
this.hosts[i] = end_host;
this.fds[i] = end_fd;
this.ports[i] = end_port;
}
this.used -= 1;
return @intCast(os.socket_t, fd);
}
}
return null;
}
};
pub const AsyncHTTP = struct {
request: ?picohttp.Request = null,
response: ?picohttp.Response = null,
@@ -319,6 +378,13 @@ pub const AsyncHTTP = struct {
return this;
}
fn reset(this: *AsyncHTTP) !void {
const timeout = this.timeout;
this.client = try HTTPClient.init(this.allocator, this.method, this.client.url, this.client.header_entries, this.client.header_buf);
this.client.timeout = timeout;
this.timeout = timeout;
}
pub fn schedule(this: *AsyncHTTP, _: std.mem.Allocator, batch: *ThreadPool.Batch) void {
std.debug.assert(NetworkThread.global_loaded.load(.Monotonic) == 1);
this.state.store(.scheduled, .Monotonic);
@@ -381,6 +447,10 @@ pub const AsyncHTTP = struct {
};
pub fn do(sender: *HTTPSender, this: *AsyncHTTP) void {
defer {
NetworkThread.global.pool.schedule(.{ .head = &sender.finisher, .tail = &sender.finisher, .len = 1 });
}
outer: {
this.err = null;
this.state.store(.sending, .Monotonic);
@@ -394,6 +464,7 @@ pub const AsyncHTTP = struct {
if (this.max_retry_count > this.retries_count) {
this.retries_count += 1;
this.response_buffer.reset();
NetworkThread.global.pool.schedule(ThreadPool.Batch.from(&this.task));
return;
}
@@ -408,7 +479,6 @@ pub const AsyncHTTP = struct {
if (this.callback) |callback| {
callback(this);
}
NetworkThread.global.pool.schedule(.{ .head = &sender.finisher, .tail = &sender.finisher, .len = 1 });
}
};
@@ -534,20 +604,20 @@ pub fn sendAsync(this: *HTTPClient, body: []const u8, body_out_str: *MutableStri
return async this.send(body, body_out_str);
}
pub fn send(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) !picohttp.Response {
defer {
if (this.socket_loaded) {
this.socket_loaded = false;
this.socket.deinit();
}
fn maybeClearSocket(this: *HTTPClient) void {
if (this.socket_loaded) {
this.socket_loaded = false;
this.socket.deinit();
}
}
pub fn send(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) !picohttp.Response {
defer this.maybeClearSocket();
// this prevents stack overflow
redirect: while (this.remaining_redirect_count >= -1) {
if (this.socket_loaded) {
this.socket_loaded = false;
this.socket.deinit();
}
this.maybeClearSocket();
_ = AsyncHTTP.active_requests_count.fetchAdd(1, .Monotonic);
defer {
@@ -596,7 +666,7 @@ pub fn sendHTTP(this: *HTTPClient, body: []const u8, body_out_str: *MutableStrin
var socket = &this.socket.socket;
try this.connect(*AsyncSocket, socket);
this.stage = Stage.request;
defer this.socket.close();
defer this.closeSocket();
var request = buildRequest(this, body.len);
if (this.verbose) {
@@ -673,6 +743,10 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
var location: string = "";
var pretend_its_304 = false;
var maybe_keepalive = false;
errdefer {
maybe_keepalive = false;
}
for (response.headers) |header| {
switch (hashHeaderName(header.name)) {
@@ -707,6 +781,13 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
location_header_hash => {
location = header.value;
},
hashHeaderName("Connection") => {
if (response.status_code >= 200 and response.status_code <= 299 and !KeepAlive.disabled) {
if (strings.eqlComptime(header.value, "keep-alive")) {
maybe_keepalive = true;
}
}
},
hashHeaderName("Last-Modified") => {
if (this.force_last_modified and response.status_code > 199 and response.status_code < 300 and this.if_modified_since.len > 0) {
if (strings.eql(this.if_modified_since, header.value)) {
@@ -774,6 +855,7 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
if (response.status_code == 304) break :body_getter;
if (transfer_encoding == Encoding.chunked) {
maybe_keepalive = false;
var decoder = std.mem.zeroes(picohttp.phr_chunked_decoder);
var buffer_: *MutableString = body_out_str;
@@ -1020,9 +1102,30 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
this.progress_node.?.context.maybeRefresh();
}
if (maybe_keepalive and response.status_code >= 200 and response.status_code < 300) {
this.received_keep_alive = true;
}
return response;
}
pub fn closeSocket(this: *HTTPClient) void {
if (this.received_keep_alive) {
this.received_keep_alive = false;
if (this.url.hostname.len > 0 and this.socket.socket.socket > 0) {
if (!this.socket.connect_frame.wait and
(!this.socket.ssl_bio_loaded or
(this.socket.ssl_bio.pending_sends == 0 and this.socket.ssl_bio.pending_reads == 0)))
{
if (KeepAlive.instance.append(this.url.hostname, this.url.getPortAuto(), this.socket.socket.socket)) {
this.socket.socket.socket = 0;
}
}
}
}
this.socket.close();
}
pub fn sendHTTPS(this: *HTTPClient, body_str: []const u8, body_out_str: *MutableString) !picohttp.Response {
this.socket = try AsyncSocket.SSL.init(default_allocator, &AsyncIO.global);
this.socket_loaded = true;
@@ -1031,7 +1134,7 @@ pub fn sendHTTPS(this: *HTTPClient, body_str: []const u8, body_out_str: *Mutable
this.stage = Stage.connect;
try this.connect(*AsyncSocket.SSL, socket);
this.stage = Stage.request;
defer this.socket.close();
defer this.closeSocket();
var request = buildRequest(this, body_str.len);
if (this.verbose) {

View File

@@ -269,8 +269,14 @@ fn _wait(self: *ThreadPool, _is_waking: bool, comptime sleep_on_idle: bool) erro
const end_count = HTTP.AsyncHTTP.active_requests_count.loadUnchecked();
if (end_count > 0) {
while (HTTP.AsyncHTTP.active_requests_count.loadUnchecked() > HTTP.AsyncHTTP.max_simultaneous_requests) {
io.run_for_ns(std.time.ns_per_ms) catch {};
if (comptime sleep_on_idle) {
idle_network_ticks = 0;
}
var remaining_ticks: i32 = 5;
while (remaining_ticks > 0 and HTTP.AsyncHTTP.active_requests_count.loadUnchecked() > HTTP.AsyncHTTP.max_simultaneous_requests) : (remaining_ticks -= 1) {
io.run_for_ns(std.time.ns_per_ms * 2) catch {};
io.tick() catch {};
}
}