This commit is contained in:
Jarred Sumner
2024-11-25 03:23:27 -08:00
parent 42a63a128a
commit 7bbaa9f3df
4 changed files with 179 additions and 67 deletions

View File

@@ -67,6 +67,21 @@ pub const Capabilities = packed struct(u32) {
CLIENT_DEPRECATE_EOF: bool = false,
_padding: u7 = 0,
pub fn format(self: @This(), comptime _: []const u8, _: anytype, writer: anytype) !void {
var first = true;
inline for (comptime std.meta.fieldNames(Capabilities)) |field| {
if (@TypeOf(@field(self, field)) == bool) {
if (@field(self, field)) {
if (!first) {
try writer.writeAll(", ");
}
first = false;
try writer.writeAll(field);
}
}
}
}
pub fn toInt(this: Capabilities) u32 {
return @bitCast(this);
}
@@ -97,16 +112,28 @@ pub const AuthMethod = enum {
caching_sha2_password,
sha256_password,
pub fn fromString(str: []const u8) ?AuthMethod {
if (std.mem.eql(u8, str, "mysql_native_password")) {
return .mysql_native_password;
} else if (std.mem.eql(u8, str, "caching_sha2_password")) {
return .caching_sha2_password;
} else if (std.mem.eql(u8, str, "sha256_password")) {
return .sha256_password;
pub fn scramble(this: AuthMethod, password: []const u8, auth_data: []const u8, buf: *[32]u8) ![]u8 {
const len = scrambleLength(this);
switch (this) {
.mysql_native_password => @memcpy(buf[0..len], try protocol.Auth.mysql_native_password.scramble(password, auth_data)),
.caching_sha2_password => @memcpy(buf[0..len], try protocol.Auth.caching_sha2_password.scramble(password, auth_data)),
.sha256_password => @memcpy(buf[0..len], try protocol.Auth.mysql_native_password.scramble(password, auth_data)),
}
return null;
return buf[0..len];
}
pub fn scrambleLength(this: AuthMethod) usize {
return switch (this) {
.mysql_native_password => 20,
.caching_sha2_password => 32,
.sha256_password => 20,
};
}
const Map = bun.ComptimeEnumMap(AuthMethod);
pub const fromString = Map.get;
};
// MySQL connection status flags
@@ -127,6 +154,21 @@ pub const StatusFlags = packed struct {
SERVER_SESSION_STATE_CHANGED: bool = false,
_padding: u2 = 0,
pub fn format(self: @This(), comptime _: []const u8, _: anytype, writer: anytype) !void {
var first = true;
inline for (comptime std.meta.fieldNames(StatusFlags)) |field| {
if (@TypeOf(@field(self, field)) == bool) {
if (@field(self, field)) {
if (!first) {
try writer.writeAll(", ");
}
first = false;
try writer.writeAll(field);
}
}
}
}
pub fn toInt(this: StatusFlags) u16 {
return @bitCast(this);
}
@@ -195,7 +237,7 @@ pub const MySQLConnection = struct {
is_ready_for_query: bool = false,
server_version: Data = .{ .empty = {} },
server_version: bun.ByteList = .{},
connection_id: u32 = 0,
capabilities: Capabilities = .{},
character_set: u8 = 0,
@@ -278,6 +320,53 @@ pub const MySQLConnection = struct {
this.deref();
}
pub fn doRef(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue {
this.poll_ref.ref(this.globalObject.bunVM());
this.updateHasPendingActivity();
return .undefined;
}
pub fn doUnref(this: *@This(), _: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue {
this.poll_ref.unref(this.globalObject.bunVM());
this.updateHasPendingActivity();
return .undefined;
}
pub fn doFlush(this: *MySQLConnection, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue {
_ = callframe;
_ = globalObject;
_ = this;
return .undefined;
}
pub fn createQuery(this: *MySQLConnection, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue {
_ = callframe;
_ = globalObject;
_ = this;
return .undefined;
}
pub fn getConnected(this: *MySQLConnection, _: *JSC.JSGlobalObject) JSValue {
return JSValue.jsBoolean(this.status == .connected);
}
pub fn doClose(this: *MySQLConnection, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue {
_ = globalObject;
this.disconnect();
this.write_buffer.deinit(bun.default_allocator);
return .undefined;
}
pub fn constructor(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!*MySQLConnection {
_ = callframe;
globalObject.ERR_ILLEGAL_CONSTRUCTOR("MySQLConnection cannot be constructed directly", .{}).throw();
return error.JSError;
}
pub fn flushData(this: *MySQLConnection) void {
const chunk = this.write_buffer.remaining();
if (chunk.len == 0) return;
@@ -591,7 +680,7 @@ pub const MySQLConnection = struct {
var requests = this.requests;
defer requests.deinit();
this.requests = .{};
this.requests = Queue.init(bun.default_allocator);
// Clear any pending requests first
for (requests.readableSlice(0)) |request| {
@@ -600,9 +689,11 @@ pub const MySQLConnection = struct {
.error_message = .{ .temporary = "Connection closed" },
}, this.globalObject);
}
this.write_buffer.deinit();
this.read_buffer.deinit();
this.write_buffer.deinit(bun.default_allocator);
this.read_buffer.deinit(bun.default_allocator);
this.statements.deinit(bun.default_allocator);
bun.default_allocator.free(this.auth_data);
this.auth_data = "";
this.tls_config.deinit();
if (this.tls_ctx) |ctx| {
ctx.deinit(true);
@@ -617,7 +708,10 @@ pub const MySQLConnection = struct {
this.poll_ref.ref(this.globalObject.bunVM());
this.updateHasPendingActivity();
this.start();
this.start() catch |err| {
this.fail("Failed to start connection", err);
return;
};
}
pub fn onHandshake(this: *MySQLConnection, success: i32, ssl_error: uws.us_bun_verify_error_t) void {
@@ -760,7 +854,7 @@ pub const MySQLConnection = struct {
pub fn handleHandshake(this: *MySQLConnection, comptime Context: type, reader: protocol.NewReader(Context)) !void {
var handshake = protocol.HandshakeV10{};
try handshake.decode(Context, reader);
try handshake.decode(reader);
defer handshake.deinit();
// Store server info
@@ -770,15 +864,32 @@ pub const MySQLConnection = struct {
this.character_set = handshake.character_set;
this.status_flags = handshake.status_flags;
debug(
\\Handshake
\\ Server Version: {s}
\\ Connection ID: {d}
\\ Character Set: {d}
\\ Capabilities: [ {} ]
\\ Status Flags: [ {} ]
\\
, .{
this.server_version.slice(),
this.connection_id,
this.character_set,
this.capabilities,
this.status_flags,
});
if (this.auth_data.len > 0) {
bun.default_allocator.free(this.auth_data);
this.auth_data = "";
}
// Store auth data
this.auth_data = try bun.default_allocator.alloc(u8, handshake.auth_plugin_data_part_1.len + handshake.auth_plugin_data_part_2.len);
@memcpy(this.auth_data[0..8], &handshake.auth_plugin_data_part_1);
@memcpy(this.auth_data[8..], handshake.auth_plugin_data_part_2);
const auth_data = try bun.default_allocator.alloc(u8, handshake.auth_plugin_data_part_1.len + handshake.auth_plugin_data_part_2.len);
@memcpy(auth_data[0..8], &handshake.auth_plugin_data_part_1);
@memcpy(auth_data[8..], handshake.auth_plugin_data_part_2);
this.auth_data = auth_data;
// Get auth plugin
if (handshake.auth_plugin_name.slice().len > 0) {
@@ -802,7 +913,7 @@ pub const MySQLConnection = struct {
switch (first_byte) {
@intFromEnum(protocol.PacketType.OK) => {
var ok = protocol.OKPacket{};
try ok.decode(Context, reader);
try ok.decode(reader);
defer ok.deinit();
this.status = .connected;
@@ -812,7 +923,7 @@ pub const MySQLConnection = struct {
@intFromEnum(protocol.PacketType.ERROR) => {
var err = protocol.ErrorPacket{};
try err.decode(Context, reader);
try err.decode(reader);
defer err.deinit();
this.fail("Authentication failed", error.AuthenticationFailed);
@@ -820,7 +931,7 @@ pub const MySQLConnection = struct {
@intFromEnum(protocol.PacketType.AUTH_SWITCH) => {
var auth_switch = protocol.AuthSwitchRequest{};
try auth_switch.decode(Context, reader);
try auth_switch.decode(reader);
defer auth_switch.deinit();
// Update auth plugin and data
@@ -862,7 +973,7 @@ pub const MySQLConnection = struct {
},
.failed => {
// Statement failed, clean up
if (this.requests.popOrNull()) |req| {
if (this.requests.readItem()) |req| {
req.onError(statement.error_response, this.globalObject);
}
},
@@ -897,12 +1008,6 @@ pub const MySQLConnection = struct {
// Generate auth response based on plugin
if (this.auth_plugin) |plugin| {
switch (plugin) {
.mysql_native_password => @memcpy(scrambled_buf[0..20], try protocol.Auth.mysql_native_password.scramble(this.password, this.auth_data)),
.caching_sha2_password => @memcpy(scrambled_buf[0..32], try protocol.Auth.caching_sha2_password.scramble(this.password, this.auth_data)),
.sha256_password => @memcpy(scrambled_buf[0..20], try protocol.Auth.mysql_native_password.scramble(this.password, this.auth_data)),
}
response.auth_response = .{
.temporary = switch (plugin) {
.mysql_native_password => scrambled_buf[0..20],
@@ -912,7 +1017,7 @@ pub const MySQLConnection = struct {
};
}
try response.write(Writer, this.writer());
try response.writeInternal(Writer, this.writer());
this.flushData();
}
@@ -922,22 +1027,11 @@ pub const MySQLConnection = struct {
var scrambled_buf: [32]u8 = undefined;
// Generate auth response based on plugin
switch (auth_method) {
.mysql_native_password => @memcpy(scrambled_buf[0..20], try protocol.Auth.mysql_native_password.scramble(this.password, plugin_data)),
.caching_sha2_password => @memcpy(scrambled_buf[0..32], try protocol.Auth.caching_sha2_password.scramble(this.password, plugin_data)),
.sha256_password => @memcpy(scrambled_buf[0..20], try protocol.Auth.mysql_native_password.scramble(this.password, plugin_data)),
}
response.auth_response = .{
.temporary = switch (auth_method) {
.mysql_native_password => scrambled_buf[0..20],
.caching_sha2_password => scrambled_buf[0..32],
.sha256_password => scrambled_buf[0..20],
},
.temporary = try auth_method.scramble(this.password, plugin_data, &scrambled_buf),
};
try response.write(Writer, this.writer());
try response.writeInternal(Writer, this.writer());
this.flushData();
}
@@ -1043,7 +1137,7 @@ pub const MySQLConnection = struct {
switch (first_byte) {
@intFromEnum(protocol.PacketType.OK) => {
var ok = protocol.StmtPrepareOKPacket{};
try ok.decode(Context, reader);
try ok.decode(reader);
// Get the current request
const request = this.requests.peekItem(0);
@@ -1058,7 +1152,7 @@ pub const MySQLConnection = struct {
for (params) |*param| {
var column = protocol.ColumnDefinition41{};
defer column.deinit();
try column.decode(Context, reader);
try column.decode(reader);
param.* = column.column_type;
}
@@ -1077,7 +1171,7 @@ pub const MySQLConnection = struct {
}
for (columns) |*column| {
try column.decode(Context, reader);
try column.decode(reader);
consumed += 1;
}
@@ -1090,12 +1184,14 @@ pub const MySQLConnection = struct {
try request.bindAndExecute(this.writer(), statement, this.globalObject);
this.flushData();
}
} else {
debug("Unexpected prepared statement packet", .{});
}
},
@intFromEnum(protocol.PacketType.ERROR) => {
var err = protocol.ErrorPacket{};
try err.decode(Context, reader);
try err.decode(reader);
defer err.deinit();
if (this.requests.readItem()) |request| {
@@ -1121,23 +1217,23 @@ pub const MySQLConnection = struct {
switch (first_byte) {
@intFromEnum(protocol.PacketType.OK) => {
var ok = protocol.OKPacket{};
try ok.decode(Context, reader);
try ok.decode(reader);
defer ok.deinit();
if (this.requests.popOrNull()) |request| {
request.onSuccess(ok.affected_rows, ok.last_insert_id, this.globalObject);
}
this.status_flags = ok.status_flags;
this.is_ready_for_query = true;
if (this.requests.readItem()) |request| {
request.onSuccess(this.globalObject);
}
},
@intFromEnum(protocol.PacketType.ERROR) => {
var err = protocol.ErrorPacket{};
try err.decode(Context, reader);
try err.decode(reader);
defer err.deinit();
if (this.requests.popOrNull()) |request| {
if (this.requests.readItem()) |request| {
request.onError(err, this.globalObject);
}
},
@@ -1145,7 +1241,7 @@ pub const MySQLConnection = struct {
else => {
// This is likely a result set header
var header = protocol.ResultSetHeader{};
try header.decode(Context, reader);
try header.decode(reader);
if (this.requests.readableLength() > 0) {
const request = this.requests.peekItem(0);
@@ -1161,7 +1257,7 @@ pub const MySQLConnection = struct {
}
for (columns) |*column| {
try column.decode(Context, reader);
try column.decode(reader);
columns_read += 1;
}
@@ -1174,7 +1270,7 @@ pub const MySQLConnection = struct {
switch (row_first_byte) {
@intFromEnum(protocol.PacketType.EOF) => {
var eof = protocol.EOFPacket{};
try eof.decode(Context, reader);
try eof.decode(reader);
// Update status flags and finish
this.status_flags = eof.status_flags;
@@ -1187,7 +1283,7 @@ pub const MySQLConnection = struct {
@intFromEnum(protocol.PacketType.ERROR) => {
var err = protocol.ErrorPacket{};
try err.decode(Context, reader);
try err.decode(reader);
defer err.deinit();
this.requests.discard(1);
request.onError(err, this.globalObject);
@@ -1221,6 +1317,8 @@ pub const MySQLConnection = struct {
},
}
}
} else {
debug("Unexpected result set packet", .{});
}
},
}
@@ -1231,7 +1329,7 @@ pub const MySQLConnection = struct {
.statement_id = statement.statement_id,
};
try close.write(Writer, this.writer());
try close.writeInternal(Writer, this.writer());
this.flushData();
}
@@ -1240,7 +1338,7 @@ pub const MySQLConnection = struct {
.statement_id = statement.statement_id,
};
try reset.write(Writer, this.writer());
try reset.writeInternal(Writer, this.writer());
this.flushData();
}
};
@@ -1410,7 +1508,7 @@ pub const MySQLQuery = struct {
};
defer execute.deinit();
try this.bind(&execute, globalObject);
try execute.write(writer);
try execute.writeInternal(writer);
this.status = .written;
}
@@ -1430,7 +1528,14 @@ pub const MySQLQuery = struct {
}
while (iter.next()) |js_value| {
const param = execute.param_types[i];
const value = try Value.fromJS(js_value, globalObject, param, bun.default_allocator);
const value = try Value.fromJS(
js_value,
globalObject,
param,
// TODO: unsigned
false,
bun.default_allocator,
);
params[i] = try value.toData(param);
i += 1;
}
@@ -1541,8 +1646,7 @@ pub const MySQLQuery = struct {
}
var ptr = bun.default_allocator.create(MySQLQuery) catch |err| {
globalThis.throwError(err, "failed to allocate query");
return .zero;
return globalThis.throwError(err, "failed to allocate query");
};
const this_value = ptr.toJS(globalThis);
@@ -1564,11 +1668,17 @@ pub const MySQLQuery = struct {
return this_value;
}
pub fn doDone(this: *@This(), globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) bun.JSError!JSValue {
_ = globalObject;
this.is_done = true;
return .undefined;
}
pub fn doRun(this: *MySQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue {
var arguments_ = callframe.arguments_old(2);
const arguments = arguments_.slice();
var connection: *MySQLConnection = arguments[0].as(MySQLConnection) orelse {
globalObject.throw("connection must be a PostgresSQLConnection", .{});
globalObject.throw("connection must be a MySQLConnection", .{});
return error.JSError;
};
var query = arguments[1];
@@ -1702,7 +1812,7 @@ pub const Signature = struct {
return Signature{
.name = name.items,
.fields = fields.toOwnedSlice(),
.fields = fields.items,
.query = try bun.default_allocator.dupe(u8, query),
};
}

View File

@@ -141,11 +141,11 @@ pub fn NewWriterWrap(
}
pub fn int4(this: @This(), value: MySQLInt32) !void {
try this.write(std.mem.asBytes(value));
try this.write(&std.mem.toBytes(value));
}
pub fn int8(this: @This(), value: MySQLInt64) !void {
try this.write(std.mem.asBytes(value));
try this.write(&std.mem.toBytes(value));
}
pub fn int1(this: @This(), value: u8) !void {

View File

@@ -644,6 +644,7 @@ pub const Value = union(enum) {
pub fn toJS(this: *const Value, globalObject: *JSC.JSGlobalObject) JSValue {
return switch (this.*) {
.null => JSValue.jsNull(),
.bool => |b| JSValue.jsBoolean(b),
.string => |*str| {
var out = bun.String.createUTF8(str.items);
return out.transferToJS(globalObject);

View File

@@ -295,7 +295,8 @@ pub const PostgresSQLQuery = struct {
pub fn constructor(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!*PostgresSQLQuery {
_ = callframe;
return globalThis.throw2("PostgresSQLQuery cannot be constructed directly", .{});
globalThis.ERR_ILLEGAL_CONSTRUCTOR("PostgresSQLQuery cannot be constructed directly", .{}).throw();
return error.JSError;
}
pub fn estimatedSize(this: *PostgresSQLQuery) usize {