From 477aa56aa474088dd23d07c3f00379127eece280 Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Wed, 10 Sep 2025 16:20:51 -0700 Subject: [PATCH] Ciro/fix onclose refactor (#22556) ### What does this PR do? ### How did you verify your code works? --- src/bun.js/api/server.zig | 9 +- src/bun.js/api/server/RequestContext.zig | 2 +- src/bun.js/api/sql.classes.ts | 3 +- src/bun.js/bindings/JSRef.zig | 43 ++++++++-- src/js/bun/sql.ts | 30 +++---- src/sql/mysql/MySQLConnection.zig | 105 +++++++++++------------ src/sql/mysql/MySQLQuery.zig | 48 ++++------- src/sql/mysql/protocol/AnyMySQLError.zig | 2 + src/sql/postgres/PostgresSQLQuery.zig | 34 +++----- src/sql/shared/ObjectIterator.zig | 4 + 10 files changed, 139 insertions(+), 141 deletions(-) diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index f26f455258..bf75c92b1a 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -627,8 +627,8 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d } pub fn jsValueAssertAlive(server: *ThisServer) jsc.JSValue { - // With JSRef, we can safely access the JS value even after stop() via weak reference - return server.js_value.get(); + bun.assert(server.js_value.isNotEmpty()); + return server.js_value.tryGet().?; } pub fn requestIP(this: *ThisServer, request: *jsc.WebCore.Request) bun.JSError!jsc.JSValue { @@ -1124,7 +1124,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d this.onReloadFromZig(&new_config, globalThis); - return this.js_value.get(); + return this.js_value.tryGet() orelse .js_undefined; } pub fn onFetch(this: *ThisServer, ctx: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { @@ -1539,8 +1539,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d } pub fn stop(this: *ThisServer, abrupt: bool) void { - const current_value = this.js_value.get(); - this.js_value.setWeak(current_value); + this.js_value.downgrade(); if (this.config.allow_hot and this.config.id.len > 0) { if (this.globalThis.bunVM().hotMap()) |hot| { diff --git a/src/bun.js/api/server/RequestContext.zig b/src/bun.js/api/server/RequestContext.zig index e1c0097107..aab5b503f2 100644 --- a/src/bun.js/api/server/RequestContext.zig +++ b/src/bun.js/api/server/RequestContext.zig @@ -1981,7 +1981,7 @@ pub fn NewRequestContext(comptime ssl_enabled: bool, comptime debug_mode: bool, this.flags.has_called_error_handler = true; const result = server.config.onError.call( server.globalThis, - server.js_value.get(), + server.js_value.tryGet() orelse .js_undefined, &.{value}, ) catch |err| server.globalThis.takeException(err); defer result.ensureStillAlive(); diff --git a/src/bun.js/api/sql.classes.ts b/src/bun.js/api/sql.classes.ts index db29a3dc1f..14d7b5bca3 100644 --- a/src/bun.js/api/sql.classes.ts +++ b/src/bun.js/api/sql.classes.ts @@ -9,7 +9,7 @@ for (const type of types) { construct: true, finalize: true, configurable: false, - hasPendingActivity: true, + hasPendingActivity: type === "PostgresSQL", klass: { // escapeString: { // fn: "escapeString", @@ -60,7 +60,6 @@ for (const type of types) { construct: true, finalize: true, configurable: false, - JSType: "0b11101110", klass: {}, proto: { diff --git a/src/bun.js/bindings/JSRef.zig b/src/bun.js/bindings/JSRef.zig index a4f079b98d..46555db798 100644 --- a/src/bun.js/bindings/JSRef.zig +++ b/src/bun.js/bindings/JSRef.zig @@ -8,6 +8,7 @@ pub const JSRef = union(enum) { } pub fn initStrong(value: jsc.JSValue, globalThis: *jsc.JSGlobalObject) @This() { + bun.assert(value != .zero); return .{ .strong = .create(value, globalThis) }; } @@ -15,15 +16,7 @@ pub const JSRef = union(enum) { return .{ .weak = .zero }; } - pub fn get(this: *@This()) jsc.JSValue { - return switch (this.*) { - .weak => this.weak, - .strong => this.strong.get() orelse .zero, - .finalized => .zero, - }; - } - - pub fn tryGet(this: *@This()) ?jsc.JSValue { + pub fn tryGet(this: *const @This()) ?jsc.JSValue { return switch (this.*) { .weak => if (this.weak != .zero) this.weak else null, .strong => this.strong.get(), @@ -44,6 +37,7 @@ pub const JSRef = union(enum) { } pub fn setStrong(this: *@This(), value: jsc.JSValue, globalThis: *jsc.JSGlobalObject) void { + bun.assert(value != .zero); if (this.* == .strong) { this.strong.set(globalThis, value); return; @@ -64,6 +58,37 @@ pub const JSRef = union(enum) { } } + pub fn downgrade(this: *@This()) void { + switch (this.*) { + .weak => {}, + .strong => |*strong| { + const value = strong.get() orelse .zero; + value.ensureStillAlive(); + strong.deinit(); + this.* = .{ .weak = value }; + }, + .finalized => { + bun.debugAssert(false); + }, + } + } + + pub fn isEmpty(this: *const @This()) bool { + return switch (this.*) { + .weak => this.weak == .zero, + .strong => !this.strong.has(), + .finalized => true, + }; + } + + pub fn isNotEmpty(this: *const @This()) bool { + return switch (this.*) { + .weak => this.weak != .zero, + .strong => this.strong.has(), + .finalized => false, + }; + } + pub fn deinit(this: *@This()) void { switch (this.*) { .weak => { diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index db7b0eb871..dc063c0a1a 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -291,7 +291,7 @@ const SQL: typeof Bun.SQL = function SQL( reserved_sql.connect = () => { if (state.connectionState & ReservedConnectionState.closed) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } return Promise.$resolve(reserved_sql); }; @@ -322,7 +322,7 @@ const SQL: typeof Bun.SQL = function SQL( reserved_sql.beginDistributed = (name: string, fn: TransactionCallback) => { // begin is allowed the difference is that we need to make sure to use the same connection and never release it if (state.connectionState & ReservedConnectionState.closed) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } let callback = fn; @@ -346,7 +346,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } let callback = fn; let options: string | undefined = options_or_fn as unknown as string; @@ -369,7 +369,7 @@ const SQL: typeof Bun.SQL = function SQL( reserved_sql.flush = () => { if (state.connectionState & ReservedConnectionState.closed) { - throw this.connectionClosedError(); + throw pool.connectionClosedError(); } // Use pooled connection's flush if available, otherwise use adapter's flush if (pooledConnection.flush) { @@ -429,7 +429,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } // just release the connection back to the pool state.connectionState |= ReservedConnectionState.closed; @@ -552,7 +552,7 @@ const SQL: typeof Bun.SQL = function SQL( function run_internal_transaction_sql(string) { if (state.connectionState & ReservedConnectionState.closed) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } return unsafeQueryFromTransaction(string, [], pooledConnection, state.queries); } @@ -564,7 +564,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } if ($isArray(strings)) { // detect if is tagged template @@ -593,7 +593,7 @@ const SQL: typeof Bun.SQL = function SQL( transaction_sql.connect = () => { if (state.connectionState & ReservedConnectionState.closed) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } return Promise.$resolve(transaction_sql); @@ -732,7 +732,7 @@ const SQL: typeof Bun.SQL = function SQL( state.connectionState & ReservedConnectionState.closed || !(state.connectionState & ReservedConnectionState.acceptQueries) ) { - throw this.connectionClosedError(); + throw pool.connectionClosedError(); } if ($isCallable(name)) { @@ -816,7 +816,7 @@ const SQL: typeof Bun.SQL = function SQL( sql.reserve = () => { if (pool.closed) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } // Check if adapter supports reserved connections @@ -831,7 +831,7 @@ const SQL: typeof Bun.SQL = function SQL( }; sql.rollbackDistributed = async function (name: string) { if (pool.closed) { - throw this.connectionClosedError(); + throw pool.connectionClosedError(); } if (!pool.getRollbackDistributedSQL) { @@ -844,7 +844,7 @@ const SQL: typeof Bun.SQL = function SQL( sql.commitDistributed = async function (name: string) { if (pool.closed) { - throw this.connectionClosedError(); + throw pool.connectionClosedError(); } if (!pool.getCommitDistributedSQL) { @@ -857,7 +857,7 @@ const SQL: typeof Bun.SQL = function SQL( sql.beginDistributed = (name: string, fn: TransactionCallback) => { if (pool.closed) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } let callback = fn; @@ -876,7 +876,7 @@ const SQL: typeof Bun.SQL = function SQL( sql.begin = (options_or_fn: string | TransactionCallback, fn?: TransactionCallback) => { if (pool.closed) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } let callback = fn; let options: string | undefined = options_or_fn as unknown as string; @@ -896,7 +896,7 @@ const SQL: typeof Bun.SQL = function SQL( }; sql.connect = () => { if (pool.closed) { - return Promise.$reject(this.connectionClosedError()); + return Promise.$reject(pool.connectionClosedError()); } if (pool.isConnected()) { diff --git a/src/sql/mysql/MySQLConnection.zig b/src/sql/mysql/MySQLConnection.zig index 64e98b11b3..bb66decc94 100644 --- a/src/sql/mysql/MySQLConnection.zig +++ b/src/sql/mysql/MySQLConnection.zig @@ -21,8 +21,7 @@ poll_ref: bun.Async.KeepAlive = .{}, globalObject: *jsc.JSGlobalObject, vm: *jsc.VirtualMachine, -has_pending_activity: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), -js_value: JSValue = .js_undefined, +js_value: jsc.JSRef = jsc.JSRef.empty(), server_version: bun.ByteList = .{}, connection_id: u32 = 0, @@ -122,20 +121,14 @@ pub const AuthState = union(enum) { }; }; -pub fn hasPendingActivity(this: *MySQLConnection) bool { - return this.has_pending_activity.load(.acquire); -} - -fn updateHasPendingActivity(this: *MySQLConnection) void { - if (this.requests.readableLength() > 0) { - this.has_pending_activity.store(true, .release); - return; +fn updateReferenceType(this: *MySQLConnection) void { + if (this.js_value.isNotEmpty()) { + if (this.requests.readableLength() > 0 or (this.status != .disconnected and this.status != .failed)) { + this.js_value.upgrade(this.globalObject); + return; + } + this.js_value.downgrade(); } - if (this.status != .disconnected and this.status != .failed) { - this.has_pending_activity.store(true, .release); - return; - } - this.has_pending_activity.store(false, .release); } fn hasDataToSend(this: *@This()) bool { @@ -276,19 +269,19 @@ pub fn finalize(this: *MySQLConnection) void { this.stopTimers(); debug("MySQLConnection finalize", .{}); - this.js_value = .zero; + this.js_value.deinit(); this.deref(); } pub fn doRef(this: *@This(), _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue { this.poll_ref.ref(this.vm); - this.updateHasPendingActivity(); + this.updateReferenceType(); return .js_undefined; } pub fn doUnref(this: *@This(), _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue { this.poll_ref.unref(this.vm); - this.updateHasPendingActivity(); + this.updateReferenceType(); return .js_undefined; } @@ -355,8 +348,10 @@ pub fn stopTimers(this: *@This()) void { } pub fn getQueriesArray(this: *const @This()) JSValue { - if (this.js_value == .zero) return .js_undefined; - return js.queriesGetCached(this.js_value) orelse .js_undefined; + if (this.js_value.tryGet()) |value| { + return js.queriesGetCached(value) orelse .js_undefined; + } + return .js_undefined; } pub fn failFmt(this: *@This(), error_code: AnyMySQLError.Error, comptime fmt: [:0]const u8, args: anytype) void { const message = bun.handleOom(std.fmt.allocPrint(bun.default_allocator, fmt, args)); @@ -366,7 +361,7 @@ pub fn failFmt(this: *@This(), error_code: AnyMySQLError.Error, comptime fmt: [: this.failWithJSValue(err); } pub fn failWithJSValue(this: *MySQLConnection, value: JSValue) void { - defer this.updateHasPendingActivity(); + defer this.updateReferenceType(); this.stopTimers(); if (this.status == .failed) return; @@ -437,7 +432,7 @@ fn refAndClose(this: *@This(), js_reason: ?jsc.JSValue) void { pub fn disconnect(this: *@This()) void { this.stopTimers(); if (this.status == .connected) { - defer this.updateHasPendingActivity(); + defer this.updateReferenceType(); this.status = .disconnected; this.poll_ref.disable(); @@ -965,7 +960,7 @@ pub fn call(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JS ptr.poll_ref.ref(vm); const js_value = ptr.toJS(globalObject); js_value.ensureStillAlive(); - ptr.js_value = js_value; + ptr.js_value.setStrong(js_value, globalObject); js.onconnectSetCached(js_value, globalObject, on_connect); js.oncloseSetCached(js_value, globalObject, on_close); @@ -1028,7 +1023,7 @@ pub fn onOpen(this: *MySQLConnection, socket: Socket) void { this.ref(); // keep a ref for the socket } this.poll_ref.ref(this.vm); - this.updateHasPendingActivity(); + this.updateReferenceType(); } pub fn onHandshake(this: *MySQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void { @@ -1166,11 +1161,12 @@ pub fn processPackets(this: *MySQLConnection, comptime Context: type, reader: Ne // Read packet header const header = PacketHeader.decode(reader.peek()) orelse return AnyMySQLError.Error.ShortRead; const header_length = header.length; + const packet_length: usize = header_length + PacketHeader.size; debug("sequence_id: {d} header: {d}", .{ this.sequence_id, header_length }); // Ensure we have the full packet - reader.ensureCapacity(header_length + PacketHeader.size) catch return AnyMySQLError.Error.ShortRead; + reader.ensureCapacity(packet_length) catch return AnyMySQLError.Error.ShortRead; // always skip the full packet, we dont care about padding or unreaded bytes - defer reader.setOffsetFromStart(header_length + PacketHeader.size); + defer reader.setOffsetFromStart(packet_length); reader.skip(PacketHeader.size); // Update sequence id @@ -1293,30 +1289,35 @@ fn handleHandshakeDecodePublicKey(this: *MySQLConnection, comptime Context: type pub fn consumeOnConnectCallback(this: *const @This(), globalObject: *jsc.JSGlobalObject) ?jsc.JSValue { debug("consumeOnConnectCallback", .{}); - if (this.js_value == .zero) { - return null; + if (this.js_value.tryGet()) |value| { + const on_connect = js.onconnectGetCached(value) orelse return null; + debug("consumeOnConnectCallback exists", .{}); + js.onconnectSetCached(value, globalObject, .zero); + if (on_connect == .zero) { + return null; + } + return on_connect; } - const on_connect = js.onconnectGetCached(this.js_value) orelse return null; - debug("consumeOnConnectCallback exists", .{}); - - js.onconnectSetCached(this.js_value, globalObject, .zero); - return on_connect; + return null; } pub fn consumeOnCloseCallback(this: *const @This(), globalObject: *jsc.JSGlobalObject) ?jsc.JSValue { debug("consumeOnCloseCallback", .{}); - if (this.js_value == .zero) { - return null; + if (this.js_value.tryGet()) |value| { + const on_close = js.oncloseGetCached(value) orelse return null; + debug("consumeOnCloseCallback exists", .{}); + js.oncloseSetCached(value, globalObject, .zero); + if (on_close == .zero) { + return null; + } + return on_close; } - const on_close = js.oncloseGetCached(this.js_value) orelse return null; - debug("consumeOnCloseCallback exists", .{}); - js.oncloseSetCached(this.js_value, globalObject, .zero); - return on_close; + return null; } pub fn setStatus(this: *@This(), status: ConnectionState) void { if (this.status == status) return; - defer this.updateHasPendingActivity(); + defer this.updateReferenceType(); this.status = status; this.resetConnectionTimeout(); @@ -1326,12 +1327,8 @@ pub fn setStatus(this: *@This(), status: ConnectionState) void { .connected => { const on_connect = this.consumeOnConnectCallback(this.globalObject) orelse return; on_connect.ensureStillAlive(); - var js_value = this.js_value; - if (js_value == .zero) { - js_value = .js_undefined; - } else { - js_value.ensureStillAlive(); - } + var js_value = this.js_value.tryGet() orelse .js_undefined; + js_value.ensureStillAlive(); this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ JSValue.jsNull(), js_value }); this.poll_ref.unref(this.vm); }, @@ -1340,8 +1337,8 @@ pub fn setStatus(this: *@This(), status: ConnectionState) void { } pub fn updateRef(this: *@This()) void { - this.updateHasPendingActivity(); - if (this.has_pending_activity.raw) { + this.updateReferenceType(); + if (this.js_value == .strong) { this.poll_ref.ref(this.vm); } else { this.poll_ref.unref(this.vm); @@ -1799,7 +1796,7 @@ fn handleResultSetOK(this: *MySQLConnection, request: *MySQLQuery, statement: *M request.onResult( statement.result_count, this.globalObject, - this.js_value, + this.js_value.tryGet() orelse .js_undefined, this.flags.is_ready_for_query, last_insert_id, affected_rows, @@ -1908,7 +1905,7 @@ pub fn handleResultSet(this: *MySQLConnection, comptime Context: type, reader: N var cached_structure: ?CachedStructure = null; switch (request.flags.result_mode) { .objects => { - cached_structure = if (this.js_value == .zero) null else statement.structure(this.js_value, this.globalObject); + cached_structure = if (this.js_value.tryGet()) |value| statement.structure(value, this.globalObject) else null; structure = cached_structure.?.jsValue() orelse .js_undefined; }, .raw, .values => { @@ -1918,7 +1915,7 @@ pub fn handleResultSet(this: *MySQLConnection, comptime Context: type, reader: N defer row.deinit(allocator); try row.decode(allocator, reader); - const pending_value = MySQLQuery.js.pendingValueGetCached(request.thisValue.get()) orelse .zero; + const pending_value = (if (request.thisValue.tryGet()) |value| MySQLQuery.js.pendingValueGetCached(value) else .js_undefined) orelse .js_undefined; // Process row data const row_value = row.toJS( @@ -1936,8 +1933,10 @@ pub fn handleResultSet(this: *MySQLConnection, comptime Context: type, reader: N } statement.result_count += 1; - if (pending_value == .zero) { - MySQLQuery.js.pendingValueSetCached(request.thisValue.get(), this.globalObject, row_value); + if (pending_value.isEmptyOrUndefinedOrNull()) { + if (request.thisValue.tryGet()) |value| { + MySQLQuery.js.pendingValueSetCached(value, this.globalObject, row_value); + } } } }, diff --git a/src/sql/mysql/MySQLQuery.zig b/src/sql/mysql/MySQLQuery.zig index 2fb9a5e840..2388240c9d 100644 --- a/src/sql/mysql/MySQLQuery.zig +++ b/src/sql/mysql/MySQLQuery.zig @@ -1,5 +1,5 @@ const MySQLQuery = @This(); -const RefCount = bun.ptr.ThreadSafeRefCount(@This(), "ref_count", deinit, .{}); +const RefCount = bun.ptr.RefCount(@This(), "ref_count", deinit, .{}); statement: ?*MySQLStatement = null, query: bun.String = bun.String.empty, @@ -42,10 +42,6 @@ pub const Status = enum(u8) { } }; -pub fn hasPendingActivity(this: *@This()) bool { - return this.ref_count.load(.monotonic) > 1; -} - pub fn deinit(this: *@This()) void { this.thisValue.deinit(); if (this.statement) |statement| { @@ -66,11 +62,7 @@ pub fn finalize(this: *@This()) void { this.statement = null; } - if (this.thisValue == .weak) { - // clean up if is a weak reference, if is a strong reference we need to wait until the query is done - // if we are a strong reference, here is probably a bug because GC'd should not happen - this.thisValue.weak = .zero; - } + this.thisValue.deinit(); this.deref(); } @@ -81,12 +73,9 @@ pub fn onWriteFail( queries_array: JSValue, ) void { this.status = .fail; - const thisValue = this.thisValue.get(); + const thisValue = this.thisValue.tryGet() orelse return; defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject, true); - if (thisValue == .zero or targetValue == .zero) { - return; - } + const targetValue = this.getTarget(globalObject, true) orelse return; const instance = AnyMySQLError.mysqlErrorToJS(globalObject, "Failed to bind query", err); @@ -122,9 +111,9 @@ pub fn bindAndExecute(this: *MySQLQuery, writer: anytype, statement: *MySQLState } fn bind(this: *MySQLQuery, execute: *PreparedStatement.Execute, globalObject: *jsc.JSGlobalObject) AnyMySQLError.Error!void { - const thisValue = this.thisValue.get(); - const binding_value = js.bindingGetCached(thisValue) orelse .zero; - const columns_value = js.columnsGetCached(thisValue) orelse .zero; + const thisValue = this.thisValue.tryGet() orelse return error.InvalidState; + const binding_value = js.bindingGetCached(thisValue) orelse .js_undefined; + const columns_value = js.columnsGetCached(thisValue) orelse .js_undefined; var iter = try QueryBindingIterator.init(binding_value, columns_value, globalObject); @@ -165,12 +154,9 @@ pub fn onJSError(this: *@This(), err: jsc.JSValue, globalObject: *jsc.JSGlobalOb this.ref(); defer this.deref(); this.status = .fail; - const thisValue = this.thisValue.get(); + const thisValue = this.thisValue.tryGet() orelse return; defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject, true); - if (thisValue == .zero or targetValue == .zero) { - return; - } + const targetValue = this.getTarget(globalObject, true) orelse return; var vm = jsc.VirtualMachine.get(); const function = vm.rareData().mysql_context.onQueryRejectFn.get().?; @@ -185,9 +171,9 @@ pub fn onJSError(this: *@This(), err: jsc.JSValue, globalObject: *jsc.JSGlobalOb js_error, }); } -pub fn getTarget(this: *@This(), globalObject: *jsc.JSGlobalObject, clean_target: bool) jsc.JSValue { - const thisValue = this.thisValue.tryGet() orelse return .zero; - const target = js.targetGetCached(thisValue) orelse return .zero; +pub fn getTarget(this: *@This(), globalObject: *jsc.JSGlobalObject, clean_target: bool) ?jsc.JSValue { + const thisValue = this.thisValue.tryGet() orelse return null; + const target = js.targetGetCached(thisValue) orelse return null; if (clean_target) { js.targetSetCached(thisValue, globalObject, .zero); } @@ -222,20 +208,18 @@ pub fn onResult(this: *@This(), result_count: u64, globalObject: *jsc.JSGlobalOb this.ref(); defer this.deref(); - const thisValue = this.thisValue.get(); - const targetValue = this.getTarget(globalObject, is_last); if (is_last) { this.status = .success; } else { this.status = .partial_response; } + const thisValue = this.thisValue.tryGet() orelse return; + defer if (is_last) { allowGC(thisValue, globalObject); this.thisValue.deinit(); }; - if (thisValue == .zero or targetValue == .zero) { - return; - } + const targetValue = this.getTarget(globalObject, is_last) orelse return; const vm = jsc.VirtualMachine.get(); const function = vm.rareData().mysql_context.onQueryResolveFn.get().?; @@ -373,7 +357,7 @@ pub fn doRun(this: *MySQLQuery, globalObject: *jsc.JSGlobalObject, callframe: *j return globalObject.throw("connection must be a MySQLConnection", .{}); }; - connection.poll_ref.ref(globalObject.bunVM()); + defer connection.updateRef(); var query = arguments[1]; if (!query.isObject()) { diff --git a/src/sql/mysql/protocol/AnyMySQLError.zig b/src/sql/mysql/protocol/AnyMySQLError.zig index 66e5b6aa26..f93a28f383 100644 --- a/src/sql/mysql/protocol/AnyMySQLError.zig +++ b/src/sql/mysql/protocol/AnyMySQLError.zig @@ -34,6 +34,7 @@ pub const Error = error{ UnexpectedPacket, ShortRead, UnknownError, + InvalidState, }; pub fn mysqlErrorToJS(globalObject: *jsc.JSGlobalObject, message: ?[]const u8, err: Error) JSValue { @@ -66,6 +67,7 @@ pub fn mysqlErrorToJS(globalObject: *jsc.JSGlobalObject, message: ?[]const u8, e error.FailedToEncryptPassword => "ERR_MYSQL_FAILED_TO_ENCRYPT_PASSWORD", error.InvalidPublicKey => "ERR_MYSQL_INVALID_PUBLIC_KEY", error.UnknownError => "ERR_MYSQL_UNKNOWN_ERROR", + error.InvalidState => "ERR_MYSQL_INVALID_STATE", error.JSError => { return globalObject.takeException(error.JSError); }, diff --git a/src/sql/postgres/PostgresSQLQuery.zig b/src/sql/postgres/PostgresSQLQuery.zig index 9860b0273a..12e3dd4fb3 100644 --- a/src/sql/postgres/PostgresSQLQuery.zig +++ b/src/sql/postgres/PostgresSQLQuery.zig @@ -1,5 +1,5 @@ const PostgresSQLQuery = @This(); -const RefCount = bun.ptr.ThreadSafeRefCount(@This(), "ref_count", deinit, .{}); +const RefCount = bun.ptr.RefCount(@This(), "ref_count", deinit, .{}); statement: ?*PostgresSQLStatement = null, query: bun.String = bun.String.empty, cursor_name: bun.String = bun.String.empty, @@ -23,9 +23,9 @@ flags: packed struct(u8) { pub const ref = RefCount.ref; pub const deref = RefCount.deref; -pub fn getTarget(this: *PostgresSQLQuery, globalObject: *jsc.JSGlobalObject, clean_target: bool) jsc.JSValue { - const thisValue = this.thisValue.tryGet() orelse return .zero; - const target = js.targetGetCached(thisValue) orelse return .zero; +pub fn getTarget(this: *PostgresSQLQuery, globalObject: *jsc.JSGlobalObject, clean_target: bool) ?jsc.JSValue { + const thisValue = this.thisValue.tryGet() orelse return null; + const target = js.targetGetCached(thisValue) orelse return null; if (clean_target) { js.targetSetCached(thisValue, globalObject, .zero); } @@ -51,10 +51,6 @@ pub const Status = enum(u8) { } }; -pub fn hasPendingActivity(this: *@This()) bool { - return this.ref_count.get() > 1; -} - pub fn deinit(this: *@This()) void { this.thisValue.deinit(); if (this.statement) |statement| { @@ -84,12 +80,9 @@ pub fn onWriteFail( this.ref(); defer this.deref(); this.status = .fail; - const thisValue = this.thisValue.get(); + const thisValue = this.thisValue.tryGet() orelse return; defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject, true); - if (thisValue == .zero or targetValue == .zero) { - return; - } + const targetValue = this.getTarget(globalObject, true) orelse return; const vm = jsc.VirtualMachine.get(); const function = vm.rareData().postgresql_context.onQueryRejectFn.get().?; @@ -105,12 +98,9 @@ pub fn onJSError(this: *@This(), err: jsc.JSValue, globalObject: *jsc.JSGlobalOb this.ref(); defer this.deref(); this.status = .fail; - const thisValue = this.thisValue.get(); + const thisValue = this.thisValue.tryGet() orelse return; defer this.thisValue.deinit(); - const targetValue = this.getTarget(globalObject, true); - if (thisValue == .zero or targetValue == .zero) { - return; - } + const targetValue = this.getTarget(globalObject, true) orelse return; var vm = jsc.VirtualMachine.get(); const function = vm.rareData().postgresql_context.onQueryRejectFn.get().?; @@ -145,21 +135,17 @@ fn consumePendingValue(thisValue: jsc.JSValue, globalObject: *jsc.JSGlobalObject pub fn onResult(this: *@This(), command_tag_str: []const u8, globalObject: *jsc.JSGlobalObject, connection: jsc.JSValue, is_last: bool) void { this.ref(); defer this.deref(); - - const thisValue = this.thisValue.get(); - const targetValue = this.getTarget(globalObject, is_last); if (is_last) { this.status = .success; } else { this.status = .partial_response; } + const thisValue = this.thisValue.tryGet() orelse return; defer if (is_last) { allowGC(thisValue, globalObject); this.thisValue.deinit(); }; - if (thisValue == .zero or targetValue == .zero) { - return; - } + const targetValue = this.getTarget(globalObject, is_last) orelse return; const vm = jsc.VirtualMachine.get(); const function = vm.rareData().postgresql_context.onQueryResolveFn.get().?; diff --git a/src/sql/shared/ObjectIterator.zig b/src/sql/shared/ObjectIterator.zig index 1e3fbb6de5..fb5f7eeb29 100644 --- a/src/sql/shared/ObjectIterator.zig +++ b/src/sql/shared/ObjectIterator.zig @@ -11,6 +11,10 @@ array_length: usize = 0, any_failed: bool = false, pub fn next(this: *ObjectIterator) ?jsc.JSValue { + if (this.array.isEmptyOrUndefinedOrNull() or this.columns.isEmptyOrUndefinedOrNull()) { + this.any_failed = true; + return null; + } if (this.row_i >= this.array_length) { return null; }