Ciro/fix onclose refactor (#22556)

### What does this PR do?

### How did you verify your code works?
This commit is contained in:
Ciro Spaciari
2025-09-10 16:20:51 -07:00
committed by GitHub
parent 1a5660ba39
commit 477aa56aa4
10 changed files with 139 additions and 141 deletions

View File

@@ -21,8 +21,7 @@ poll_ref: bun.Async.KeepAlive = .{},
globalObject: *jsc.JSGlobalObject,
vm: *jsc.VirtualMachine,
has_pending_activity: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
js_value: JSValue = .js_undefined,
js_value: jsc.JSRef = jsc.JSRef.empty(),
server_version: bun.ByteList = .{},
connection_id: u32 = 0,
@@ -122,20 +121,14 @@ pub const AuthState = union(enum) {
};
};
pub fn hasPendingActivity(this: *MySQLConnection) bool {
return this.has_pending_activity.load(.acquire);
}
fn updateHasPendingActivity(this: *MySQLConnection) void {
if (this.requests.readableLength() > 0) {
this.has_pending_activity.store(true, .release);
return;
fn updateReferenceType(this: *MySQLConnection) void {
if (this.js_value.isNotEmpty()) {
if (this.requests.readableLength() > 0 or (this.status != .disconnected and this.status != .failed)) {
this.js_value.upgrade(this.globalObject);
return;
}
this.js_value.downgrade();
}
if (this.status != .disconnected and this.status != .failed) {
this.has_pending_activity.store(true, .release);
return;
}
this.has_pending_activity.store(false, .release);
}
fn hasDataToSend(this: *@This()) bool {
@@ -276,19 +269,19 @@ pub fn finalize(this: *MySQLConnection) void {
this.stopTimers();
debug("MySQLConnection finalize", .{});
this.js_value = .zero;
this.js_value.deinit();
this.deref();
}
pub fn doRef(this: *@This(), _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue {
this.poll_ref.ref(this.vm);
this.updateHasPendingActivity();
this.updateReferenceType();
return .js_undefined;
}
pub fn doUnref(this: *@This(), _: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!JSValue {
this.poll_ref.unref(this.vm);
this.updateHasPendingActivity();
this.updateReferenceType();
return .js_undefined;
}
@@ -355,8 +348,10 @@ pub fn stopTimers(this: *@This()) void {
}
pub fn getQueriesArray(this: *const @This()) JSValue {
if (this.js_value == .zero) return .js_undefined;
return js.queriesGetCached(this.js_value) orelse .js_undefined;
if (this.js_value.tryGet()) |value| {
return js.queriesGetCached(value) orelse .js_undefined;
}
return .js_undefined;
}
pub fn failFmt(this: *@This(), error_code: AnyMySQLError.Error, comptime fmt: [:0]const u8, args: anytype) void {
const message = bun.handleOom(std.fmt.allocPrint(bun.default_allocator, fmt, args));
@@ -366,7 +361,7 @@ pub fn failFmt(this: *@This(), error_code: AnyMySQLError.Error, comptime fmt: [:
this.failWithJSValue(err);
}
pub fn failWithJSValue(this: *MySQLConnection, value: JSValue) void {
defer this.updateHasPendingActivity();
defer this.updateReferenceType();
this.stopTimers();
if (this.status == .failed) return;
@@ -437,7 +432,7 @@ fn refAndClose(this: *@This(), js_reason: ?jsc.JSValue) void {
pub fn disconnect(this: *@This()) void {
this.stopTimers();
if (this.status == .connected) {
defer this.updateHasPendingActivity();
defer this.updateReferenceType();
this.status = .disconnected;
this.poll_ref.disable();
@@ -965,7 +960,7 @@ pub fn call(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JS
ptr.poll_ref.ref(vm);
const js_value = ptr.toJS(globalObject);
js_value.ensureStillAlive();
ptr.js_value = js_value;
ptr.js_value.setStrong(js_value, globalObject);
js.onconnectSetCached(js_value, globalObject, on_connect);
js.oncloseSetCached(js_value, globalObject, on_close);
@@ -1028,7 +1023,7 @@ pub fn onOpen(this: *MySQLConnection, socket: Socket) void {
this.ref(); // keep a ref for the socket
}
this.poll_ref.ref(this.vm);
this.updateHasPendingActivity();
this.updateReferenceType();
}
pub fn onHandshake(this: *MySQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void {
@@ -1166,11 +1161,12 @@ pub fn processPackets(this: *MySQLConnection, comptime Context: type, reader: Ne
// Read packet header
const header = PacketHeader.decode(reader.peek()) orelse return AnyMySQLError.Error.ShortRead;
const header_length = header.length;
const packet_length: usize = header_length + PacketHeader.size;
debug("sequence_id: {d} header: {d}", .{ this.sequence_id, header_length });
// Ensure we have the full packet
reader.ensureCapacity(header_length + PacketHeader.size) catch return AnyMySQLError.Error.ShortRead;
reader.ensureCapacity(packet_length) catch return AnyMySQLError.Error.ShortRead;
// always skip the full packet, we dont care about padding or unreaded bytes
defer reader.setOffsetFromStart(header_length + PacketHeader.size);
defer reader.setOffsetFromStart(packet_length);
reader.skip(PacketHeader.size);
// Update sequence id
@@ -1293,30 +1289,35 @@ fn handleHandshakeDecodePublicKey(this: *MySQLConnection, comptime Context: type
pub fn consumeOnConnectCallback(this: *const @This(), globalObject: *jsc.JSGlobalObject) ?jsc.JSValue {
debug("consumeOnConnectCallback", .{});
if (this.js_value == .zero) {
return null;
if (this.js_value.tryGet()) |value| {
const on_connect = js.onconnectGetCached(value) orelse return null;
debug("consumeOnConnectCallback exists", .{});
js.onconnectSetCached(value, globalObject, .zero);
if (on_connect == .zero) {
return null;
}
return on_connect;
}
const on_connect = js.onconnectGetCached(this.js_value) orelse return null;
debug("consumeOnConnectCallback exists", .{});
js.onconnectSetCached(this.js_value, globalObject, .zero);
return on_connect;
return null;
}
pub fn consumeOnCloseCallback(this: *const @This(), globalObject: *jsc.JSGlobalObject) ?jsc.JSValue {
debug("consumeOnCloseCallback", .{});
if (this.js_value == .zero) {
return null;
if (this.js_value.tryGet()) |value| {
const on_close = js.oncloseGetCached(value) orelse return null;
debug("consumeOnCloseCallback exists", .{});
js.oncloseSetCached(value, globalObject, .zero);
if (on_close == .zero) {
return null;
}
return on_close;
}
const on_close = js.oncloseGetCached(this.js_value) orelse return null;
debug("consumeOnCloseCallback exists", .{});
js.oncloseSetCached(this.js_value, globalObject, .zero);
return on_close;
return null;
}
pub fn setStatus(this: *@This(), status: ConnectionState) void {
if (this.status == status) return;
defer this.updateHasPendingActivity();
defer this.updateReferenceType();
this.status = status;
this.resetConnectionTimeout();
@@ -1326,12 +1327,8 @@ pub fn setStatus(this: *@This(), status: ConnectionState) void {
.connected => {
const on_connect = this.consumeOnConnectCallback(this.globalObject) orelse return;
on_connect.ensureStillAlive();
var js_value = this.js_value;
if (js_value == .zero) {
js_value = .js_undefined;
} else {
js_value.ensureStillAlive();
}
var js_value = this.js_value.tryGet() orelse .js_undefined;
js_value.ensureStillAlive();
this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ JSValue.jsNull(), js_value });
this.poll_ref.unref(this.vm);
},
@@ -1340,8 +1337,8 @@ pub fn setStatus(this: *@This(), status: ConnectionState) void {
}
pub fn updateRef(this: *@This()) void {
this.updateHasPendingActivity();
if (this.has_pending_activity.raw) {
this.updateReferenceType();
if (this.js_value == .strong) {
this.poll_ref.ref(this.vm);
} else {
this.poll_ref.unref(this.vm);
@@ -1799,7 +1796,7 @@ fn handleResultSetOK(this: *MySQLConnection, request: *MySQLQuery, statement: *M
request.onResult(
statement.result_count,
this.globalObject,
this.js_value,
this.js_value.tryGet() orelse .js_undefined,
this.flags.is_ready_for_query,
last_insert_id,
affected_rows,
@@ -1908,7 +1905,7 @@ pub fn handleResultSet(this: *MySQLConnection, comptime Context: type, reader: N
var cached_structure: ?CachedStructure = null;
switch (request.flags.result_mode) {
.objects => {
cached_structure = if (this.js_value == .zero) null else statement.structure(this.js_value, this.globalObject);
cached_structure = if (this.js_value.tryGet()) |value| statement.structure(value, this.globalObject) else null;
structure = cached_structure.?.jsValue() orelse .js_undefined;
},
.raw, .values => {
@@ -1918,7 +1915,7 @@ pub fn handleResultSet(this: *MySQLConnection, comptime Context: type, reader: N
defer row.deinit(allocator);
try row.decode(allocator, reader);
const pending_value = MySQLQuery.js.pendingValueGetCached(request.thisValue.get()) orelse .zero;
const pending_value = (if (request.thisValue.tryGet()) |value| MySQLQuery.js.pendingValueGetCached(value) else .js_undefined) orelse .js_undefined;
// Process row data
const row_value = row.toJS(
@@ -1936,8 +1933,10 @@ pub fn handleResultSet(this: *MySQLConnection, comptime Context: type, reader: N
}
statement.result_count += 1;
if (pending_value == .zero) {
MySQLQuery.js.pendingValueSetCached(request.thisValue.get(), this.globalObject, row_value);
if (pending_value.isEmptyOrUndefinedOrNull()) {
if (request.thisValue.tryGet()) |value| {
MySQLQuery.js.pendingValueSetCached(value, this.globalObject, row_value);
}
}
}
},