This commit is contained in:
Jarred Sumner
2024-11-25 02:55:52 -08:00
parent d1f3ceebdb
commit 42a63a128a
3 changed files with 396 additions and 62 deletions

View File

@@ -37,7 +37,7 @@ pub const SSLMode = enum(u8) {
verify_ca = 3,
verify_full = 4,
};
const Data = sql.Data;
pub const Data = sql.Data;
// MySQL capability flags
pub const Capabilities = packed struct(u32) {
CLIENT_LONG_PASSWORD: bool = false,
@@ -212,9 +212,12 @@ pub const MySQLConnection = struct {
on_connect: JSC.Strong = .{},
on_close: JSC.Strong = .{},
auth_data: []const u8 = "",
database: []const u8 = "",
user: []const u8 = "",
password: []const u8 = "",
options: []const u8 = "",
options_buf: []const u8 = "",
pub const AuthState = union(enum) {
pending: void,
@@ -314,8 +317,8 @@ pub const MySQLConnection = struct {
this.fail("Connection closed", error.ConnectionClosed);
}
fn start(this: *MySQLConnection) void {
this.sendHandshakeResponse();
fn start(this: *MySQLConnection) !void {
try this.sendHandshakeResponse();
const event_loop = this.globalObject.bunVM().eventLoop();
event_loop.enter();
@@ -359,27 +362,252 @@ pub const MySQLConnection = struct {
}
}
pub fn deinit(this: *@This()) void {
const Queue = std.fifo.LinearFifo(*MySQLQuery, .Dynamic);
fn SocketHandler(comptime ssl: bool) type {
return struct {
const SocketType = uws.NewSocketHandler(ssl);
fn _socket(s: SocketType) Socket {
if (comptime ssl) {
return Socket{ .SocketTLS = s };
}
return Socket{ .SocketTCP = s };
}
pub fn onOpen(this: *MySQLConnection, socket: SocketType) void {
this.onOpen(_socket(socket));
}
fn onHandshake_(this: *MySQLConnection, _: anytype, success: i32, ssl_error: uws.us_bun_verify_error_t) void {
this.onHandshake(success, ssl_error);
}
pub const onHandshake = if (ssl) onHandshake_ else null;
pub fn onClose(this: *MySQLConnection, socket: SocketType, _: i32, _: ?*anyopaque) void {
_ = socket;
this.onClose();
}
pub fn onEnd(this: *MySQLConnection, socket: SocketType) void {
_ = socket;
this.onClose();
}
pub fn onConnectError(this: *MySQLConnection, socket: SocketType, _: i32) void {
_ = socket;
this.onClose();
}
pub fn onTimeout(this: *MySQLConnection, socket: SocketType) void {
_ = socket;
this.onTimeout();
}
pub fn onData(this: *MySQLConnection, socket: SocketType, data: []const u8) void {
_ = socket;
this.onData(data);
}
pub fn onWritable(this: *MySQLConnection, socket: SocketType) void {
_ = socket;
this.onDrain();
}
};
}
pub fn onTimeout(this: *MySQLConnection) void {
this.fail("Connection timed out", error.ConnectionTimedOut);
}
pub fn onDrain(this: *MySQLConnection) void {
const event_loop = this.globalObject.bunVM().eventLoop();
event_loop.enter();
defer event_loop.exit();
this.flushData();
}
pub fn call(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue {
var vm = globalObject.bunVM();
const arguments = callframe.arguments_old(10).slice();
const hostname_str = arguments[0].toBunString(globalObject);
defer hostname_str.deref();
const port = arguments[1].coerce(i32, globalObject);
const username_str = arguments[2].toBunString(globalObject);
defer username_str.deref();
const password_str = arguments[3].toBunString(globalObject);
defer password_str.deref();
const database_str = arguments[4].toBunString(globalObject);
defer database_str.deref();
// TODO: update this to match MySQL.
const ssl_mode: SSLMode = switch (arguments[5].toInt32()) {
0 => .disable,
1 => .prefer,
2 => .require,
3 => .verify_ca,
4 => .verify_full,
else => .disable,
};
const tls_object = arguments[6];
var tls_config: JSC.API.ServerConfig.SSLConfig = .{};
var tls_ctx: ?*uws.SocketContext = null;
if (ssl_mode != .disable) {
tls_config = if (tls_object.isBoolean() and tls_object.toBoolean())
.{}
else if (tls_object.isObject())
(JSC.API.ServerConfig.SSLConfig.fromJS(vm, globalObject, tls_object) catch return .zero) orelse .{}
else {
return globalObject.throwInvalidArguments("tls must be a boolean or an object", .{});
};
if (globalObject.hasException()) {
tls_config.deinit();
return .zero;
}
if (tls_config.reject_unauthorized != 0)
tls_config.request_cert = 1;
// We create it right here so we can throw errors early.
const context_options = tls_config.asUSockets();
var err: uws.create_bun_socket_error_t = .none;
tls_ctx = uws.us_create_bun_socket_context(1, vm.uwsLoop(), @sizeOf(*MySQLConnection), context_options, &err) orelse {
if (err != .none) {
globalObject.throw("failed to create TLS context", .{});
} else {
globalObject.throwValue(err.toJS(globalObject));
}
return .zero;
};
if (err != .none) {
tls_config.deinit();
globalObject.throwValue(err.toJS(globalObject));
if (tls_ctx) |ctx| {
ctx.deinit(true);
}
return .zero;
}
uws.NewSocketHandler(true).configure(tls_ctx.?, true, *MySQLConnection, SocketHandler(true));
}
var username: []const u8 = "";
var password: []const u8 = "";
var database: []const u8 = "";
var options: []const u8 = "";
const options_str = arguments[7].toBunString(globalObject);
defer options_str.deref();
const options_buf: []u8 = brk: {
var b = bun.StringBuilder{};
b.cap += username_str.utf8ByteLength() + 1 + password_str.utf8ByteLength() + 1 + database_str.utf8ByteLength() + 1 + options_str.utf8ByteLength() + 1;
b.allocate(bun.default_allocator) catch {};
var u = username_str.toUTF8WithoutRef(bun.default_allocator);
defer u.deinit();
username = b.append(u.slice());
var p = password_str.toUTF8WithoutRef(bun.default_allocator);
defer p.deinit();
password = b.append(p.slice());
var d = database_str.toUTF8WithoutRef(bun.default_allocator);
defer d.deinit();
database = b.append(d.slice());
var o = options_str.toUTF8WithoutRef(bun.default_allocator);
defer o.deinit();
options = b.append(o.slice());
break :brk b.allocatedSlice();
};
const on_connect = arguments[8];
const on_close = arguments[9];
var ptr = try bun.default_allocator.create(MySQLConnection);
ptr.* = MySQLConnection{
.globalObject = globalObject,
.on_connect = JSC.Strong.create(on_connect, globalObject),
.on_close = JSC.Strong.create(on_close, globalObject),
.database = database,
.user = username,
.password = password,
.options = options,
.options_buf = options_buf,
.socket = .{
.SocketTCP = .{ .socket = .{ .detached = {} } },
},
.requests = Queue.init(bun.default_allocator),
.statements = PreparedStatementsMap{},
.tls_config = tls_config,
.tls_ctx = tls_ctx,
.ssl_mode = ssl_mode,
.tls_status = if (ssl_mode != .disable) .pending else .none,
};
ptr.updateHasPendingActivity();
ptr.poll_ref.ref(vm);
const js_value = ptr.toJS(globalObject);
js_value.ensureStillAlive();
ptr.js_value = js_value;
{
const hostname = hostname_str.toUTF8(bun.default_allocator);
defer hostname.deinit();
const ctx = vm.rareData().mysql_context.tcp orelse brk: {
var err: uws.create_bun_socket_error_t = .none;
const ctx_ = uws.us_create_bun_socket_context(0, vm.uwsLoop(), @sizeOf(*MySQLConnection), uws.us_bun_socket_context_options_t{}, &err).?;
uws.NewSocketHandler(false).configure(ctx_, true, *MySQLConnection, SocketHandler(false));
vm.rareData().mysql_context.tcp = ctx_;
break :brk ctx_;
};
ptr.socket = .{
.SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| {
tls_config.deinit();
if (tls_ctx) |tls| {
tls.deinit(true);
}
ptr.deinit();
return globalObject.throwError(err, "failed to connect to mysql");
},
};
}
return js_value;
}
pub fn deinit(this: *MySQLConnection) void {
debug("MySQLConnection deinit", .{});
bun.assert(this.ref_count == 0);
var requests = this.requests;
defer requests.deinit();
this.requests = .{};
// Clear any pending requests first
for (this.requests.readableSlice(0)) |request| {
for (requests.readableSlice(0)) |request| {
request.onError(.{
.error_code = 2013,
.error_message = .{ .temporary = "Connection closed" },
}, this.globalObject);
}
for (this.columns) |*column| {
@constCast(column).deinit();
this.write_buffer.deinit();
this.read_buffer.deinit();
this.statements.deinit(bun.default_allocator);
this.tls_config.deinit();
if (this.tls_ctx) |ctx| {
ctx.deinit(true);
}
bun.default_allocator.free(this.columns);
bun.default_allocator.free(this.params);
this.cached_structure.deinit();
this.error_response.deinit();
this.signature.deinit();
bun.default_allocator.free(this.options_buf);
bun.default_allocator.destroy(this);
}
@@ -503,15 +731,18 @@ pub const MySQLConnection = struct {
const header = protocol.PacketHeader.decode(reader.peek()) orelse break;
try reader.skip(protocol.PACKET_HEADER_SIZE);
// Ensure we have the full packet
reader.ensureCapacity(header.length) catch |err| {
if (err == error.ShortRead) {
try reader.skip(-@as(isize, @intCast(protocol.PACKET_HEADER_SIZE)));
}
return err;
};
// Update sequence id
this.sequence_id = header.sequence_id +% 1;
// Ensure we have the full packet
if (!reader.ensureCapacity(header.length)) {
try reader.skip(-@as(isize, @intCast(protocol.PACKET_HEADER_SIZE)));
return error.ShortRead;
}
// Process packet based on connection state
switch (this.status) {
.handshaking => try this.handleHandshake(Context, reader),
@@ -539,6 +770,11 @@ pub const MySQLConnection = struct {
this.character_set = handshake.character_set;
this.status_flags = handshake.status_flags;
if (this.auth_data.len > 0) {
bun.default_allocator.free(this.auth_data);
this.auth_data = "";
}
// Store auth data
this.auth_data = try bun.default_allocator.alloc(u8, handshake.auth_plugin_data_part_1.len + handshake.auth_plugin_data_part_2.len);
@memcpy(this.auth_data[0..8], &handshake.auth_plugin_data_part_1);
@@ -813,7 +1049,6 @@ pub const MySQLConnection = struct {
const request = this.requests.peekItem(0);
if (request.statement) |statement| {
statement.statement_id = ok.statement_id;
statement.status = .prepared;
// Read parameter definitions if any
if (ok.num_params > 0) {
@@ -849,15 +1084,12 @@ pub const MySQLConnection = struct {
statement.columns = columns;
}
var execute = protocol.PreparedStatement.Execute{
.statement_id = statement.statement_id,
.param_types = statement.params,
.iteration_count = 1,
};
defer execute.deinit();
try request.bind(&execute, this.globalObject);
try execute.writeInternal(Context, this.writer());
this.flushData();
statement.status = .prepared;
if (request.status == .pending) {
try request.bindAndExecute(this.writer(), statement, this.globalObject);
this.flushData();
}
}
},
@@ -1170,6 +1402,18 @@ pub const MySQLQuery = struct {
});
}
pub fn bindAndExecute(this: *MySQLQuery, writer: anytype, statement: *MySQLStatement, globalObject: *JSC.JSGlobalObject) !void {
var execute = protocol.PreparedStatement.Execute{
.statement_id = statement.statement_id,
.param_types = statement.params,
.iteration_count = 1,
};
defer execute.deinit();
try this.bind(&execute, globalObject);
try execute.write(writer);
this.status = .written;
}
pub fn bind(this: *MySQLQuery, execute: *protocol.PreparedStatement.Execute, globalObject: *JSC.JSGlobalObject) !void {
const binding_value = MySQLQuery.bindingGetCached(this.thisValue) orelse .zero;
const columns_value = MySQLQuery.columnsGetCached(this.thisValue) orelse .zero;
@@ -1191,6 +1435,10 @@ pub const MySQLQuery = struct {
i += 1;
}
if (iter.anyFailed()) {
return error.InvalidQueryBinding;
}
this.status = .binding;
execute.params = params;
}
@@ -1271,7 +1519,7 @@ pub const MySQLQuery = struct {
}
pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue {
const arguments = callframe.arguments(4).slice();
const arguments = callframe.argumentsUndef(4).slice();
const query = arguments[0];
const values = arguments[1];
const columns = arguments[3];
@@ -1316,6 +1564,86 @@ pub const MySQLQuery = struct {
return this_value;
}
pub fn doRun(this: *MySQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue {
var arguments_ = callframe.arguments_old(2);
const arguments = arguments_.slice();
var connection: *MySQLConnection = arguments[0].as(MySQLConnection) orelse {
globalObject.throw("connection must be a PostgresSQLConnection", .{});
return error.JSError;
};
var query = arguments[1];
if (!query.isObject()) {
globalObject.throwInvalidArgumentType("run", "query", "Query");
return error.JSError;
}
this.target.set(globalObject, query);
const binding_value = MySQLQuery.bindingGetCached(callframe.this()) orelse .zero;
var query_str = this.query.toUTF8(bun.default_allocator);
defer query_str.deinit();
const columns_value = MySQLQuery.columnsGetCached(callframe.this()) orelse .undefined;
var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value) catch |err| {
if (!globalObject.hasException())
return globalObject.throwError(err, "failed to generate signature");
return error.JSError;
};
errdefer signature.deinit();
const writer = connection.writer();
const entry = connection.statements.getOrPut(bun.default_allocator, bun.hash(signature.name)) catch |err| {
return globalObject.throwError(err, "failed to allocate statement");
};
const has_params = signature.fields.len > 0;
var did_write = false;
enqueue: {
if (entry.found_existing) {
this.statement = entry.value_ptr.*;
this.statement.?.ref();
signature.deinit();
signature = Signature{};
if (has_params and this.statement.?.status == .parsing) {
// if it has params, we need to wait for PrepareOk to be received before we can write the data
} else {
this.binary = true;
this.bindAndExecute(writer, this.statement.?, globalObject) catch |err| {
if (!globalObject.hasException())
return globalObject.throwError(err, "failed to bind and execute query");
return error.JSError;
};
did_write = true;
}
break :enqueue;
}
const stmt = bun.default_allocator.create(MySQLStatement) catch |err| {
return globalObject.throwError(err, "failed to allocate statement");
};
stmt.* = .{
.signature = signature,
.ref_count = 2,
.status = .parsing,
};
this.statement = stmt;
entry.value_ptr.* = stmt;
}
try connection.requests.writeItem(this);
this.ref();
this.status = if (did_write) .binding else .pending;
if (connection.is_ready_for_query)
connection.flushData();
return .undefined;
}
comptime {
if (!JSC.is_bindgen) {
const jscall = JSC.toJSHostFunction(call);
@@ -1325,9 +1653,9 @@ pub const MySQLQuery = struct {
};
pub const Signature = struct {
fields: []const types.FieldType,
name: []const u8,
query: []const u8,
fields: []const types.FieldType = &.{},
name: []const u8 = "",
query: []const u8 = "",
pub fn deinit(this: *Signature) void {
bun.default_allocator.free(this.fields);
@@ -1387,3 +1715,23 @@ pub const TLSStatus = enum {
ssl_not_available,
ssl_ok,
};
pub fn createBinding(globalObject: *JSC.JSGlobalObject) JSValue {
const binding = JSValue.createEmptyObjectWithNullPrototype(globalObject);
const ZigString = JSC.ZigString;
binding.put(globalObject, ZigString.static("MySQLConnection"), MySQLConnection.getConstructor(globalObject));
binding.put(globalObject, ZigString.static("init"), JSC.JSFunction.create(globalObject, "init", MySQLContext.init, 0, .{}));
binding.put(
globalObject,
ZigString.static("createQuery"),
JSC.JSFunction.create(globalObject, "createQuery", MySQLQuery.call, 2, .{}),
);
binding.put(
globalObject,
ZigString.static("createConnection"),
JSC.JSFunction.create(globalObject, "createConnection", MySQLConnection.call, 10, .{}),
);
return binding;
}

View File

@@ -132,28 +132,6 @@ pub fn NewWriterWrap(
try writeFn(this.wrapped, data);
}
pub const LengthWriter = struct {
index: usize,
context: WrappedWriter,
pub fn write(this: LengthWriter) anyerror!void {
try this.context.pwrite(&Int32(this.context.offset() - this.index), this.index);
}
pub fn writeExcludingSelf(this: LengthWriter) anyerror!void {
try this.context.pwrite(&Int32(this.context.offset() -| (this.index + 4)), this.index);
}
};
pub inline fn length(this: @This()) anyerror!LengthWriter {
const i = this.offset();
try this.int4(0);
return LengthWriter{
.index = i,
.context = this,
};
}
pub inline fn offset(this: @This()) usize {
return offsetFn(this.wrapped);
}
@@ -285,10 +263,6 @@ fn writeWrap(comptime Container: type, comptime writeFn: anytype) type {
};
}
fn Int32(value: anytype) [4]u8 {
return
}
// MySQL packet types
pub const PacketType = enum(u8) {
// Server packets
@@ -766,7 +740,13 @@ pub const StmtExecutePacket = struct {
}
}
pub fn writeInternal(this: *const StmtExecutePacket, comptime Context: type, writer: NewWriter(Context), iter: *sql.QueryBindingIterator, ) !void {
pub fn writeInternal(
this: *const StmtExecutePacket,
comptime Context: type,
writer: NewWriter(Context),
iter: *sql.QueryBindingIterator,
) !void {
_ = iter; // autofix
try writer.int1(@intFromEnum(this.command));
try writer.int4(this.statement_id);
try writer.int1(this.flags);

View File

@@ -1359,7 +1359,11 @@ pub const PostgresSQLConnection = struct {
.password = password,
.options = options,
.options_buf = options_buf,
.socket = undefined,
.socket = .{
.SocketTCP = .{
.socket = .{ .detached = {} },
},
},
.requests = PostgresRequest.Queue.init(bun.default_allocator),
.statements = PreparedStatementsMap{},
.tls_config = tls_config,
@@ -1662,7 +1666,9 @@ pub const PostgresSQLConnection = struct {
switch (this.tag) {
.string => {
this.value.string.deref();
if (this.value.string != null) {
this.value.string.?.deref();
}
},
.json => {
this.value.json.deref();