From 7173593a807dda3d795a09ff8e644fdb093cdebe Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Wed, 15 Jan 2025 01:02:31 -0800 Subject: [PATCH] Fix error handling bugs in HTMLRewriter API (#16368) Co-authored-by: Ciro Spaciari --- src/bun.js/api/html_rewriter.zig | 271 ++++++++++++++++++-------- test/js/workerd/html-rewriter.test.js | 44 +++++ 2 files changed, 235 insertions(+), 80 deletions(-) diff --git a/src/bun.js/api/html_rewriter.zig b/src/bun.js/api/html_rewriter.zig index 65ec201080..ced09dd55e 100644 --- a/src/bun.js/api/html_rewriter.zig +++ b/src/bun.js/api/html_rewriter.zig @@ -68,7 +68,7 @@ pub const HTMLRewriter = struct { const selector_slice = std.fmt.allocPrint(bun.default_allocator, "{}", .{selector_name}) catch bun.outOfMemory(); var selector = LOLHTML.HTMLSelector.parse(selector_slice) catch - return throwLOLHTMLError(global); + return createLOLHTMLError(global); const handler_ = try ElementHandler.init(global, listener); const handler = getAllocator(global).create(ElementHandler) catch bun.outOfMemory(); handler.* = handler_; @@ -98,7 +98,7 @@ pub const HTMLRewriter = struct { null, ) catch { selector.deinit(); - return throwLOLHTMLError(global); + return createLOLHTMLError(global); }; this.context.selectors.append(bun.default_allocator, selector) catch bun.outOfMemory(); @@ -152,7 +152,7 @@ pub const HTMLRewriter = struct { return callFrame.this(); } - pub fn finalize(this: *HTMLRewriter) callconv(.C) void { + pub fn finalize(this: *HTMLRewriter) void { this.finalizeWithoutDestroy(); bun.default_allocator.destroy(this); } @@ -394,15 +394,19 @@ pub const HTMLRewriter = struct { response_value: JSC.Strong = .{}, bodyValueBufferer: ?JSC.WebCore.BodyValueBufferer = null, tmp_sync_error: ?*JSC.JSValue = null, + ref_count: u32 = 1, + pub usingnamespace bun.NewRefCounted(BufferOutputSink, deinit); + // const log = bun.Output.scoped(.BufferOutputSink, false); pub fn init(context: *LOLHTMLContext, global: *JSGlobalObject, original: *Response, builder: *LOLHTML.HTMLRewriter.Builder) JSC.JSValue { - var sink = bun.new(BufferOutputSink, BufferOutputSink{ + var sink = BufferOutputSink.new(.{ .global = global, .bytes = bun.MutableString.initEmpty(bun.default_allocator), .rewriter = null, .context = context, .response = undefined, }); + defer sink.deref(); var result = bun.new(Response, .{ .init = .{ .status_code = 200, @@ -418,8 +422,24 @@ pub const HTMLRewriter = struct { }); sink.response = result; - + var sink_error: JSC.JSValue = .zero; const input_size = original.body.len(); + var vm = global.bunVM(); + + // Since we're still using vm.waitForPromise, we have to also + // override the error rejection handler. That way, we can propagate + // errors to the caller. + var scope = vm.unhandledRejectionScope(); + const prev_unhandled_pending_rejection_to_capture = vm.unhandled_pending_rejection_to_capture; + vm.unhandled_pending_rejection_to_capture = &sink_error; + sink.tmp_sync_error = &sink_error; + vm.onUnhandledRejection = &JSC.VirtualMachine.onQuietUnhandledRejectionHandlerCaptureValue; + defer { + sink_error.ensureStillAlive(); + vm.unhandled_pending_rejection_to_capture = prev_unhandled_pending_rejection_to_capture; + scope.apply(vm); + } + sink.rewriter = builder.build( .UTF8, .{ @@ -435,9 +455,8 @@ pub const HTMLRewriter = struct { BufferOutputSink.write, BufferOutputSink.done, ) catch { - sink.deinit(); result.finalize(); - return throwLOLHTMLError(global); + return createLOLHTMLError(global); }; result.init.method = original.init.method; @@ -454,12 +473,14 @@ pub const HTMLRewriter = struct { sink.response_value.set(global, response_js_value); result.url = original.url.clone(); - var sink_error: JSC.JSValue = .zero; - sink.tmp_sync_error = &sink_error; + const value = original.getBodyValue(); - sink.bodyValueBufferer = JSC.WebCore.BodyValueBufferer.init(sink, onFinishedBuffering, sink.global, bun.default_allocator); + sink.ref(); + sink.bodyValueBufferer = JSC.WebCore.BodyValueBufferer.init(sink, @ptrCast(&onFinishedBuffering), sink.global, bun.default_allocator); response_js_value.ensureStillAlive(); + sink.bodyValueBufferer.?.run(value) catch |buffering_error| { + defer sink.deref(); return switch (buffering_error) { error.StreamAlreadyUsed => { var err = JSC.SystemError{ @@ -482,7 +503,6 @@ pub const HTMLRewriter = struct { if (sink_error != .zero) { sink_error.ensureStillAlive(); sink_error.unprotect(); - defer sink.deinit(); return sink_error; } @@ -491,8 +511,8 @@ pub const HTMLRewriter = struct { return response_js_value; } - pub fn onFinishedBuffering(ctx: *anyopaque, bytes: []const u8, js_err: ?JSC.WebCore.Body.Value.ValueError, is_async: bool) void { - const sink = bun.cast(*BufferOutputSink, ctx); + pub fn onFinishedBuffering(sink: *BufferOutputSink, bytes: []const u8, js_err: ?JSC.WebCore.Body.Value.ValueError, is_async: bool) void { + defer sink.deref(); if (js_err) |err| { if (sink.response.body.value == .Locked and @intFromPtr(sink.response.body.value.Locked.task) == @intFromPtr(sink) and sink.response.body.value.Locked.promise == null) @@ -510,13 +530,13 @@ pub const HTMLRewriter = struct { if (is_async) { sink.response.body.value.toErrorInstance(err.dupe(sink.global), sink.global); } else { - var ret_err = throwLOLHTMLError(sink.global); + var ret_err = createLOLHTMLError(sink.global); ret_err.ensureStillAlive(); ret_err.protect(); sink.tmp_sync_error.?.* = ret_err; } sink.rewriter.?.end() catch {}; - sink.deinit(); + return; } @@ -524,9 +544,7 @@ pub const HTMLRewriter = struct { ret_err.ensureStillAlive(); ret_err.protect(); sink.tmp_sync_error.?.* = ret_err; - } else { - sink.deinit(); - } + } else {} } pub fn runOutputSink( @@ -539,27 +557,22 @@ pub const HTMLRewriter = struct { var response = sink.response; sink.rewriter.?.write(bytes) catch { - sink.deinit(); - if (is_async) { - response.body.value.toErrorInstance(.{ .Message = throwLOLHTMLStringError() }, global); - + response.body.value.toErrorInstance(.{ .Message = createLOLHTMLStringError() }, global); return null; } else { - return throwLOLHTMLError(global); + return createLOLHTMLError(global); } }; sink.rewriter.?.end() catch { if (!is_async) response.finalize(); sink.response = undefined; - sink.deinit(); - if (is_async) { - response.body.value.toErrorInstance(.{ .Message = throwLOLHTMLStringError() }, global); + response.body.value.toErrorInstance(.{ .Message = createLOLHTMLStringError() }, global); return null; } else { - return throwLOLHTMLError(global); + return createLOLHTMLError(global); } }; @@ -595,7 +608,7 @@ pub const HTMLRewriter = struct { this.bytes.append(bytes) catch bun.outOfMemory(); } - pub fn deinit(this: *BufferOutputSink) void { + fn deinit(this: *BufferOutputSink) void { this.bytes.deinit(); if (this.bodyValueBufferer) |*bufferer| { bufferer.deinit(); @@ -607,7 +620,7 @@ pub const HTMLRewriter = struct { rewriter.deinit(); } - bun.destroy(this); + this.destroy(); } }; @@ -649,7 +662,7 @@ pub const HTMLRewriter = struct { // sink.deinit(); // bun.default_allocator.destroy(result); - // return throwLOLHTMLError(global); + // return createLOLHTMLError(global); // }; // result.* = Response{ @@ -866,9 +879,18 @@ fn HandlerCallback( return struct { pub fn callback(this: *HandlerType, value: *LOLHTMLType) bool { JSC.markBinding(@src()); - var zig_element = bun.default_allocator.create(ZigType) catch bun.outOfMemory(); - @field(zig_element, field_name) = value; - defer @field(zig_element, field_name) = null; + + var wrapper = ZigType.init(value); + + // All of these start with a ref_count of 2. + // 1. For this scope. + // 2. For the JS value. + bun.debugAssert(wrapper.ref_count == 2); + + defer { + @field(wrapper, field_name) = null; + wrapper.deref(); + } const result = @field(this, callback_name).?.call( this.global, @@ -876,8 +898,11 @@ fn HandlerCallback( @field(this, "thisObject") else JSValue.zero, - &.{zig_element.toJS(this.global)}, - ) catch |err| this.global.takeException(err); + &.{wrapper.toJS(this.global)}, + ) catch { + // If there's an error, we'll propagate it to the caller. + return true; + }; if (!result.isUndefinedOrNull()) { if (result.isError() or result.isAggregateError(this.global)) { @@ -1006,15 +1031,30 @@ pub const ContentOptions = struct { html: bool = false, }; -fn throwLOLHTMLError(global: *JSGlobalObject) JSValue { - const err = LOLHTML.HTMLString.lastError(); - defer err.deinit(); - return ZigString.fromUTF8(err.slice()).toErrorInstance(global); +fn createLOLHTMLError(global: *JSGlobalObject) JSValue { + // If there was already a pending exception, we want to use that instead. + if (global.tryTakeException()) |err| { + // it's a synchronous error + return err; + } else if (global.bunVM().unhandled_pending_rejection_to_capture) |err_ptr| { + if (err_ptr.* != .zero) { + // it's a promise rejection + const result = err_ptr.*; + err_ptr.* = .zero; + return result; + } + } + + var err = createLOLHTMLStringError(); + const value = err.toErrorInstance(global); + value.put(global, "name", ZigString.init("HTMLRewriterError").toJS(global)); + return value; } -fn throwLOLHTMLStringError() bun.String { +fn createLOLHTMLStringError() bun.String { + // We must clone this string. const err = LOLHTML.HTMLString.lastError(); defer err.deinit(); - return bun.String.fromUTF8(err.slice()); + return bun.String.createUTF8(err.slice()); } fn htmlStringValue(input: LOLHTML.HTMLString, globalObject: *JSGlobalObject) JSValue { @@ -1023,8 +1063,13 @@ fn htmlStringValue(input: LOLHTML.HTMLString, globalObject: *JSGlobalObject) JSV pub const TextChunk = struct { text_chunk: ?*LOLHTML.TextChunk = null, + ref_count: u32 = 1, pub usingnamespace JSC.Codegen.JSTextChunk; + pub usingnamespace bun.NewRefCounted(@This(), deinit); + pub fn init(text_chunk: *LOLHTML.TextChunk) *TextChunk { + return TextChunk.new(.{ .text_chunk = text_chunk, .ref_count = 2 }); + } fn contentHandler(this: *TextChunk, comptime Callback: (fn (*LOLHTML.TextChunk, []const u8, bool) LOLHTML.Error!void), thisObject: JSValue, globalObject: *JSGlobalObject, content: ZigString, contentOptions: ?ContentOptions) JSValue { if (this.text_chunk == null) @@ -1036,7 +1081,7 @@ pub const TextChunk = struct { this.text_chunk.?, content_slice.slice(), contentOptions != null and contentOptions.?.html, - ) catch return throwLOLHTMLError(globalObject); + ) catch return createLOLHTMLError(globalObject); return thisObject; } @@ -1103,20 +1148,34 @@ pub const TextChunk = struct { return JSValue.jsBoolean(this.text_chunk.?.isLastInTextNode()); } - pub fn finalize(this: *TextChunk) callconv(.C) void { + pub fn finalize(this: *TextChunk) void { + this.deref(); + } + + pub fn deinit(this: *TextChunk) void { this.text_chunk = null; - bun.default_allocator.destroy(this); + this.destroy(); } }; pub const DocType = struct { doctype: ?*LOLHTML.DocType = null, + ref_count: u32 = 1, - pub fn finalize(this: *DocType) callconv(.C) void { + pub fn deinit(this: *DocType) void { this.doctype = null; - bun.default_allocator.destroy(this); + this.destroy(); } + pub fn finalize(this: *DocType) void { + this.deref(); + } + + pub fn init(doctype: *LOLHTML.DocType) *DocType { + return DocType.new(.{ .doctype = doctype, .ref_count = 2 }); + } + + pub usingnamespace bun.NewRefCounted(@This(), deinit); pub usingnamespace JSC.Codegen.JSDocType; /// The doctype name. @@ -1181,14 +1240,15 @@ pub const DocType = struct { pub const DocEnd = struct { doc_end: ?*LOLHTML.DocEnd, + ref_count: u32 = 1, - pub fn finalize(this: *DocEnd) callconv(.C) void { - this.doc_end = null; - bun.default_allocator.destroy(this); - } - + pub usingnamespace bun.NewRefCounted(@This(), deinit); pub usingnamespace JSC.Codegen.JSDocEnd; + pub fn init(doc_end: *LOLHTML.DocEnd) *DocEnd { + return DocEnd.new(.{ .doc_end = doc_end, .ref_count = 2 }); + } + fn contentHandler(this: *DocEnd, comptime Callback: (fn (*LOLHTML.DocEnd, []const u8, bool) LOLHTML.Error!void), thisObject: JSValue, globalObject: *JSGlobalObject, content: ZigString, contentOptions: ?ContentOptions) JSValue { if (this.doc_end == null) return JSValue.jsNull(); @@ -1200,7 +1260,7 @@ pub const DocEnd = struct { this.doc_end.?, content_slice.slice(), contentOptions != null and contentOptions.?.html, - ) catch return throwLOLHTMLError(globalObject); + ) catch return createLOLHTMLError(globalObject); return thisObject; } @@ -1216,18 +1276,28 @@ pub const DocEnd = struct { } pub const append = JSC.wrapInstanceMethod(DocEnd, "append_", false); + + pub fn finalize(this: *DocEnd) void { + this.deref(); + } + + pub fn deinit(this: *DocEnd) void { + this.doc_end = null; + this.destroy(); + } }; pub const Comment = struct { comment: ?*LOLHTML.Comment = null, + ref_count: u32 = 1, - pub fn finalize(this: *Comment) callconv(.C) void { - this.comment = null; - bun.default_allocator.destroy(this); - } - + pub usingnamespace bun.NewRefCounted(@This(), deinit); pub usingnamespace JSC.Codegen.JSComment; + pub fn init(comment: *LOLHTML.Comment) *Comment { + return Comment.new(.{ .comment = comment, .ref_count = 2 }); + } + fn contentHandler(this: *Comment, comptime Callback: (fn (*LOLHTML.Comment, []const u8, bool) LOLHTML.Error!void), thisObject: JSValue, globalObject: *JSGlobalObject, content: ZigString, contentOptions: ?ContentOptions) JSValue { if (this.comment == null) return JSValue.jsNull(); @@ -1238,7 +1308,7 @@ pub const Comment = struct { this.comment.?, content_slice.slice(), contentOptions != null and contentOptions.?.html, - ) catch return throwLOLHTMLError(globalObject); + ) catch return createLOLHTMLError(globalObject); return thisObject; } @@ -1307,7 +1377,7 @@ pub const Comment = struct { var text = value.toSlice(global, bun.default_allocator); defer text.deinit(); this.comment.?.setText(text.slice()) catch { - global.throwValue(throwLOLHTMLError(global)) catch {}; + global.throwValue(createLOLHTMLError(global)) catch {}; return false; }; @@ -1322,14 +1392,32 @@ pub const Comment = struct { return JSValue.jsUndefined(); return JSValue.jsBoolean(this.comment.?.isRemoved()); } + + pub fn finalize(this: *Comment) void { + this.deref(); + } + + pub fn deinit(this: *Comment) void { + this.comment = null; + this.destroy(); + } }; pub const EndTag = struct { end_tag: ?*LOLHTML.EndTag, + ref_count: u32 = 1, - pub fn finalize(this: *EndTag) callconv(.C) void { + pub fn init(end_tag: *LOLHTML.EndTag) *EndTag { + return EndTag.new(.{ .end_tag = end_tag, .ref_count = 2 }); + } + + pub fn finalize(this: *EndTag) void { + this.deref(); + } + + pub fn deinit(this: *EndTag) void { this.end_tag = null; - bun.default_allocator.destroy(this); + this.destroy(); } pub const Handler = struct { @@ -1348,6 +1436,7 @@ pub const EndTag = struct { }; pub usingnamespace JSC.Codegen.JSEndTag; + pub usingnamespace bun.NewRefCounted(@This(), deinit); fn contentHandler(this: *EndTag, comptime Callback: (fn (*LOLHTML.EndTag, []const u8, bool) LOLHTML.Error!void), thisObject: JSValue, globalObject: *JSGlobalObject, content: ZigString, contentOptions: ?ContentOptions) JSValue { if (this.end_tag == null) @@ -1360,7 +1449,7 @@ pub const EndTag = struct { this.end_tag.?, content_slice.slice(), contentOptions != null and contentOptions.?.html, - ) catch return throwLOLHTMLError(globalObject); + ) catch return createLOLHTMLError(globalObject); return thisObject; } @@ -1431,7 +1520,7 @@ pub const EndTag = struct { var text = value.toSlice(global, bun.default_allocator); defer text.deinit(); this.end_tag.?.setName(text.slice()) catch { - global.throwValue(throwLOLHTMLError(global)) catch {}; + global.throwValue(createLOLHTMLError(global)) catch {}; return false; }; @@ -1441,16 +1530,31 @@ pub const EndTag = struct { pub const AttributeIterator = struct { iterator: ?*LOLHTML.Attribute.Iterator = null, + ref_count: u32 = 1, - pub fn finalize(this: *AttributeIterator) callconv(.C) void { + pub fn init(iterator: *LOLHTML.Attribute.Iterator) *AttributeIterator { + return AttributeIterator.new(.{ .iterator = iterator, .ref_count = 2 }); + } + + fn detach(this: *AttributeIterator) void { if (this.iterator) |iter| { iter.deinit(); this.iterator = null; } - bun.default_allocator.destroy(this); + } + + pub fn finalize(this: *AttributeIterator) void { + this.detach(); + this.deref(); + } + + pub fn deinit(this: *AttributeIterator) void { + this.detach(); + this.destroy(); } pub usingnamespace JSC.Codegen.JSAttributeIterator; + pub usingnamespace bun.NewRefCounted(@This(), deinit); pub fn next(this: *AttributeIterator, globalObject: *JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { const done_label = JSC.ZigString.static("done"); @@ -1484,12 +1588,22 @@ pub const AttributeIterator = struct { }; pub const Element = struct { element: ?*LOLHTML.Element = null, + ref_count: u32 = 1, pub usingnamespace JSC.Codegen.JSElement; + pub usingnamespace bun.NewRefCounted(@This(), deinit); - pub fn finalize(this: *Element) callconv(.C) void { + pub fn init(element: *LOLHTML.Element) *Element { + return Element.new(.{ .element = element, .ref_count = 2 }); + } + + pub fn finalize(this: *Element) void { + this.deref(); + } + + pub fn deinit(this: *Element) void { this.element = null; - bun.default_allocator.destroy(this); + this.destroy(); } pub fn onEndTag_( @@ -1509,7 +1623,8 @@ pub const Element = struct { this.element.?.onEndTag(EndTag.Handler.onEndTagHandler, end_tag_handler) catch { bun.default_allocator.destroy(end_tag_handler); - return throwLOLHTMLError(globalObject); + const err = createLOLHTMLError(globalObject); + return globalObject.throwValue(err); }; function.protect(); @@ -1540,7 +1655,7 @@ pub const Element = struct { var slice = name.toSlice(bun.default_allocator); defer slice.deinit(); - return JSValue.jsBoolean(this.element.?.hasAttribute(slice.slice()) catch return throwLOLHTMLError(global)); + return JSValue.jsBoolean(this.element.?.hasAttribute(slice.slice()) catch return createLOLHTMLError(global)); } /// Sets an attribute to a provided value, creating the attribute if it does not exist. @@ -1553,7 +1668,7 @@ pub const Element = struct { var value_slice = value_.toSlice(bun.default_allocator); defer value_slice.deinit(); - this.element.?.setAttribute(name_slice.slice(), value_slice.slice()) catch return throwLOLHTMLError(globalObject); + this.element.?.setAttribute(name_slice.slice(), value_slice.slice()) catch return createLOLHTMLError(globalObject); return callFrame.this(); } @@ -1567,7 +1682,7 @@ pub const Element = struct { this.element.?.removeAttribute( name_slice.slice(), - ) catch return throwLOLHTMLError(globalObject); + ) catch return createLOLHTMLError(globalObject); return callFrame.this(); } @@ -1588,7 +1703,7 @@ pub const Element = struct { this.element.?, content_slice.slice(), contentOptions != null and contentOptions.?.html, - ) catch return throwLOLHTMLError(globalObject); + ) catch return createLOLHTMLError(globalObject); return thisObject; } @@ -1716,7 +1831,7 @@ pub const Element = struct { defer text.deinit(); this.element.?.setTagName(text.slice()) catch { - global.throwValue(throwLOLHTMLError(global)) catch {}; + global.throwValue(createLOLHTMLError(global)) catch {}; return false; }; @@ -1768,12 +1883,8 @@ pub const Element = struct { if (this.element == null) return JSValue.jsUndefined(); - const iter = this.element.?.attributes() orelse return throwLOLHTMLError(globalObject); - var attr_iter = bun.default_allocator.create(AttributeIterator) catch bun.outOfMemory(); - attr_iter.* = .{ .iterator = iter }; - var js_attr_iter = attr_iter.toJS(globalObject); - js_attr_iter.protect(); - defer js_attr_iter.unprotect(); - return js_attr_iter; + const iter = this.element.?.attributes() orelse return createLOLHTMLError(globalObject); + var attr_iter = AttributeIterator.new(.{ .iterator = iter, .ref_count = 1 }); + return attr_iter.toJS(globalObject); } }; diff --git a/test/js/workerd/html-rewriter.test.js b/test/js/workerd/html-rewriter.test.js index 2bdd093be4..4e58b43ac5 100644 --- a/test/js/workerd/html-rewriter.test.js +++ b/test/js/workerd/html-rewriter.test.js @@ -15,6 +15,50 @@ var setTimeoutAsync = (fn, delay) => { }; describe("HTMLRewriter", () => { + it("error handling", () => { + expect(() => new HTMLRewriter().transform(Symbol("ok"))).toThrow(); + }); + + it("error inside element handler", () => { + expect(() => + new HTMLRewriter() + .on("div", { + element(element) { + throw new Error("test"); + }, + }) + .transform(new Response("
hello
")), + ).toThrow("test"); + }); + + it("error inside element handler (string)", () => { + expect(() => + new HTMLRewriter() + .on("div", { + element(element) { + throw new Error("test"); + }, + }) + .transform("
hello
"), + ).toThrow("test"); + }); + + it("async error inside element handler", async () => { + try { + await new HTMLRewriter() + .on("div", { + async element(element) { + await Bun.sleep(0); + throw new Error("test"); + }, + }) + .transform(new Response("
hello
")); + expect.unreachable(); + } catch (e) { + expect(e.message).toBe("test"); + } + }); + it("HTMLRewriter: async replacement", async () => { await gcTick(); const res = new HTMLRewriter()