mirror of
https://github.com/oven-sh/bun
synced 2026-02-09 18:38:55 +00:00
Reliability bugfix for WebSocket (#3394)
* Rewrite elementLengthLatin1IntoUTF8 * Update SIMDUTF * Make `elementLengthLatin1IntoUTF8` faster --------- Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com>
This commit is contained in:
@@ -1946,61 +1946,44 @@ pub fn replaceLatin1WithUTF8(buf_: []u8) void {
|
||||
}
|
||||
|
||||
pub fn elementLengthLatin1IntoUTF8(comptime Type: type, latin1_: Type) usize {
|
||||
// https://zig.godbolt.org/z/zzYexPPs9
|
||||
|
||||
var latin1 = latin1_;
|
||||
const input_len = latin1.len;
|
||||
var total_non_ascii_count: usize = 0;
|
||||
|
||||
const latin1_last = latin1.ptr + latin1.len;
|
||||
if (latin1.ptr != latin1_last) {
|
||||
// This is about 30% faster on large input compared to auto-vectorization
|
||||
if (comptime Environment.enableSIMD) {
|
||||
const end = latin1.ptr + (latin1.len - (latin1.len % ascii_vector_size));
|
||||
while (latin1.ptr != end) {
|
||||
const vec: AsciiVector = latin1[0..ascii_vector_size].*;
|
||||
|
||||
// reference the pointer directly because it improves codegen
|
||||
var ptr = latin1.ptr;
|
||||
// Shifting a unsigned 8 bit integer to the right by 7 bits always produces a value of 0 or 1.
|
||||
const cmp = vec >> @splat(
|
||||
ascii_vector_size,
|
||||
@as(u8, 7),
|
||||
);
|
||||
|
||||
if (comptime Environment.enableSIMD) {
|
||||
const wrapped_len = latin1.len - (latin1.len % ascii_vector_size);
|
||||
const latin1_vec_end = ptr + wrapped_len;
|
||||
while (ptr != latin1_vec_end) {
|
||||
const vec: AsciiVector = ptr[0..ascii_vector_size].*;
|
||||
const cmp = vec & @splat(ascii_vector_size, @as(u8, 0x80));
|
||||
total_non_ascii_count += @reduce(.Add, cmp);
|
||||
ptr += ascii_vector_size;
|
||||
}
|
||||
} else {
|
||||
while (@intFromPtr(ptr + 8) < @intFromPtr(latin1_last)) {
|
||||
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
|
||||
const bytes = @bitCast(u64, ptr[0..8].*) & 0x8080808080808080;
|
||||
total_non_ascii_count += @popCount(bytes);
|
||||
ptr += 8;
|
||||
}
|
||||
// Anding that value rather than converting it into a @Vector(16, u1) produces better code from LLVM.
|
||||
const mask = cmp & @splat(
|
||||
ascii_vector_size,
|
||||
@as(u8, 1),
|
||||
);
|
||||
|
||||
if (@intFromPtr(ptr + 4) < @intFromPtr(latin1_last)) {
|
||||
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
|
||||
const bytes = @bitCast(u32, ptr[0..4].*) & 0x80808080;
|
||||
total_non_ascii_count += @popCount(bytes);
|
||||
ptr += 4;
|
||||
}
|
||||
|
||||
if (@intFromPtr(ptr + 2) < @intFromPtr(latin1_last)) {
|
||||
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
|
||||
const bytes = @bitCast(u16, ptr[0..2].*) & 0x8080;
|
||||
total_non_ascii_count += @popCount(bytes);
|
||||
ptr += 2;
|
||||
}
|
||||
total_non_ascii_count += @as(usize, @reduce(.Add, mask));
|
||||
latin1 = latin1[ascii_vector_size..];
|
||||
}
|
||||
|
||||
while (ptr != latin1_last) {
|
||||
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) < @intFromPtr(latin1_last));
|
||||
// an important hint to the compiler to not auto-vectorize the loop below
|
||||
if (latin1.len >= ascii_vector_size) unreachable;
|
||||
}
|
||||
|
||||
total_non_ascii_count += @as(usize, @intFromBool(ptr[0] > 127));
|
||||
ptr += 1;
|
||||
}
|
||||
|
||||
// assert we never go out of bounds
|
||||
if (comptime Environment.allow_assert) std.debug.assert(@intFromPtr(ptr) <= @intFromPtr(latin1_last) and @intFromPtr(ptr) >= @intFromPtr(latin1_.ptr));
|
||||
for (latin1) |c| {
|
||||
total_non_ascii_count += @as(usize, @intFromBool(c > 127));
|
||||
}
|
||||
|
||||
// each non-ascii latin1 character becomes 2 UTF8 characters
|
||||
// since latin1_.len is the original length, we only need to add up the number of non-ascii characters to get the final count
|
||||
return latin1_.len + total_non_ascii_count;
|
||||
return input_len + total_non_ascii_count;
|
||||
}
|
||||
|
||||
const JSC = @import("root").bun.JSC;
|
||||
|
||||
Reference in New Issue
Block a user