Reliability bugfix for WebSocket (#3394)

* Rewrite elementLengthLatin1IntoUTF8

* Update SIMDUTF

* Make `elementLengthLatin1IntoUTF8` faster

---------

Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com>
This commit is contained in:
Jarred Sumner
2023-06-25 02:58:49 -07:00
committed by GitHub
parent ff63555143
commit bc7719fc28
4 changed files with 24923 additions and 20637 deletions

View File

@@ -1946,61 +1946,44 @@ pub fn replaceLatin1WithUTF8(buf_: []u8) void {
}
pub fn elementLengthLatin1IntoUTF8(comptime Type: type, latin1_: Type) usize {
// https://zig.godbolt.org/z/zzYexPPs9
var latin1 = latin1_;
const input_len = latin1.len;
var total_non_ascii_count: usize = 0;
const latin1_last = latin1.ptr + latin1.len;
if (latin1.ptr != latin1_last) {
// This is about 30% faster on large input compared to auto-vectorization
if (comptime Environment.enableSIMD) {
const end = latin1.ptr + (latin1.len - (latin1.len % ascii_vector_size));
while (latin1.ptr != end) {
const vec: AsciiVector = latin1[0..ascii_vector_size].*;
// reference the pointer directly because it improves codegen
var ptr = latin1.ptr;
// Shifting a unsigned 8 bit integer to the right by 7 bits always produces a value of 0 or 1.
const cmp = vec >> @splat(
ascii_vector_size,
@as(u8, 7),
);
if (comptime Environment.enableSIMD) {
const wrapped_len = latin1.len - (latin1.len % ascii_vector_size);
const latin1_vec_end = ptr + wrapped_len;
while (ptr != latin1_vec_end) {
const vec: AsciiVector = ptr[0..ascii_vector_size].*;
const cmp = vec & @splat(ascii_vector_size, @as(u8, 0x80));
total_non_ascii_count += @reduce(.Add, cmp);
ptr += ascii_vector_size;
}
} else {
while (@intFromPtr(ptr + 8) < @intFromPtr(latin1_last)) {
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
const bytes = @bitCast(u64, ptr[0..8].*) & 0x8080808080808080;
total_non_ascii_count += @popCount(bytes);
ptr += 8;
}
// Anding that value rather than converting it into a @Vector(16, u1) produces better code from LLVM.
const mask = cmp & @splat(
ascii_vector_size,
@as(u8, 1),
);
if (@intFromPtr(ptr + 4) < @intFromPtr(latin1_last)) {
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
const bytes = @bitCast(u32, ptr[0..4].*) & 0x80808080;
total_non_ascii_count += @popCount(bytes);
ptr += 4;
}
if (@intFromPtr(ptr + 2) < @intFromPtr(latin1_last)) {
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
const bytes = @bitCast(u16, ptr[0..2].*) & 0x8080;
total_non_ascii_count += @popCount(bytes);
ptr += 2;
}
total_non_ascii_count += @as(usize, @reduce(.Add, mask));
latin1 = latin1[ascii_vector_size..];
}
while (ptr != latin1_last) {
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) < @intFromPtr(latin1_last));
// an important hint to the compiler to not auto-vectorize the loop below
if (latin1.len >= ascii_vector_size) unreachable;
}
total_non_ascii_count += @as(usize, @intFromBool(ptr[0] > 127));
ptr += 1;
}
// assert we never go out of bounds
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
for (latin1) |c| {
total_non_ascii_count += @as(usize, @intFromBool(c > 127));
}
// each non-ascii latin1 character becomes 2 UTF8 characters
// since latin1_.len is the original length, we only need to add up the number of non-ascii characters to get the final count
return latin1_.len + total_non_ascii_count;
return input_len + total_non_ascii_count;
}
const JSC = @import("root").bun.JSC;