From ebaeafbc89b8d86d4303a468d373792f8859bde3 Mon Sep 17 00:00:00 2001 From: Zack Radisic <56137411+zackradisic@users.noreply.github.com> Date: Fri, 16 Feb 2024 04:09:34 -0800 Subject: [PATCH] feat: More robust and faster shell escaping (#8904) * wip * Proper escaping algorithm * Don't use `$` for js obj/string referencs * [autofix.ci] apply automated fixes * Changes * Changes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jarred Sumner --- src/bun.js/api/BunObject.zig | 42 ++- src/shell/interpreter.zig | 38 ++- src/shell/shell.zig | 475 ++++++++++++++++++++++++----- src/string_immutable.zig | 2 +- test/js/bun/shell/bunshell.test.ts | 31 ++ 5 files changed, 489 insertions(+), 99 deletions(-) diff --git a/src/bun.js/api/BunObject.zig b/src/bun.js/api/BunObject.zig index 76700e1cb4..a0266f8505 100644 --- a/src/bun.js/api/BunObject.zig +++ b/src/bun.js/api/BunObject.zig @@ -318,6 +318,17 @@ pub fn shellLex( defer arena.deinit(); const template_args = callframe.argumentsPtr()[1..callframe.argumentsCount()]; + var stack_alloc = std.heap.stackFallback(@sizeOf(bun.String) * 4, arena.allocator()); + var jsstrings = std.ArrayList(bun.String).initCapacity(stack_alloc.get(), 4) catch { + globalThis.throwOutOfMemory(); + return .undefined; + }; + defer { + for (jsstrings.items[0..]) |bunstr| { + bunstr.deref(); + } + jsstrings.deinit(); + } var jsobjs = std.ArrayList(JSValue).init(arena.allocator()); defer { for (jsobjs.items) |jsval| { @@ -326,7 +337,7 @@ pub fn shellLex( } var script = std.ArrayList(u8).init(arena.allocator()); - if (!(bun.shell.shellCmdFromJS(globalThis, string_args, template_args, &jsobjs, &script) catch { + if (!(bun.shell.shellCmdFromJS(globalThis, string_args, template_args, &jsobjs, &jsstrings, &script) catch { globalThis.throwOutOfMemory(); return JSValue.undefined; })) { @@ -335,14 +346,14 @@ pub fn shellLex( const lex_result = brk: { if (bun.strings.isAllASCII(script.items[0..])) { - var lexer = Shell.LexerAscii.new(arena.allocator(), script.items[0..]); + var lexer = Shell.LexerAscii.new(arena.allocator(), script.items[0..], jsstrings.items[0..]); lexer.lex() catch |err| { globalThis.throwError(err, "failed to lex shell"); return JSValue.undefined; }; break :brk lexer.get_result(); } - var lexer = Shell.LexerUnicode.new(arena.allocator(), script.items[0..]); + var lexer = Shell.LexerUnicode.new(arena.allocator(), script.items[0..], jsstrings.items[0..]); lexer.lex() catch |err| { globalThis.throwError(err, "failed to lex shell"); return JSValue.undefined; @@ -393,6 +404,17 @@ pub fn shellParse( defer arena.deinit(); const template_args = callframe.argumentsPtr()[1..callframe.argumentsCount()]; + var stack_alloc = std.heap.stackFallback(@sizeOf(bun.String) * 4, arena.allocator()); + var jsstrings = std.ArrayList(bun.String).initCapacity(stack_alloc.get(), 4) catch { + globalThis.throwOutOfMemory(); + return .undefined; + }; + defer { + for (jsstrings.items[0..]) |bunstr| { + bunstr.deref(); + } + jsstrings.deinit(); + } var jsobjs = std.ArrayList(JSValue).init(arena.allocator()); defer { for (jsobjs.items) |jsval| { @@ -400,7 +422,7 @@ pub fn shellParse( } } var script = std.ArrayList(u8).init(arena.allocator()); - if (!(bun.shell.shellCmdFromJS(globalThis, string_args, template_args, &jsobjs, &script) catch { + if (!(bun.shell.shellCmdFromJS(globalThis, string_args, template_args, &jsobjs, &jsstrings, &script) catch { globalThis.throwOutOfMemory(); return JSValue.undefined; })) { @@ -410,7 +432,7 @@ pub fn shellParse( var out_parser: ?bun.shell.Parser = null; var out_lex_result: ?bun.shell.LexResult = null; - const script_ast = bun.shell.Interpreter.parse(&arena, script.items[0..], jsobjs.items[0..], &out_parser, &out_lex_result) catch |err| { + const script_ast = bun.shell.Interpreter.parse(&arena, script.items[0..], jsobjs.items[0..], jsstrings.items[0..], &out_parser, &out_lex_result) catch |err| { if (err == bun.shell.ParseError.Lex) { std.debug.assert(out_lex_result != null); const str = out_lex_result.?.combineErrors(arena.allocator()); @@ -545,17 +567,21 @@ pub fn shellEscape( if (bunstr.isUTF16()) { if (bun.shell.needsEscapeUTF16(bunstr.utf16())) { - bun.shell.escapeUnicode(bunstr.byteSlice(), &outbuf) catch { + const has_invalid_utf16 = bun.shell.escapeUtf16(bunstr.utf16(), &outbuf, true) catch { globalThis.throwOutOfMemory(); return .undefined; }; + if (has_invalid_utf16) { + globalThis.throw("String has invalid utf-16: {s}", .{bunstr.byteSlice()}); + return .undefined; + } return bun.String.createUTF8(outbuf.items[0..]).toJS(globalThis); } return jsval; } - if (bun.shell.needsEscape(bunstr.latin1())) { - bun.shell.escape(bunstr.byteSlice(), &outbuf) catch { + if (bun.shell.needsEscapeUtf8AsciiLatin1(bunstr.latin1())) { + bun.shell.escape8Bit(bunstr.byteSlice(), &outbuf, true) catch { globalThis.throwOutOfMemory(); return .undefined; }; diff --git a/src/shell/interpreter.zig b/src/shell/interpreter.zig index 37f8bd31b5..63c10ef418 100644 --- a/src/shell/interpreter.zig +++ b/src/shell/interpreter.zig @@ -842,9 +842,20 @@ pub fn NewInterpreter(comptime EventLoopKind: JSC.EventLoopKind) type { }; const template_args = callframe.argumentsPtr()[1..callframe.argumentsCount()]; + var stack_alloc = std.heap.stackFallback(@sizeOf(bun.String) * 4, arena.allocator()); + var jsstrings = std.ArrayList(bun.String).initCapacity(stack_alloc.get(), 4) catch { + globalThis.throwOutOfMemory(); + return null; + }; + defer { + for (jsstrings.items[0..]) |bunstr| { + bunstr.deref(); + } + jsstrings.deinit(); + } var jsobjs = std.ArrayList(JSValue).init(arena.allocator()); var script = std.ArrayList(u8).init(arena.allocator()); - if (!(bun.shell.shellCmdFromJS(globalThis, string_args, template_args, &jsobjs, &script) catch { + if (!(bun.shell.shellCmdFromJS(globalThis, string_args, template_args, &jsobjs, &jsstrings, &script) catch { globalThis.throwOutOfMemory(); return null; })) { @@ -857,6 +868,7 @@ pub fn NewInterpreter(comptime EventLoopKind: JSC.EventLoopKind) type { &arena, script.items[0..], jsobjs.items[0..], + jsstrings.items[0..], &parser, &lex_result, ) catch |err| { @@ -902,14 +914,21 @@ pub fn NewInterpreter(comptime EventLoopKind: JSC.EventLoopKind) type { return interpreter; } - pub fn parse(arena: *bun.ArenaAllocator, script: []const u8, jsobjs: []JSValue, out_parser: *?bun.shell.Parser, out_lex_result: *?shell.LexResult) !ast.Script { + pub fn parse( + arena: *bun.ArenaAllocator, + script: []const u8, + jsobjs: []JSValue, + jsstrings_to_escape: []bun.String, + out_parser: *?bun.shell.Parser, + out_lex_result: *?shell.LexResult, + ) !ast.Script { const lex_result = brk: { if (bun.strings.isAllASCII(script)) { - var lexer = bun.shell.LexerAscii.new(arena.allocator(), script); + var lexer = bun.shell.LexerAscii.new(arena.allocator(), script, jsstrings_to_escape); try lexer.lex(); break :brk lexer.get_result(); } - var lexer = bun.shell.LexerUnicode.new(arena.allocator(), script); + var lexer = bun.shell.LexerUnicode.new(arena.allocator(), script, jsstrings_to_escape); try lexer.lex(); break :brk lexer.get_result(); }; @@ -1029,7 +1048,14 @@ pub fn NewInterpreter(comptime EventLoopKind: JSC.EventLoopKind) type { const jsobjs: []JSValue = &[_]JSValue{}; var out_parser: ?bun.shell.Parser = null; var out_lex_result: ?bun.shell.LexResult = null; - const script = ThisInterpreter.parse(&arena, src, jsobjs, &out_parser, &out_lex_result) catch |err| { + const script = ThisInterpreter.parse( + &arena, + src, + jsobjs, + &[_]bun.String{}, + &out_parser, + &out_lex_result, + ) catch |err| { if (err == bun.shell.ParseError.Lex) { std.debug.assert(out_lex_result != null); const str = out_lex_result.?.combineErrors(arena.allocator()); @@ -1075,7 +1101,7 @@ pub fn NewInterpreter(comptime EventLoopKind: JSC.EventLoopKind) type { const jsobjs: []JSValue = &[_]JSValue{}; var out_parser: ?bun.shell.Parser = null; var out_lex_result: ?bun.shell.LexResult = null; - const script = ThisInterpreter.parse(&arena, src, jsobjs, &out_parser, &out_lex_result) catch |err| { + const script = ThisInterpreter.parse(&arena, src, jsobjs, &[_]bun.String{}, &out_parser, &out_lex_result) catch |err| { if (err == bun.shell.ParseError.Lex) { std.debug.assert(out_lex_result != null); const str = out_lex_result.?.combineErrors(arena.allocator()); diff --git a/src/shell/shell.zig b/src/shell/shell.zig index 60c2b8ca9b..c6cd685544 100644 --- a/src/shell/shell.zig +++ b/src/shell/shell.zig @@ -1279,7 +1279,8 @@ pub const LexError = struct { /// Allocated with lexer arena msg: []const u8, }; -pub const LEX_JS_OBJREF_PREFIX = "$__bun_"; +pub const LEX_JS_OBJREF_PREFIX = "~__bun_"; +pub const LEX_JS_STRING_PREFIX = "~__bunstr_"; pub fn NewLexer(comptime encoding: StringEncoding) type { const Chars = ShellCharIter(encoding); @@ -1300,6 +1301,10 @@ pub fn NewLexer(comptime encoding: StringEncoding) type { in_subshell: ?SubShellKind = null, errors: std.ArrayList(LexError), + /// Contains a list of strings we need to escape + /// Not owned by this struct + string_refs: []bun.String, + const SubShellKind = enum { /// (echo hi; echo hello) normal, @@ -1329,12 +1334,13 @@ pub fn NewLexer(comptime encoding: StringEncoding) type { delimit_quote: bool, }; - pub fn new(alloc: Allocator, src: []const u8) @This() { + pub fn new(alloc: Allocator, src: []const u8, strings_to_escape: []bun.String) @This() { return .{ .chars = Chars.init(src), .tokens = ArrayList(Token).init(alloc), .strpool = ArrayList(u8).init(alloc), .errors = ArrayList(LexError).init(alloc), + .string_refs = strings_to_escape, }; } @@ -1364,6 +1370,7 @@ pub fn NewLexer(comptime encoding: StringEncoding) type { .word_start = self.word_start, .j = self.j, + .string_refs = self.string_refs, }; sublexer.chars.state = .Normal; return sublexer; @@ -1411,11 +1418,31 @@ pub fn NewLexer(comptime encoding: StringEncoding) type { const char = input.char; const escaped = input.escaped; + // Special token to denote substituted JS variables + if (char == '~') { + if (self.looksLikeJSStringRef()) { + if (self.eatJSStringRef()) |bunstr| { + try self.break_word(false); + try self.handleJSStringRef(bunstr); + continue; + } + } else if (self.looksLikeJSObjRef()) { + if (self.eatJSObjRef()) |tok| { + if (self.chars.state == .Double) { + self.add_error("JS object reference not allowed in double quotes"); + return; + } + try self.break_word(false); + try self.tokens.append(tok); + continue; + } + } + } // Handle non-escaped chars: // 1. special syntax (operators, etc.) // 2. lexing state switchers (quotes) // 3. word breakers (spaces, etc.) - if (!escaped) escaped: { + else if (!escaped) escaped: { switch (char) { '#' => { if (self.chars.state == .Single or self.chars.state == .Double) break :escaped; @@ -1506,21 +1533,13 @@ pub fn NewLexer(comptime encoding: StringEncoding) type { // const snapshot = self.make_snapshot(); // Handle variable try self.break_word(false); - if (self.eat_js_obj_ref()) |ref| { - if (self.chars.state == .Double) { - try self.errors.append(.{ .msg = bun.default_allocator.dupe(u8, "JS object reference not allowed in double quotes") catch bun.outOfMemory() }); - return; - } - try self.tokens.append(ref); + const var_tok = try self.eat_var(); + // empty var + if (var_tok.start == var_tok.end) { + try self.appendCharToStrPool('$'); + try self.break_word(false); } else { - const var_tok = try self.eat_var(); - // empty var - if (var_tok.start == var_tok.end) { - try self.appendCharToStrPool('$'); - try self.break_word(false); - } else { - try self.tokens.append(.{ .Var = var_tok }); - } + try self.tokens.append(.{ .Var = var_tok }); } self.word_start = self.j; continue; @@ -1907,19 +1926,146 @@ pub fn NewLexer(comptime encoding: StringEncoding) type { self.continue_from_sublexer(&sublexer); } - fn eat_js_obj_ref(self: *@This()) ?Token { - const snap = self.make_snapshot(); - if (self.eat_literal(u8, LEX_JS_OBJREF_PREFIX)) { - if (self.eat_number_word()) |num| { - if (num <= std.math.maxInt(u32)) { - return .{ .JSObjRef = @intCast(num) }; + fn appendStringToStrPool(self: *@This(), bunstr: bun.String) !void { + const start = self.strpool.items.len; + if (bunstr.is8Bit() or bunstr.isUTF8()) { + try self.strpool.appendSlice(bunstr.byteSlice()); + } else { + const utf16 = bunstr.utf16(); + const additional = bun.simdutf.simdutf__utf8_length_from_utf16le(utf16.ptr, utf16.len); + try self.strpool.ensureUnusedCapacity(additional); + try bun.strings.convertUTF16ToUTF8Append(&self.strpool, bunstr.utf16()); + } + const end = self.strpool.items.len; + self.j += @intCast(end - start); + } + + fn handleJSStringRef(self: *@This(), bunstr: bun.String) !void { + try self.appendStringToStrPool(bunstr); + } + + fn looksLikeJSObjRef(self: *@This()) bool { + const bytes = self.chars.srcBytesAtCursor(); + if (LEX_JS_OBJREF_PREFIX.len - 1 >= bytes.len) return false; + return std.mem.eql(u8, bytes[0 .. LEX_JS_OBJREF_PREFIX.len - 1], LEX_JS_OBJREF_PREFIX[1..]); + } + + fn looksLikeJSStringRef(self: *@This()) bool { + const bytes = self.chars.srcBytesAtCursor(); + if (LEX_JS_STRING_PREFIX.len - 1 >= bytes.len) return false; + return std.mem.eql(u8, bytes[0 .. LEX_JS_STRING_PREFIX.len - 1], LEX_JS_STRING_PREFIX[1..]); + } + + fn eatJSSubstitutionIdx(self: *@This(), comptime literal: []const u8, comptime name: []const u8, comptime validate: *const fn (*@This(), usize) bool) ?usize { + const bytes = self.chars.srcBytesAtCursor(); + if (literal.len - 1 >= bytes.len) return null; + if (std.mem.eql(u8, bytes[0 .. literal.len - 1], literal[1..])) { + var i: usize = 0; + var digit_buf: [32]u8 = undefined; + var digit_buf_count: u8 = 0; + + i += literal.len - 1; + + while (i < bytes.len) : (i += 1) { + switch (bytes[i]) { + '0'...'9' => { + if (digit_buf_count >= digit_buf.len) { + const ERROR_STR = "Invalid " ++ name ++ " (number too high): "; + var error_buf: [ERROR_STR.len + digit_buf.len + 1]u8 = undefined; + const error_msg = std.fmt.bufPrint(error_buf[0..], "{s} {s}{c}", .{ ERROR_STR, digit_buf[0..digit_buf_count], bytes[i] }) catch @panic("Should not happen"); + self.add_error(error_msg); + return null; + } + digit_buf[digit_buf_count] = bytes[i]; + digit_buf_count += 1; + }, + else => break, } } + + if (digit_buf_count == 0) { + self.add_error("Invalid " ++ name ++ " (no idx)"); + return null; + } + + const idx = std.fmt.parseInt(usize, digit_buf[0..digit_buf_count], 10) catch { + self.add_error("Invalid " ++ name ++ " ref "); + return null; + }; + + if (!validate(self, idx)) return null; + // if (idx >= self.string_refs.len) { + // self.add_error("Invalid " ++ name ++ " (out of bounds"); + // return null; + // } + + // Bump the cursor + brk: { + const new_idx = self.chars.cursorPos() + i; + const prev_ascii_char: ?u7 = if (digit_buf_count == 1) null else @truncate(digit_buf[digit_buf_count - 2]); + const cur_ascii_char: u7 = @truncate(digit_buf[digit_buf_count - 1]); + if (comptime encoding == .ascii) { + self.chars.src.i = new_idx; + if (prev_ascii_char) |pc| self.chars.prev = .{ .char = pc }; + self.chars.current = .{ .char = cur_ascii_char }; + break :brk; + } + self.chars.src.cursor = CodepointIterator.Cursor{ + .i = @intCast(new_idx), + .c = cur_ascii_char, + .width = 1, + }; + self.chars.src.next_cursor = self.chars.src.cursor; + SrcUnicode.nextCursor(&self.chars.src.iter, &self.chars.src.next_cursor); + if (prev_ascii_char) |pc| self.chars.prev = .{ .char = pc }; + self.chars.current = .{ .char = cur_ascii_char }; + } + + // return self.string_refs[idx]; + return idx; } - self.backtrack(snap); return null; } + /// __NOTE__: Do not store references to the returned bun.String, it does not have its ref count incremented + fn eatJSStringRef(self: *@This()) ?bun.String { + if (self.eatJSSubstitutionIdx( + LEX_JS_STRING_PREFIX, + "JS string ref", + validateJSStringRefIdx, + )) |idx| { + return self.string_refs[idx]; + } + return null; + } + + fn validateJSStringRefIdx(self: *@This(), idx: usize) bool { + if (idx >= self.string_refs.len) { + self.add_error("Invalid JS string ref (out of bounds"); + return false; + } + return true; + } + + fn eatJSObjRef(self: *@This()) ?Token { + if (self.eatJSSubstitutionIdx( + LEX_JS_OBJREF_PREFIX, + "JS object ref", + validateJSObjRefIdx, + )) |idx| { + return .{ .JSObjRef = @intCast(idx) }; + } + return null; + } + + fn validateJSObjRefIdx(self: *@This(), idx: usize) bool { + if (idx >= std.math.maxInt(u32)) { + self.add_error("Invalid JS object ref (out of bounds)"); + return false; + } + return true; + } + fn eat_var(self: *@This()) !Token.TextRange { const start = self.j; var i: usize = 0; @@ -2087,7 +2233,7 @@ const SrcUnicode = struct { inline fn indexNext(this: *const SrcUnicode) ?IndexValue { if (this.next_cursor.width + this.next_cursor.i > this.iter.bytes.len) return null; - return .{ .char = this.next_cursor.c, .width = this.next_cursor.width }; + return .{ .char = @intCast(this.next_cursor.c), .width = this.next_cursor.width }; } inline fn eat(this: *SrcUnicode, escaped: bool) void { @@ -2147,6 +2293,27 @@ pub fn ShellCharIter(comptime encoding: StringEncoding) type { }; } + pub fn srcBytes(self: *@This()) []const u8 { + if (comptime encoding == .ascii) return self.src.bytes; + return self.src.iter.bytes; + } + + pub fn srcBytesAtCursor(self: *@This()) []const u8 { + const bytes = self.srcBytes(); + if (comptime encoding == .ascii) { + if (self.src.i >= bytes.len) return ""; + return bytes[self.src.i..]; + } + + if (self.src.iter.i >= bytes.len) return ""; + return bytes[self.src.iter.i..]; + } + + pub fn cursorPos(self: *@This()) usize { + if (comptime encoding == .ascii) return self.src.i; + return self.src.iter.i; + } + pub fn eat(self: *@This()) ?InputChar { if (self.read_char()) |result| { self.prev = self.current; @@ -2451,8 +2618,10 @@ pub fn shellCmdFromJS( string_args: JSValue, template_args: []const JSValue, out_jsobjs: *std.ArrayList(JSValue), + jsstrings: *std.ArrayList(bun.String), out_script: *std.ArrayList(u8), ) !bool { + var builder = ShellSrcBuilder.init(globalThis, out_script, jsstrings); var jsobjref_buf: [128]u8 = [_]u8{0} ** 128; var string_iter = string_args.arrayIterator(globalThis); @@ -2460,7 +2629,7 @@ pub fn shellCmdFromJS( const last = string_iter.len -| 1; while (string_iter.next()) |js_value| { defer i += 1; - if (!try appendJSValueStr(globalThis, js_value, out_script, false)) { + if (!try builder.appendJSValueStr(js_value, false)) { globalThis.throw("Shell script string contains invalid UTF-16", .{}); return false; } @@ -2468,7 +2637,7 @@ pub fn shellCmdFromJS( // try script.appendSlice(str.full()); if (i < last) { const template_value = template_args[i]; - if (!(try handleTemplateValue(globalThis, template_value, out_jsobjs, out_script, jsobjref_buf[0..]))) return false; + if (!(try handleTemplateValue(globalThis, template_value, out_jsobjs, out_script, jsstrings, jsobjref_buf[0..]))) return false; } } return true; @@ -2479,8 +2648,10 @@ pub fn handleTemplateValue( template_value: JSValue, out_jsobjs: *std.ArrayList(JSValue), out_script: *std.ArrayList(u8), + jsstrings: *std.ArrayList(bun.String), jsobjref_buf: []u8, ) !bool { + var builder = ShellSrcBuilder.init(globalThis, out_script, jsstrings); if (!template_value.isEmpty()) { if (template_value.asArrayBuffer(globalThis)) |array_buffer| { _ = array_buffer; @@ -2497,7 +2668,7 @@ pub fn handleTemplateValue( if (store.data == .file) { if (store.data.file.pathlike == .path) { const path = store.data.file.pathlike.path.slice(); - if (!try appendUTF8Text(path, out_script, true)) { + if (!try builder.appendUTF8(path, true)) { globalThis.throw("Shell script string contains invalid UTF-16", .{}); return false; } @@ -2537,7 +2708,7 @@ pub fn handleTemplateValue( } if (template_value.isString()) { - if (!try appendJSValueStr(globalThis, template_value, out_script, true)) { + if (!try builder.appendJSValueStr(template_value, true)) { globalThis.throw("Shell script string contains invalid UTF-16", .{}); return false; } @@ -2549,10 +2720,10 @@ pub fn handleTemplateValue( const last = array.len -| 1; var i: u32 = 0; while (array.next()) |arr| : (i += 1) { - if (!(try handleTemplateValue(globalThis, arr, out_jsobjs, out_script, jsobjref_buf))) return false; + if (!(try handleTemplateValue(globalThis, arr, out_jsobjs, out_script, jsstrings, jsobjref_buf))) return false; if (i < last) { - const str = bun.String.init(" "); - if (!try appendBunStr(str, out_script, false)) return false; + const str = bun.String.static(" "); + if (!try builder.appendBunStr(str, false)) return false; } } return true; @@ -2562,7 +2733,7 @@ pub fn handleTemplateValue( if (template_value.getTruthy(globalThis, "raw")) |maybe_str| { const bunstr = maybe_str.toBunString(globalThis); defer bunstr.deref(); - if (!try appendBunStr(bunstr, out_script, false)) { + if (!try builder.appendBunStr(bunstr, false)) { globalThis.throw("Shell script string contains invalid UTF-16", .{}); return false; } @@ -2571,7 +2742,7 @@ pub fn handleTemplateValue( } if (template_value.isPrimitive()) { - if (!try appendJSValueStr(globalThis, template_value, out_script, true)) { + if (!try builder.appendJSValueStr(template_value, true)) { globalThis.throw("Shell script string contains invalid UTF-16", .{}); return false; } @@ -2579,7 +2750,7 @@ pub fn handleTemplateValue( } if (template_value.implementsToString(globalThis)) { - if (!try appendJSValueStr(globalThis, template_value, out_script, true)) { + if (!try builder.appendJSValueStr(template_value, true)) { globalThis.throw("Shell script string contains invalid UTF-16", .{}); return false; } @@ -2593,57 +2764,127 @@ pub fn handleTemplateValue( return true; } -/// This will disallow invalid surrogate pairs -pub fn appendJSValueStr(globalThis: *JSC.JSGlobalObject, jsval: JSValue, outbuf: *std.ArrayList(u8), comptime allow_escape: bool) !bool { - const bunstr = jsval.toBunString(globalThis); - defer bunstr.deref(); +pub const ShellSrcBuilder = struct { + globalThis: *JSC.JSGlobalObject, + outbuf: *std.ArrayList(u8), + jsstrs_to_escape: *std.ArrayList(bun.String), + jsstr_ref_buf: [128]u8 = [_]u8{0} ** 128, - return try appendBunStr(bunstr, outbuf, allow_escape); -} - -pub fn appendUTF8Text(slice: []const u8, outbuf: *std.ArrayList(u8), comptime allow_escape: bool) !bool { - if (!bun.simdutf.validate.utf8(slice)) { - return false; + pub fn init( + globalThis: *JSC.JSGlobalObject, + outbuf: *std.ArrayList(u8), + jsstrs_to_escape: *std.ArrayList(bun.String), + ) ShellSrcBuilder { + return .{ + .globalThis = globalThis, + .outbuf = outbuf, + .jsstrs_to_escape = jsstrs_to_escape, + }; } - if (allow_escape and needsEscape(slice)) { - try escape(slice, outbuf); - } else { - try outbuf.appendSlice(slice); + pub fn appendJSValueStr(this: *ShellSrcBuilder, jsval: JSValue, comptime allow_escape: bool) !bool { + const bunstr = jsval.toBunString(this.globalThis); + defer bunstr.deref(); + + return try this.appendBunStr(bunstr, allow_escape); } - return true; -} - -pub fn appendBunStr(bunstr: bun.String, outbuf: *std.ArrayList(u8), comptime allow_escape: bool) !bool { - const str = bunstr.toUTF8WithoutRef(bun.default_allocator); - defer str.deinit(); - - // TODO: toUTF8 already validates. We shouldn't have to do this twice! - const is_ascii = str.isAllocated(); - if (!is_ascii and !bun.simdutf.validate.utf8(str.slice())) { - return false; + pub fn appendBunStr( + this: *ShellSrcBuilder, + bunstr: bun.String, + comptime allow_escape: bool, + ) !bool { + const invalid = (bunstr.isUTF16() and !bun.simdutf.validate.utf16le(bunstr.utf16())) or (bunstr.isUTF8() and !bun.simdutf.validate.utf8(bunstr.byteSlice())); + if (invalid) return false; + if (allow_escape) { + if (needsEscapeBunstr(bunstr)) { + try this.appendJSStrRef(bunstr); + return true; + } + } + if (bunstr.isUTF16()) { + try this.appendUTF16Impl(bunstr.utf16()); + return true; + } + if (bunstr.isUTF8() or bun.strings.isAllASCII(bunstr.byteSlice())) { + try this.appendUTF8Impl(bunstr.byteSlice()); + return true; + } + try this.appendLatin1Impl(bunstr.byteSlice()); + return true; } - if (allow_escape and needsEscape(str.slice())) { - try escape(str.slice(), outbuf); - } else { - try outbuf.appendSlice(str.slice()); + pub fn appendUTF8(this: *ShellSrcBuilder, utf8: []const u8, comptime allow_escape: bool) !bool { + const invalid = bun.simdutf.validate.utf8(utf8); + if (!invalid) return false; + if (allow_escape) { + if (needsEscapeUtf8AsciiLatin1(utf8)) { + const bunstr = bun.String.createUTF8(utf8); + defer bunstr.deref(); + try this.appendJSStrRef(bunstr); + return true; + } + } + + try this.appendUTF8Impl(utf8); + return true; } - return true; -} + pub fn appendUTF16Impl(this: *ShellSrcBuilder, utf16: []const u16) !void { + const size = bun.simdutf.simdutf__utf8_length_from_utf16le(utf16.ptr, utf16.len); + try this.outbuf.ensureUnusedCapacity(size); + try bun.strings.convertUTF16ToUTF8Append(this.outbuf, utf16); + } + + pub fn appendUTF8Impl(this: *ShellSrcBuilder, utf8: []const u8) !void { + try this.outbuf.appendSlice(utf8); + } + + pub fn appendLatin1Impl(this: *ShellSrcBuilder, latin1: []const u8) !void { + const non_ascii_idx = bun.strings.firstNonASCII(latin1) orelse 0; + + if (non_ascii_idx > 0) { + try this.appendUTF8Impl(latin1[0..non_ascii_idx]); + } + + this.outbuf.* = try bun.strings.allocateLatin1IntoUTF8WithList(this.outbuf.*, this.outbuf.items.len, []const u8, latin1); + } + + pub fn appendJSStrRef(this: *ShellSrcBuilder, bunstr: bun.String) !void { + const idx = this.jsstrs_to_escape.items.len; + const str = std.fmt.bufPrint(this.jsstr_ref_buf[0..], "{s}{d}", .{ LEX_JS_STRING_PREFIX, idx }) catch { + @panic("Impossible"); + }; + try this.outbuf.appendSlice(str); + bunstr.ref(); + try this.jsstrs_to_escape.append(bunstr); + } +}; /// Characters that need to escaped -const SPECIAL_CHARS = [_]u8{ '$', '>', '&', '|', '=', ';', '\n', '{', '}', ',', '(', ')', '\\', '\"', ' ' }; +const SPECIAL_CHARS = [_]u8{ '$', '>', '&', '|', '=', ';', '\n', '{', '}', ',', '(', ')', '\\', '\"', ' ', '\'' }; /// Characters that need to be backslashed inside double quotes const BACKSLASHABLE_CHARS = [_]u8{ '$', '`', '"', '\\' }; -/// assumes WTF-8 -pub fn escape(str: []const u8, outbuf: *std.ArrayList(u8)) !void { +pub fn escapeBunStr(bunstr: bun.String, outbuf: *std.ArrayList(u8), comptime add_quotes: bool) !bool { + // latin-1 or ascii + if (bunstr.is8Bit()) { + try escape8Bit(bunstr.byteSlice(), outbuf, add_quotes); + return true; + } + if (bunstr.isUTF16()) { + return try escapeUtf16(bunstr.utf16(), outbuf, add_quotes); + } + // Otherwise is utf-8 + try escapeWTF8(bunstr.byteSlice(), outbuf, add_quotes); + return true; +} + +/// works for latin-1 and ascii +pub fn escape8Bit(str: []const u8, outbuf: *std.ArrayList(u8), comptime add_quotes: bool) !void { try outbuf.ensureUnusedCapacity(str.len); - try outbuf.append('\"'); + if (add_quotes) try outbuf.append('\"'); loop: for (str) |c| { inline for (BACKSLASHABLE_CHARS) |spc| { @@ -2658,15 +2899,15 @@ pub fn escape(str: []const u8, outbuf: *std.ArrayList(u8)) !void { try outbuf.append(c); } - try outbuf.append('\"'); + if (add_quotes) try outbuf.append('\"'); } -pub fn escapeUnicode(str: []const u8, outbuf: *std.ArrayList(u8)) !void { +pub fn escapeWTF8(str: []const u8, outbuf: *std.ArrayList(u8), comptime add_quotes: bool) !void { try outbuf.ensureUnusedCapacity(str.len); var bytes: [8]u8 = undefined; - var n = bun.strings.encodeWTF8Rune(bytes[0..4], '"'); - try outbuf.appendSlice(bytes[0..n]); + var n: u3 = if (add_quotes) bun.strings.encodeWTF8Rune(bytes[0..4], '"') else 0; + if (add_quotes) try outbuf.appendSlice(bytes[0..n]); loop: for (str) |c| { inline for (BACKSLASHABLE_CHARS) |spc| { @@ -2686,18 +2927,84 @@ pub fn escapeUnicode(str: []const u8, outbuf: *std.ArrayList(u8)) !void { try outbuf.appendSlice(bytes[0..n]); } - n = bun.strings.encodeWTF8Rune(bytes[0..4], '"'); - try outbuf.appendSlice(bytes[0..n]); + if (add_quotes) { + n = bun.strings.encodeWTF8Rune(bytes[0..4], '"'); + try outbuf.appendSlice(bytes[0..n]); + } +} + +pub fn escapeUtf16(str: []const u16, outbuf: *std.ArrayList(u8), comptime add_quotes: bool) !bool { + if (add_quotes) try outbuf.append('"'); + + const non_ascii = bun.strings.firstNonASCII16([]const u16, str) orelse 0; + var cp_buf: [4]u8 = undefined; + + var i: usize = 0; + loop: while (i < str.len) { + const char: u32 = brk: { + if (i < non_ascii) { + i += 1; + break :brk str[i]; + } + const ret = bun.strings.utf16Codepoint([]const u16, str[i..]); + if (ret.fail) return false; + i += ret.len; + break :brk ret.code_point; + }; + + inline for (BACKSLASHABLE_CHARS) |bchar| { + if (@as(u32, @intCast(bchar)) == char) { + try outbuf.appendSlice(&[_]u8{ '\\', @intCast(char) }); + continue :loop; + } + } + + const len = bun.strings.encodeWTF8RuneT(&cp_buf, u32, char); + try outbuf.appendSlice(cp_buf[0..len]); + } + if (add_quotes) try outbuf.append('"'); + return true; +} + +pub fn needsEscapeBunstr(bunstr: bun.String) bool { + if (bunstr.isUTF16()) return needsEscapeUTF16(bunstr.utf16()); + // Otherwise is utf-8, ascii, or latin-1 + return needsEscapeUtf8AsciiLatin1(bunstr.byteSlice()); +} + +pub fn needsEscapeUTF16Slow(str: []const u16) bool { + for (str) |codeunit| { + inline for (SPECIAL_CHARS) |spc| { + if (@as(u16, @intCast(spc)) == codeunit) return true; + } + } + + return false; } pub fn needsEscapeUTF16(str: []const u16) bool { - for (str) |char| { - switch (char) { - '$', '>', '&', '|', '=', ';', '\n', '{', '}', ',', '(', ')', '\\', '\"', ' ' => return true, - else => {}, + if (str.len < 64) return needsEscapeUTF16Slow(str); + + const needles = comptime brk: { + var needles: [SPECIAL_CHARS.len]@Vector(8, u16) = undefined; + for (SPECIAL_CHARS, 0..) |c, i| { + needles[i] = @splat(@as(u16, @intCast(c))); + } + break :brk needles; + }; + + var i: usize = 0; + while (i + 8 <= str.len) : (i += 8) { + const haystack: @Vector(8, u16) = str[i..][0..8].*; + + inline for (needles) |needle| { + const result = haystack == needle; + if (std.simd.firstTrue(result) != null) return true; } } + if (i < str.len) return needsEscapeUTF16Slow(str[i..]); + return false; } @@ -2705,8 +3012,8 @@ pub fn needsEscapeUTF16(str: []const u16) bool { /// indicates the *possibility* that the string must be escaped, so it can have /// false positives, but it is faster than running the shell lexer through the /// input string for a more correct implementation. -pub fn needsEscape(str: []const u8) bool { - if (str.len < 128) return needsEscapeSlow(str); +pub fn needsEscapeUtf8AsciiLatin1(str: []const u8) bool { + if (str.len < 128) return needsEscapeUtf8AsciiLatin1Slow(str); const needles = comptime brk: { var needles: [SPECIAL_CHARS.len]@Vector(16, u8) = undefined; @@ -2726,12 +3033,12 @@ pub fn needsEscape(str: []const u8) bool { } } - if (i < str.len) return needsEscapeSlow(str[i..]); + if (i < str.len) return needsEscapeUtf8AsciiLatin1Slow(str[i..]); return false; } -pub fn needsEscapeSlow(str: []const u8) bool { +pub fn needsEscapeUtf8AsciiLatin1Slow(str: []const u8) bool { for (str) |c| { inline for (SPECIAL_CHARS) |spc| { if (spc == c) return true; diff --git a/src/string_immutable.zig b/src/string_immutable.zig index 7c4d80a2f6..59fffedfde 100644 --- a/src/string_immutable.zig +++ b/src/string_immutable.zig @@ -1884,7 +1884,7 @@ pub fn convertUTF16ToUTF8Append(list: *std.ArrayList(u8), utf16: []const u16) !v return; } - list.items.len = result.count; + list.items.len += result.count; } pub fn toUTF8AllocWithType(allocator: std.mem.Allocator, comptime Type: type, utf16: Type) ![]u8 { diff --git a/test/js/bun/shell/bunshell.test.ts b/test/js/bun/shell/bunshell.test.ts index 55fdd78142..8bfb96d303 100644 --- a/test/js/bun/shell/bunshell.test.ts +++ b/test/js/bun/shell/bunshell.test.ts @@ -69,6 +69,37 @@ describe("bunshell", () => { `"hello" "lol" "nice"lkasjf;jdfla<>SKDJFLKSF`, `"\\"hello\\" \\"lol\\" \\"nice\\"lkasjf;jdfla<>SKDJFLKSF"`, ); + + test("wrapped in quotes", async () => { + const url = "http://www.example.com?candy_name=M&M"; + await TestBuilder.command`echo url="${url}"`.stdout(`url=${url}\n`).run(); + await TestBuilder.command`echo url='${url}'`.stdout(`url=${url}\n`).run(); + await TestBuilder.command`echo url=${url}`.stdout(`url=${url}\n`).run(); + }); + + test("escape var", async () => { + const shellvar = "$FOO"; + await TestBuilder.command`FOO=bar && echo "${shellvar}"`.stdout(`$FOO\n`).run(); + await TestBuilder.command`FOO=bar && echo '${shellvar}'`.stdout(`$FOO\n`).run(); + await TestBuilder.command`FOO=bar && echo ${shellvar}`.stdout(`$FOO\n`).run(); + }); + + test("can't escape a js string/obj ref", async () => { + const shellvar = "$FOO"; + await TestBuilder.command`FOO=bar && echo \\${shellvar}`.stdout(`$FOO\n`).run(); + const buf = new Uint8Array(1); + await TestBuilder.command`echo hi > \\${buf}`.run(); + }); + + test("in command position", async () => { + const x = "echo hi"; + await TestBuilder.command`${x}`.exitCode(1).stderr("bun: command not found: echo hi\n").run(); + }); + + test("arrays", async () => { + const x = ["echo", "hi"]; + await TestBuilder.command`${x}`.stdout("hi\n").run(); + }); }); describe("quiet", async () => {