diff --git a/src/sql/mysql.zig b/src/sql/mysql.zig index 5315fa8fe4..f804a410a5 100644 --- a/src/sql/mysql.zig +++ b/src/sql/mysql.zig @@ -67,6 +67,21 @@ pub const Capabilities = packed struct(u32) { CLIENT_DEPRECATE_EOF: bool = false, _padding: u7 = 0, + pub fn format(self: @This(), comptime _: []const u8, _: anytype, writer: anytype) !void { + var first = true; + inline for (comptime std.meta.fieldNames(Capabilities)) |field| { + if (@TypeOf(@field(self, field)) == bool) { + if (@field(self, field)) { + if (!first) { + try writer.writeAll(", "); + } + first = false; + try writer.writeAll(field); + } + } + } + } + pub fn toInt(this: Capabilities) u32 { return @bitCast(this); } @@ -97,16 +112,28 @@ pub const AuthMethod = enum { caching_sha2_password, sha256_password, - pub fn fromString(str: []const u8) ?AuthMethod { - if (std.mem.eql(u8, str, "mysql_native_password")) { - return .mysql_native_password; - } else if (std.mem.eql(u8, str, "caching_sha2_password")) { - return .caching_sha2_password; - } else if (std.mem.eql(u8, str, "sha256_password")) { - return .sha256_password; + pub fn scramble(this: AuthMethod, password: []const u8, auth_data: []const u8, buf: *[32]u8) ![]u8 { + const len = scrambleLength(this); + switch (this) { + .mysql_native_password => @memcpy(buf[0..len], try protocol.Auth.mysql_native_password.scramble(password, auth_data)), + .caching_sha2_password => @memcpy(buf[0..len], try protocol.Auth.caching_sha2_password.scramble(password, auth_data)), + .sha256_password => @memcpy(buf[0..len], try protocol.Auth.mysql_native_password.scramble(password, auth_data)), } - return null; + + return buf[0..len]; } + + pub fn scrambleLength(this: AuthMethod) usize { + return switch (this) { + .mysql_native_password => 20, + .caching_sha2_password => 32, + .sha256_password => 20, + }; + } + + const Map = bun.ComptimeEnumMap(AuthMethod); + + pub const fromString = Map.get; }; // MySQL connection status flags @@ -127,6 +154,21 @@ pub const StatusFlags = packed struct { SERVER_SESSION_STATE_CHANGED: bool = false, _padding: u2 = 0, + pub fn format(self: @This(), comptime _: []const u8, _: anytype, writer: anytype) !void { + var first = true; + inline for (comptime std.meta.fieldNames(StatusFlags)) |field| { + if (@TypeOf(@field(self, field)) == bool) { + if (@field(self, field)) { + if (!first) { + try writer.writeAll(", "); + } + first = false; + try writer.writeAll(field); + } + } + } + } + pub fn toInt(this: StatusFlags) u16 { return @bitCast(this); } @@ -195,7 +237,7 @@ pub const MySQLConnection = struct { is_ready_for_query: bool = false, - server_version: Data = .{ .empty = {} }, + server_version: bun.ByteList = .{}, connection_id: u32 = 0, capabilities: Capabilities = .{}, character_set: u8 = 0, @@ -278,6 +320,53 @@ pub const MySQLConnection = struct { this.deref(); } + pub fn doRef(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { + this.poll_ref.ref(this.globalObject.bunVM()); + this.updateHasPendingActivity(); + return .undefined; + } + + pub fn doUnref(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { + this.poll_ref.unref(this.globalObject.bunVM()); + this.updateHasPendingActivity(); + return .undefined; + } + + pub fn doFlush(this: *MySQLConnection, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + _ = callframe; + _ = globalObject; + _ = this; + + return .undefined; + } + + pub fn createQuery(this: *MySQLConnection, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + _ = callframe; + _ = globalObject; + _ = this; + + return .undefined; + } + + pub fn getConnected(this: *MySQLConnection, _: *JSC.JSGlobalObject) JSValue { + return JSValue.jsBoolean(this.status == .connected); + } + + pub fn doClose(this: *MySQLConnection, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { + _ = globalObject; + this.disconnect(); + this.write_buffer.deinit(bun.default_allocator); + + return .undefined; + } + + pub fn constructor(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!*MySQLConnection { + _ = callframe; + + globalObject.ERR_ILLEGAL_CONSTRUCTOR("MySQLConnection cannot be constructed directly", .{}).throw(); + return error.JSError; + } + pub fn flushData(this: *MySQLConnection) void { const chunk = this.write_buffer.remaining(); if (chunk.len == 0) return; @@ -591,7 +680,7 @@ pub const MySQLConnection = struct { var requests = this.requests; defer requests.deinit(); - this.requests = .{}; + this.requests = Queue.init(bun.default_allocator); // Clear any pending requests first for (requests.readableSlice(0)) |request| { @@ -600,9 +689,11 @@ pub const MySQLConnection = struct { .error_message = .{ .temporary = "Connection closed" }, }, this.globalObject); } - this.write_buffer.deinit(); - this.read_buffer.deinit(); + this.write_buffer.deinit(bun.default_allocator); + this.read_buffer.deinit(bun.default_allocator); this.statements.deinit(bun.default_allocator); + bun.default_allocator.free(this.auth_data); + this.auth_data = ""; this.tls_config.deinit(); if (this.tls_ctx) |ctx| { ctx.deinit(true); @@ -617,7 +708,10 @@ pub const MySQLConnection = struct { this.poll_ref.ref(this.globalObject.bunVM()); this.updateHasPendingActivity(); - this.start(); + this.start() catch |err| { + this.fail("Failed to start connection", err); + return; + }; } pub fn onHandshake(this: *MySQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void { @@ -760,7 +854,7 @@ pub const MySQLConnection = struct { pub fn handleHandshake(this: *MySQLConnection, comptime Context: type, reader: protocol.NewReader(Context)) !void { var handshake = protocol.HandshakeV10{}; - try handshake.decode(Context, reader); + try handshake.decode(reader); defer handshake.deinit(); // Store server info @@ -770,15 +864,32 @@ pub const MySQLConnection = struct { this.character_set = handshake.character_set; this.status_flags = handshake.status_flags; + debug( + \\Handshake + \\ Server Version: {s} + \\ Connection ID: {d} + \\ Character Set: {d} + \\ Capabilities: [ {} ] + \\ Status Flags: [ {} ] + \\ + , .{ + this.server_version.slice(), + this.connection_id, + this.character_set, + this.capabilities, + this.status_flags, + }); + if (this.auth_data.len > 0) { bun.default_allocator.free(this.auth_data); this.auth_data = ""; } // Store auth data - this.auth_data = try bun.default_allocator.alloc(u8, handshake.auth_plugin_data_part_1.len + handshake.auth_plugin_data_part_2.len); - @memcpy(this.auth_data[0..8], &handshake.auth_plugin_data_part_1); - @memcpy(this.auth_data[8..], handshake.auth_plugin_data_part_2); + const auth_data = try bun.default_allocator.alloc(u8, handshake.auth_plugin_data_part_1.len + handshake.auth_plugin_data_part_2.len); + @memcpy(auth_data[0..8], &handshake.auth_plugin_data_part_1); + @memcpy(auth_data[8..], handshake.auth_plugin_data_part_2); + this.auth_data = auth_data; // Get auth plugin if (handshake.auth_plugin_name.slice().len > 0) { @@ -802,7 +913,7 @@ pub const MySQLConnection = struct { switch (first_byte) { @intFromEnum(protocol.PacketType.OK) => { var ok = protocol.OKPacket{}; - try ok.decode(Context, reader); + try ok.decode(reader); defer ok.deinit(); this.status = .connected; @@ -812,7 +923,7 @@ pub const MySQLConnection = struct { @intFromEnum(protocol.PacketType.ERROR) => { var err = protocol.ErrorPacket{}; - try err.decode(Context, reader); + try err.decode(reader); defer err.deinit(); this.fail("Authentication failed", error.AuthenticationFailed); @@ -820,7 +931,7 @@ pub const MySQLConnection = struct { @intFromEnum(protocol.PacketType.AUTH_SWITCH) => { var auth_switch = protocol.AuthSwitchRequest{}; - try auth_switch.decode(Context, reader); + try auth_switch.decode(reader); defer auth_switch.deinit(); // Update auth plugin and data @@ -862,7 +973,7 @@ pub const MySQLConnection = struct { }, .failed => { // Statement failed, clean up - if (this.requests.popOrNull()) |req| { + if (this.requests.readItem()) |req| { req.onError(statement.error_response, this.globalObject); } }, @@ -897,12 +1008,6 @@ pub const MySQLConnection = struct { // Generate auth response based on plugin if (this.auth_plugin) |plugin| { - switch (plugin) { - .mysql_native_password => @memcpy(scrambled_buf[0..20], try protocol.Auth.mysql_native_password.scramble(this.password, this.auth_data)), - .caching_sha2_password => @memcpy(scrambled_buf[0..32], try protocol.Auth.caching_sha2_password.scramble(this.password, this.auth_data)), - .sha256_password => @memcpy(scrambled_buf[0..20], try protocol.Auth.mysql_native_password.scramble(this.password, this.auth_data)), - } - response.auth_response = .{ .temporary = switch (plugin) { .mysql_native_password => scrambled_buf[0..20], @@ -912,7 +1017,7 @@ pub const MySQLConnection = struct { }; } - try response.write(Writer, this.writer()); + try response.writeInternal(Writer, this.writer()); this.flushData(); } @@ -922,22 +1027,11 @@ pub const MySQLConnection = struct { var scrambled_buf: [32]u8 = undefined; - // Generate auth response based on plugin - switch (auth_method) { - .mysql_native_password => @memcpy(scrambled_buf[0..20], try protocol.Auth.mysql_native_password.scramble(this.password, plugin_data)), - .caching_sha2_password => @memcpy(scrambled_buf[0..32], try protocol.Auth.caching_sha2_password.scramble(this.password, plugin_data)), - .sha256_password => @memcpy(scrambled_buf[0..20], try protocol.Auth.mysql_native_password.scramble(this.password, plugin_data)), - } - response.auth_response = .{ - .temporary = switch (auth_method) { - .mysql_native_password => scrambled_buf[0..20], - .caching_sha2_password => scrambled_buf[0..32], - .sha256_password => scrambled_buf[0..20], - }, + .temporary = try auth_method.scramble(this.password, plugin_data, &scrambled_buf), }; - try response.write(Writer, this.writer()); + try response.writeInternal(Writer, this.writer()); this.flushData(); } @@ -1043,7 +1137,7 @@ pub const MySQLConnection = struct { switch (first_byte) { @intFromEnum(protocol.PacketType.OK) => { var ok = protocol.StmtPrepareOKPacket{}; - try ok.decode(Context, reader); + try ok.decode(reader); // Get the current request const request = this.requests.peekItem(0); @@ -1058,7 +1152,7 @@ pub const MySQLConnection = struct { for (params) |*param| { var column = protocol.ColumnDefinition41{}; defer column.deinit(); - try column.decode(Context, reader); + try column.decode(reader); param.* = column.column_type; } @@ -1077,7 +1171,7 @@ pub const MySQLConnection = struct { } for (columns) |*column| { - try column.decode(Context, reader); + try column.decode(reader); consumed += 1; } @@ -1090,12 +1184,14 @@ pub const MySQLConnection = struct { try request.bindAndExecute(this.writer(), statement, this.globalObject); this.flushData(); } + } else { + debug("Unexpected prepared statement packet", .{}); } }, @intFromEnum(protocol.PacketType.ERROR) => { var err = protocol.ErrorPacket{}; - try err.decode(Context, reader); + try err.decode(reader); defer err.deinit(); if (this.requests.readItem()) |request| { @@ -1121,23 +1217,23 @@ pub const MySQLConnection = struct { switch (first_byte) { @intFromEnum(protocol.PacketType.OK) => { var ok = protocol.OKPacket{}; - try ok.decode(Context, reader); + try ok.decode(reader); defer ok.deinit(); - if (this.requests.popOrNull()) |request| { - request.onSuccess(ok.affected_rows, ok.last_insert_id, this.globalObject); - } - this.status_flags = ok.status_flags; this.is_ready_for_query = true; + + if (this.requests.readItem()) |request| { + request.onSuccess(this.globalObject); + } }, @intFromEnum(protocol.PacketType.ERROR) => { var err = protocol.ErrorPacket{}; - try err.decode(Context, reader); + try err.decode(reader); defer err.deinit(); - if (this.requests.popOrNull()) |request| { + if (this.requests.readItem()) |request| { request.onError(err, this.globalObject); } }, @@ -1145,7 +1241,7 @@ pub const MySQLConnection = struct { else => { // This is likely a result set header var header = protocol.ResultSetHeader{}; - try header.decode(Context, reader); + try header.decode(reader); if (this.requests.readableLength() > 0) { const request = this.requests.peekItem(0); @@ -1161,7 +1257,7 @@ pub const MySQLConnection = struct { } for (columns) |*column| { - try column.decode(Context, reader); + try column.decode(reader); columns_read += 1; } @@ -1174,7 +1270,7 @@ pub const MySQLConnection = struct { switch (row_first_byte) { @intFromEnum(protocol.PacketType.EOF) => { var eof = protocol.EOFPacket{}; - try eof.decode(Context, reader); + try eof.decode(reader); // Update status flags and finish this.status_flags = eof.status_flags; @@ -1187,7 +1283,7 @@ pub const MySQLConnection = struct { @intFromEnum(protocol.PacketType.ERROR) => { var err = protocol.ErrorPacket{}; - try err.decode(Context, reader); + try err.decode(reader); defer err.deinit(); this.requests.discard(1); request.onError(err, this.globalObject); @@ -1221,6 +1317,8 @@ pub const MySQLConnection = struct { }, } } + } else { + debug("Unexpected result set packet", .{}); } }, } @@ -1231,7 +1329,7 @@ pub const MySQLConnection = struct { .statement_id = statement.statement_id, }; - try close.write(Writer, this.writer()); + try close.writeInternal(Writer, this.writer()); this.flushData(); } @@ -1240,7 +1338,7 @@ pub const MySQLConnection = struct { .statement_id = statement.statement_id, }; - try reset.write(Writer, this.writer()); + try reset.writeInternal(Writer, this.writer()); this.flushData(); } }; @@ -1410,7 +1508,7 @@ pub const MySQLQuery = struct { }; defer execute.deinit(); try this.bind(&execute, globalObject); - try execute.write(writer); + try execute.writeInternal(writer); this.status = .written; } @@ -1430,7 +1528,14 @@ pub const MySQLQuery = struct { } while (iter.next()) |js_value| { const param = execute.param_types[i]; - const value = try Value.fromJS(js_value, globalObject, param, bun.default_allocator); + const value = try Value.fromJS( + js_value, + globalObject, + param, + // TODO: unsigned + false, + bun.default_allocator, + ); params[i] = try value.toData(param); i += 1; } @@ -1541,8 +1646,7 @@ pub const MySQLQuery = struct { } var ptr = bun.default_allocator.create(MySQLQuery) catch |err| { - globalThis.throwError(err, "failed to allocate query"); - return .zero; + return globalThis.throwError(err, "failed to allocate query"); }; const this_value = ptr.toJS(globalThis); @@ -1564,11 +1668,17 @@ pub const MySQLQuery = struct { return this_value; } + pub fn doDone(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue { + _ = globalObject; + this.is_done = true; + return .undefined; + } + pub fn doRun(this: *MySQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { var arguments_ = callframe.arguments_old(2); const arguments = arguments_.slice(); var connection: *MySQLConnection = arguments[0].as(MySQLConnection) orelse { - globalObject.throw("connection must be a PostgresSQLConnection", .{}); + globalObject.throw("connection must be a MySQLConnection", .{}); return error.JSError; }; var query = arguments[1]; @@ -1702,7 +1812,7 @@ pub const Signature = struct { return Signature{ .name = name.items, - .fields = fields.toOwnedSlice(), + .fields = fields.items, .query = try bun.default_allocator.dupe(u8, query), }; } diff --git a/src/sql/mysql/mysql_protocol.zig b/src/sql/mysql/mysql_protocol.zig index 74dbef67c2..dfa1c9f904 100644 --- a/src/sql/mysql/mysql_protocol.zig +++ b/src/sql/mysql/mysql_protocol.zig @@ -141,11 +141,11 @@ pub fn NewWriterWrap( } pub fn int4(this: @This(), value: MySQLInt32) !void { - try this.write(std.mem.asBytes(value)); + try this.write(&std.mem.toBytes(value)); } pub fn int8(this: @This(), value: MySQLInt64) !void { - try this.write(std.mem.asBytes(value)); + try this.write(&std.mem.toBytes(value)); } pub fn int1(this: @This(), value: u8) !void { diff --git a/src/sql/mysql/mysql_types.zig b/src/sql/mysql/mysql_types.zig index 99d22e5085..9f61c432d9 100644 --- a/src/sql/mysql/mysql_types.zig +++ b/src/sql/mysql/mysql_types.zig @@ -644,6 +644,7 @@ pub const Value = union(enum) { pub fn toJS(this: *const Value, globalObject: *JSC.JSGlobalObject) JSValue { return switch (this.*) { .null => JSValue.jsNull(), + .bool => |b| JSValue.jsBoolean(b), .string => |*str| { var out = bun.String.createUTF8(str.items); return out.transferToJS(globalObject); diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index 1e5593a3a7..fe43064586 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -295,7 +295,8 @@ pub const PostgresSQLQuery = struct { pub fn constructor(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!*PostgresSQLQuery { _ = callframe; - return globalThis.throw2("PostgresSQLQuery cannot be constructed directly", .{}); + globalThis.ERR_ILLEGAL_CONSTRUCTOR("PostgresSQLQuery cannot be constructed directly", .{}).throw(); + return error.JSError; } pub fn estimatedSize(this: *PostgresSQLQuery) usize {