[fetch] Fix data corruption bug

This commit is contained in:
Jarred Sumner
2022-06-27 05:29:25 -07:00
parent f70784a6d1
commit 07050901a6
2 changed files with 51 additions and 49 deletions

View File

@@ -15,6 +15,8 @@ const SOCKET_FLAGS: u32 = @import("../http_client_async.zig").SOCKET_FLAGS;
const getAllocator = @import("../http_client_async.zig").getAllocator;
const OPEN_SOCKET_FLAGS: u32 = @import("../http_client_async.zig").OPEN_SOCKET_FLAGS;
const log = Output.scoped(.AsyncSocket, true);
const SSLFeatureFlags = struct {
pub const early_data_enabled = true;
};
@@ -291,6 +293,8 @@ pub inline fn bufferedReadAmount(_: *AsyncSocket) usize {
pub fn read(
this: *AsyncSocket,
bytes: []u8,
/// offset is necessary here to be consistent with HTTPS
/// HTTPs must have the same buffer pointer for each read
offset: u64,
) RecvError!u64 {
this.read_context = bytes;
@@ -303,7 +307,7 @@ pub fn read(
Reader.on_read,
&this.read_completion,
this.socket,
bytes,
bytes[original_read_offset..],
);
suspend {
@@ -315,6 +319,15 @@ pub fn read(
return @errSetCast(RecvError, err);
}
log(
\\recv(offset: {d}, len: {d}, read_offset: {d})
\\
, .{
offset,
bytes[original_read_offset..].len,
this.read_offset,
});
return this.read_offset - original_read_offset;
}

View File

@@ -35,6 +35,8 @@ pub const URLPath = @import("./http/url_path.zig");
pub var default_allocator: std.mem.Allocator = undefined;
pub var default_arena: Arena = undefined;
const log = Output.scoped(.fetch, true);
pub fn onThreadStart(_: ?*anyopaque) ?*anyopaque {
default_arena = Arena.init() catch unreachable;
default_allocator = default_arena.allocator();
@@ -494,6 +496,7 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
var override_accept_encoding = false;
var override_accept_header = false;
var override_host_header = false;
var override_user_agent = false;
for (header_names) |head, i| {
@@ -504,7 +507,6 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
// Skip host and connection header
// we manage those
switch (hash) {
host_header_hash,
connection_header_hash,
content_length_header_hash,
=> continue,
@@ -513,6 +515,9 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
this.if_modified_since = this.headerStr(header_values[i]);
}
},
host_header_hash => {
override_host_header = true;
},
accept_header_hash => {
override_accept_header = true;
},
@@ -544,8 +549,8 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
header_count += 1;
}
// request_headers_buf[header_count] = connection_header;
// header_count += 1;
request_headers_buf[header_count] = connection_header;
header_count += 1;
if (!override_user_agent) {
request_headers_buf[header_count] = user_agent_header;
@@ -557,11 +562,13 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
header_count += 1;
}
request_headers_buf[header_count] = picohttp.Header{
.name = host_header_name,
.value = this.url.hostname,
};
header_count += 1;
if (!override_host_header) {
request_headers_buf[header_count] = picohttp.Header{
.name = host_header_name,
.value = this.url.hostname,
};
header_count += 1;
}
if (!override_accept_encoding) {
request_headers_buf[header_count] = accept_encoding_header;
@@ -698,7 +705,12 @@ pub fn sendHTTP(this: *HTTPClient, body: []const u8, body_out_str: *MutableStrin
pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, comptime Client: type, client: Client, body_out_str: *MutableString) !picohttp.Response {
defer if (this.verbose) Output.flush();
var response: picohttp.Response = undefined;
var response: picohttp.Response = .{
.minor_version = 1,
.status_code = 0,
.status = "",
.headers = &[_]picohttp.Header{},
};
var request_message = AsyncMessage.get(default_allocator);
defer request_message.release();
var request_buffer: []u8 = request_message.buf;
@@ -712,23 +724,21 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
restart: while (req_buf_read != 0) {
req_buf_read = try client.read(request_buffer, read_length);
read_length += req_buf_read;
var request_body = request_buffer[0..read_length];
log("request_body ({d}):\n{s}", .{ read_length, request_body });
if (comptime report_progress) {
this.progress_node.?.activate();
this.progress_node.?.setCompletedItems(read_length);
this.progress_node.?.context.maybeRefresh();
}
var request_body = request_buffer[0..read_length];
read_headers_up_to = if (read_headers_up_to > read_length) read_length else read_headers_up_to;
read_headers_up_to = @minimum(read_headers_up_to, read_length);
response = picohttp.Response.parseParts(request_body, &this.response_headers_buf, &read_headers_up_to) catch |err| {
log("read_headers_up_to: {d}", .{read_headers_up_to});
switch (err) {
error.ShortRead => {
continue :restart;
},
else => {
return err;
},
error.ShortRead => continue :restart,
else => return err,
}
};
break :restart;
@@ -751,6 +761,7 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
maybe_keepalive = false;
}
var content_encoding_i = response.headers.len + 1;
for (response.headers) |header, header_i| {
switch (hashHeaderName(header.name)) {
content_length_header_hash => {
@@ -890,7 +901,7 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
}
var remainder = request_buffer[@intCast(usize, response.bytes_read)..read_length];
last_read = remainder.len;
try buffer.inflate(std.math.max(remainder.len, 2048));
try buffer.inflate(@maximum(remainder.len, 2048));
buffer.list.expandToCapacity();
std.mem.copy(u8, buffer.list.items, remainder);
}
@@ -910,41 +921,17 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
while (pret == -2) {
var buffered_amount = client.bufferedReadAmount();
if (buffer.list.items.len < total_size + 512 or buffer.list.items[total_size..].len < @intCast(usize, @maximum(decoder.bytes_left_in_chunk, buffered_amount)) or buffer.list.items[total_size..].len < 512) {
try buffer.inflate(std.math.max((buffered_amount + total_size) * 2, 1024));
if (comptime Environment.isDebug) {
var temp_buffer = buffer;
temp_buffer.list.expandToCapacity();
@memset(temp_buffer.list.items.ptr + buffer.list.items.len, 0, temp_buffer.list.items.len - buffer.list.items.len);
buffer = temp_buffer;
}
try buffer.inflate(@maximum((buffered_amount + total_size) * 2, 1024));
buffer.list.expandToCapacity();
}
// while (true) {
if (extremely_verbose) {
Output.prettyErrorln(
\\ Buffered: {d}
\\ Chunk
\\ {d} left / {d} bytes total (buffer: {d})
\\ Read
\\ {d} bytes / {d} total ({d} parsed)
, .{
client.bufferedReadAmount(),
decoder.bytes_left_in_chunk,
total_size,
buffer.list.items.len,
rret,
total_size,
total_size,
});
}
var remainder = buffer.list.items[total_size..];
const errorable_read = client.read(remainder, 0);
rret = errorable_read catch |err| {
if (extremely_verbose) Output.prettyErrorln("Chunked transfoer encoding error: {s}", .{@errorName(err)});
if (extremely_verbose) Output.prettyErrorln("Chunked transfer encoding error: {s}", .{@errorName(err)});
return err;
};
@@ -955,11 +942,12 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
remainder = buffer.list.items[total_size..];
remainder = remainder[rret..][0..buffered_amount];
rret += client.read(remainder, 0) catch |err| {
if (extremely_verbose) Output.prettyErrorln("Chunked transfoer encoding error: {s}", .{@errorName(err)});
if (extremely_verbose) Output.prettyErrorln("Chunked transfer encoding error: {s}", .{@errorName(err)});
return err;
};
}
// socket hang up, there was a parsing error, etc
if (rret == 0) {
if (extremely_verbose) Output.prettyErrorln("Unexpected 0", .{});
@@ -986,9 +974,9 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
decoder,
});
return error.ChunkedEncodingParseError;
}
total_size += rsize;
if (comptime report_progress) {
@@ -1000,7 +988,6 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
buffer.list.shrinkRetainingCapacity(total_size);
buffer_.* = buffer;
switch (encoding) {
Encoding.gzip, Encoding.deflate => {
var gzip_timer: std.time.Timer = undefined;
@@ -1035,7 +1022,9 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
this.progress_node.?.context.maybeRefresh();
}
this.body_size = @intCast(u32, body_out_str.list.items.len);
this.body_size = @truncate(u32, body_out_str.list.items.len);
std.debug.assert(body_out_str.list.items.len == buffer.list.items.len);
return response;
}