diff --git a/src/bake/production.zig b/src/bake/production.zig index 37dbb5ef1b..2b20cc3cf1 100644 --- a/src/bake/production.zig +++ b/src/bake/production.zig @@ -44,6 +44,7 @@ pub fn buildCommand(ctx: bun.CLI.Command.Context) !void { vm.event_loop.ensureWaker(); const b = &vm.transpiler; vm.preload = ctx.preloads; + vm.snapshot_serializers = ctx.snapshot_serializers; vm.argv = ctx.passthrough; vm.arena = &arena; vm.allocator = arena.allocator(); diff --git a/src/bun.js/VirtualMachine.zig b/src/bun.js/VirtualMachine.zig index b2dc00d0d2..5c73301216 100644 --- a/src/bun.js/VirtualMachine.zig +++ b/src/bun.js/VirtualMachine.zig @@ -38,7 +38,9 @@ node_fs: ?*bun.api.node.fs.NodeFS = null, timer: bun.api.Timer.All, event_loop_handle: ?*JSC.PlatformEventLoop = null, pending_unref_counter: i32 = 0, -preload: []const []const u8 = &.{}, + preload: []const []const u8 = &.{}, + snapshot_serializers: []const []const u8 = &.{}, + loaded_snapshot_serializers: []JSC.Strong.Optional = &.{}, unhandled_pending_rejection_to_capture: ?*JSValue = null, standalone_module_graph: ?*bun.StandaloneModuleGraph = null, smol: bool = false, @@ -2053,6 +2055,120 @@ fn loadPreloads(this: *VirtualMachine) !?*JSInternalPromise { return null; } +fn loadSnapshotSerializers(this: *VirtualMachine) !void { + if (this.snapshot_serializers.len == 0) { + return; + } + + // Allocate array for strong references + var serializers = try this.allocator.alloc(JSC.Strong.Optional, this.snapshot_serializers.len); + var loaded_count: usize = 0; + + for (this.snapshot_serializers) |serializer_path| { + var result = switch (this.transpiler.resolver.resolveAndAutoInstall( + this.transpiler.fs.top_level_dir, + normalizeSource(serializer_path), + .stmt, + if (this.standalone_module_graph == null) .read_only else .disable, + )) { + .success => |r| r, + .failure => |e| { + this.log.addErrorFmt( + null, + logger.Loc.Empty, + this.allocator, + "{s} resolving snapshot serializer {}", + .{ + @errorName(e), + bun.fmt.formatJSONStringLatin1(serializer_path), + }, + ) catch unreachable; + return e; + }, + .pending, .not_found => { + this.log.addErrorFmt( + null, + logger.Loc.Empty, + this.allocator, + "snapshot serializer not found {}", + .{ + bun.fmt.formatJSONStringLatin1(serializer_path), + }, + ) catch unreachable; + return error.ModuleNotFound; + }, + }; + + var promise = try JSModuleLoader.import(this.global, &String.fromBytes(result.path().?.text)); + + this.pending_internal_promise = promise; + JSValue.fromCell(promise).protect(); + defer JSValue.fromCell(promise).unprotect(); + + if (this.isWatcherEnabled()) { + this.eventLoop().performGC(); + switch (this.pending_internal_promise.?.status(this.global.vm())) { + .pending => { + while (this.pending_internal_promise.?.status(this.global.vm()) == .pending) { + this.eventLoop().tick(); + + if (this.pending_internal_promise.?.status(this.global.vm()) == .pending) { + this.eventLoop().autoTick(); + } + } + }, + else => {}, + } + } else { + this.eventLoop().performGC(); + this.waitForPromise(JSC.AnyPromise{ + .internal = promise, + }); + } + + if (promise.status(this.global.vm()) == .rejected) { + this.log.addErrorFmt( + null, + logger.Loc.Empty, + this.allocator, + "snapshot serializer failed to load {}", + .{ + bun.fmt.formatJSONStringLatin1(serializer_path), + }, + ) catch unreachable; + continue; + } + + // Get the module's exports + const module_result = promise.result(this.global.vm()); + var default_export = module_result.fastGet(this.global, .default) orelse module_result; + + // Check if it's a valid serializer (has test and serialize methods) + if (default_export.isObject()) { + const has_test = default_export.fastGet(this.global, .test) != null; + const has_serialize = default_export.fastGet(this.global, .serialize) != null; + + if (has_test and has_serialize) { + serializers[loaded_count] = JSC.Strong.Optional.create(default_export, this.global); + loaded_count += 1; + } else { + this.log.addErrorFmt( + null, + logger.Loc.Empty, + this.allocator, + "snapshot serializer must export test and serialize methods {}", + .{ + bun.fmt.formatJSONStringLatin1(serializer_path), + }, + ) catch unreachable; + } + } + } + + // Store only the successfully loaded serializers + this.loaded_snapshot_serializers = serializers[0..loaded_count]; +} + pub fn ensureDebugger(this: *VirtualMachine, block_until_connected: bool) !void { if (this.debugger != null) { try JSC.Debugger.create(this, this.global); @@ -2150,6 +2266,9 @@ pub fn reloadEntryPointForTestRunner(this: *VirtualMachine, entry_path: []const return promise; } + + // Load snapshot serializers after preloads + try this.loadSnapshotSerializers(); } const promise = JSModuleLoader.loadAndEvaluateModule(this.global, &String.fromBytes(this.main)) orelse return error.JSError; diff --git a/src/bun.js/test/pretty_format.zig b/src/bun.js/test/pretty_format.zig index b544956606..65edf63f5f 100644 --- a/src/bun.js/test/pretty_format.zig +++ b/src/bun.js/test/pretty_format.zig @@ -1982,11 +1982,113 @@ pub const JestPrettyFormat = struct { } } + pub fn trySnapshotSerializers(this: *JestPrettyFormat.Formatter, comptime Writer: type, writer: Writer, value: JSValue, globalThis: *JSGlobalObject, comptime enable_ansi_colors: bool) bun.JSError!bool { + const vm = globalThis.bunVM(); + for (vm.loaded_snapshot_serializers) |serializer_strong| { + if (serializer_strong.get()) |serializer| { + // Call the test function to check if this serializer should handle this value + const test_function = serializer.fastGet(globalThis, .test) orelse continue; + const test_result = test_function.call(globalThis, serializer, &[_]JSValue{value}) catch continue; + + if (test_result.toBoolean()) { + // This serializer should handle this value, call the serialize function + const serialize_function = serializer.fastGet(globalThis, .serialize) orelse continue; + + // Create printer context + const printer_context = PrinterContext(Writer){ + .formatter = this, + .writer = writer, + .globalThis = globalThis, + .enable_ansi_colors = enable_ansi_colors, + }; + + // Use threadlocal storage to pass the context + printer_context_storage = @ptrCast(&printer_context); + defer printer_context_storage = null; + + // Create a printer function that this serializer can use + const printer = JSC.JSFunction.create(globalThis, "printer", 1, printerCallback, false, false); + + // Call serialize(val, config, indentation, depth, refs, printer) + const config = JSC.JSValue.createEmptyObject(globalThis, 0); + const indentation = JSC.JSValue.jsNumberFromInt32(@as(i32, @intCast(this.indent))); + const depth = JSC.JSValue.jsNumberFromInt32(0); // TODO: track depth + const refs = JSC.JSValue.createEmptyObject(globalThis, 0); // TODO: track refs + + const result = serialize_function.call(globalThis, serializer, &[_]JSValue{ + value, + config, + indentation, + depth, + refs, + printer, + }) catch continue; + + if (result.isString()) { + const str = result.toSlice(globalThis, globalThis.allocator()); + defer str.deinit(); + writer.writeAll(str.slice()) catch {}; + return true; + } + } + } + } + return false; + } + + // Context for the printer callback + fn PrinterContext(comptime Writer: type) type { + return struct { + formatter: *JestPrettyFormat.Formatter, + writer: Writer, + globalThis: *JSGlobalObject, + enable_ansi_colors: bool, + }; + } + + // Threadlocal storage for printer context + threadlocal var printer_context_storage: ?*anyopaque = null; + + // Printer callback function that serializers can use + fn printerCallback(globalThis: *JSGlobalObject, callFrame: *JSC.CallFrame) callconv(JSC.conv) JSValue { + const args = callFrame.arguments(1); + if (args.len < 1) return JSValue.jsUndefined(); + + const value = args.ptr[0]; + + // Get the context from threadlocal storage + const context_ptr = printer_context_storage orelse return JSValue.jsUndefined(); + + // We need to handle this generically since we don't know the Writer type at compile time + // For now, just format the value as a string and return it + var temp_formatter = JestPrettyFormat.Formatter{ + .remaining_values = &[_]JSValue{}, + .globalThis = globalThis, + .quote_strings = true, + .indent = 0, + }; + + const tag = Tag.get(value, globalThis) catch return JSValue.jsUndefined(); + + // Create a string buffer to capture the formatted output + var buffer = std.ArrayList(u8).init(globalThis.allocator()); + defer buffer.deinit(); + + temp_formatter.format(tag, @TypeOf(buffer.writer()), buffer.writer(), value, globalThis, false) catch {}; + + return JSC.ZigString.fromUTF8(buffer.items).toValueGC(globalThis); + } + pub fn format(this: *JestPrettyFormat.Formatter, result: Tag.Result, comptime Writer: type, writer: Writer, value: JSValue, globalThis: *JSGlobalObject, comptime enable_ansi_colors: bool) bun.JSError!void { const prevGlobalThis = this.globalThis; defer this.globalThis = prevGlobalThis; this.globalThis = globalThis; + // Try snapshot serializers first + if (try this.trySnapshotSerializers(Writer, writer, value, globalThis, enable_ansi_colors)) { + return; + } + // This looks incredibly redundant. We make the JestPrettyFormat.Formatter.Tag a // comptime var so we have to repeat it here. The rationale there is // it _should_ limit the stack usage because each version of the diff --git a/src/bun.js/web_worker.zig b/src/bun.js/web_worker.zig index f545a8bc3d..466f696b34 100644 --- a/src/bun.js/web_worker.zig +++ b/src/bun.js/web_worker.zig @@ -441,7 +441,8 @@ fn spin(this: *WebWorker) void { var vm = this.vm.?; assert(this.status.load(.acquire) == .start); this.setStatus(.starting); - vm.preload = this.preloads; + vm.preload = this.preloads; + vm.snapshot_serializers = &.{}; // resolve entrypoint var resolve_error = bun.String.empty; defer resolve_error.deref(); diff --git a/src/bun_js.zig b/src/bun_js.zig index b9ee1001a7..d626c6dd65 100644 --- a/src/bun_js.zig +++ b/src/bun_js.zig @@ -65,6 +65,7 @@ pub const Run = struct { var vm = run.vm; var b = &vm.transpiler; vm.preload = ctx.preloads; + vm.snapshot_serializers = ctx.snapshot_serializers; vm.argv = ctx.passthrough; vm.arena = &run.arena; vm.allocator = arena.allocator(); @@ -204,6 +205,7 @@ pub const Run = struct { var vm = run.vm; var b = &vm.transpiler; vm.preload = ctx.preloads; + vm.snapshot_serializers = ctx.snapshot_serializers; vm.argv = ctx.passthrough; vm.arena = &run.arena; vm.allocator = arena.allocator(); diff --git a/src/bunfig.zig b/src/bunfig.zig index 40cc8ec7a0..1f2c4b0971 100644 --- a/src/bunfig.zig +++ b/src/bunfig.zig @@ -165,6 +165,32 @@ pub const Bunfig = struct { } } + fn loadSnapshotSerializers( + this: *Parser, + allocator: std.mem.Allocator, + expr: js_ast.Expr, + ) !void { + if (expr.asArray()) |array_| { + var array = array_; + var serializers = try std.ArrayList(string).initCapacity(allocator, array.array.items.len); + errdefer serializers.deinit(); + while (array.next()) |item| { + try this.expectString(item); + if (item.data.e_string.len() > 0) + serializers.appendAssumeCapacity(try item.data.e_string.string(allocator)); + } + this.ctx.snapshot_serializers = serializers.items; + } else if (expr.data == .e_string) { + if (expr.data.e_string.len() > 0) { + var serializers = try allocator.alloc(string, 1); + serializers[0] = try expr.data.e_string.string(allocator); + this.ctx.snapshot_serializers = serializers; + } + } else if (expr.data != .e_null) { + try this.addError(expr.loc, "Expected snapshotSerializers to be an array"); + } + } + pub fn parse(this: *Parser, comptime cmd: Command.Tag) !void { bun.analytics.Features.bunfig += 1; @@ -246,6 +272,10 @@ pub const Bunfig = struct { try this.loadPreload(allocator, expr); } + if (test_.get("snapshotSerializers")) |expr| { + try this.loadSnapshotSerializers(allocator, expr); + } + if (test_.get("smol")) |expr| { try this.expect(expr, .e_boolean); this.ctx.runtime_options.smol = expr.data.e_boolean.value; diff --git a/src/cli.zig b/src/cli.zig index 96c3f2d4b1..5fb247ae03 100644 --- a/src/cli.zig +++ b/src/cli.zig @@ -412,6 +412,7 @@ pub const Command = struct { filters: []const []const u8 = &.{}, preloads: []const string = &.{}, + snapshot_serializers: []const string = &.{}, has_loaded_global_config: bool = false, pub const BundlerOptions = struct { diff --git a/src/cli/test_command.zig b/src/cli/test_command.zig index d2e73e588d..56b505ed54 100644 --- a/src/cli/test_command.zig +++ b/src/cli/test_command.zig @@ -1109,6 +1109,7 @@ pub const TestCommand = struct { ); vm.argv = ctx.passthrough; vm.preload = ctx.preloads; + vm.snapshot_serializers = ctx.snapshot_serializers; vm.transpiler.options.rewrite_jest_for_tests = true; vm.transpiler.options.env.behavior = .load_all_without_inlining; diff --git a/test_snapshot_serializers.md b/test_snapshot_serializers.md new file mode 100644 index 0000000000..0b05c83dc1 --- /dev/null +++ b/test_snapshot_serializers.md @@ -0,0 +1,75 @@ +# Custom Snapshot Serializer Support in Bun Test + +This implementation adds support for custom snapshot serializers in bun:test, following Jest's API. + +## Configuration + +Add snapshot serializers to your `bunfig.toml`: + +```toml +[test] +snapshotSerializers = ["./my-serializer.js"] +``` + +## API + +Snapshot serializers should export an object with `test` and `serialize` methods: + +```javascript +// my-serializer.js +module.exports = { + test(val) { + return val && Object.prototype.hasOwnProperty.call(val, 'foo'); + }, + + serialize(val, config, indentation, depth, refs, printer) { + return `Pretty foo: ${printer(val.foo)}`; + } +}; +``` + +Or using ES modules: + +```javascript +// my-serializer.js +export default { + test(val) { + return val && Object.prototype.hasOwnProperty.call(val, 'foo'); + }, + + serialize(val, config, indentation, depth, refs, printer) { + return `Pretty foo: ${printer(val.foo)}`; + } +}; +``` + +## Test Example + +```javascript +// test.js +import { expect, test } from 'bun:test'; + +test('snapshot serializer', () => { + const obj = { foo: 'bar', baz: 123 }; + expect(obj).toMatchSnapshot(); + // Output: Pretty foo: "bar" +}); +``` + +## Implementation Details + +1. **Configuration Parsing**: Added `snapshotSerializers` parsing in `bunfig.zig` +2. **Module Loading**: Added `loadSnapshotSerializers()` function in `VirtualMachine.zig` +3. **Pretty Format Integration**: Added `trySnapshotSerializers()` in `pretty_format.zig` +4. **Strong References**: Loaded serializers are stored as `JSC.Strong.Optional` to prevent garbage collection + +## Files Modified + +- `src/bunfig.zig`: Added configuration parsing +- `src/cli.zig`: Added snapshot_serializers field to context +- `src/bun.js/VirtualMachine.zig`: Added loading and storage of serializers +- `src/bun.js/test/pretty_format.zig`: Added serializer integration +- `src/cli/test_command.zig`: Added serializer assignment +- `src/bun_js.zig`: Added serializer assignment +- `src/bake/production.zig`: Added serializer assignment +- `src/bun.js/web_worker.zig`: Added empty serializer assignment \ No newline at end of file