diff --git a/src/sql/mysql.zig b/src/sql/mysql.zig index 186b182208..5315fa8fe4 100644 --- a/src/sql/mysql.zig +++ b/src/sql/mysql.zig @@ -37,7 +37,7 @@ pub const SSLMode = enum(u8) { verify_ca = 3, verify_full = 4, }; -const Data = sql.Data; +pub const Data = sql.Data; // MySQL capability flags pub const Capabilities = packed struct(u32) { CLIENT_LONG_PASSWORD: bool = false, @@ -212,9 +212,12 @@ pub const MySQLConnection = struct { on_connect: JSC.Strong = .{}, on_close: JSC.Strong = .{}, + auth_data: []const u8 = "", database: []const u8 = "", user: []const u8 = "", password: []const u8 = "", + options: []const u8 = "", + options_buf: []const u8 = "", pub const AuthState = union(enum) { pending: void, @@ -314,8 +317,8 @@ pub const MySQLConnection = struct { this.fail("Connection closed", error.ConnectionClosed); } - fn start(this: *MySQLConnection) void { - this.sendHandshakeResponse(); + fn start(this: *MySQLConnection) !void { + try this.sendHandshakeResponse(); const event_loop = this.globalObject.bunVM().eventLoop(); event_loop.enter(); @@ -359,27 +362,252 @@ pub const MySQLConnection = struct { } } - pub fn deinit(this: *@This()) void { + const Queue = std.fifo.LinearFifo(*MySQLQuery, .Dynamic); + + fn SocketHandler(comptime ssl: bool) type { + return struct { + const SocketType = uws.NewSocketHandler(ssl); + fn _socket(s: SocketType) Socket { + if (comptime ssl) { + return Socket{ .SocketTLS = s }; + } + + return Socket{ .SocketTCP = s }; + } + pub fn onOpen(this: *MySQLConnection, socket: SocketType) void { + this.onOpen(_socket(socket)); + } + + fn onHandshake_(this: *MySQLConnection, _: anytype, success: i32, ssl_error: uws.us_bun_verify_error_t) void { + this.onHandshake(success, ssl_error); + } + + pub const onHandshake = if (ssl) onHandshake_ else null; + + pub fn onClose(this: *MySQLConnection, socket: SocketType, _: i32, _: ?*anyopaque) void { + _ = socket; + this.onClose(); + } + + pub fn onEnd(this: *MySQLConnection, socket: SocketType) void { + _ = socket; + this.onClose(); + } + + pub fn onConnectError(this: *MySQLConnection, socket: SocketType, _: i32) void { + _ = socket; + this.onClose(); + } + + pub fn onTimeout(this: *MySQLConnection, socket: SocketType) void { + _ = socket; + this.onTimeout(); + } + + pub fn onData(this: *MySQLConnection, socket: SocketType, data: []const u8) void { + _ = socket; + this.onData(data); + } + + pub fn onWritable(this: *MySQLConnection, socket: SocketType) void { + _ = socket; + this.onDrain(); + } + }; + } + + pub fn onTimeout(this: *MySQLConnection) void { + this.fail("Connection timed out", error.ConnectionTimedOut); + } + + pub fn onDrain(this: *MySQLConnection) void { + const event_loop = this.globalObject.bunVM().eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + this.flushData(); + } + + pub fn call(globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { + var vm = globalObject.bunVM(); + const arguments = callframe.arguments_old(10).slice(); + const hostname_str = arguments[0].toBunString(globalObject); + defer hostname_str.deref(); + const port = arguments[1].coerce(i32, globalObject); + + const username_str = arguments[2].toBunString(globalObject); + defer username_str.deref(); + const password_str = arguments[3].toBunString(globalObject); + defer password_str.deref(); + const database_str = arguments[4].toBunString(globalObject); + defer database_str.deref(); + // TODO: update this to match MySQL. + const ssl_mode: SSLMode = switch (arguments[5].toInt32()) { + 0 => .disable, + 1 => .prefer, + 2 => .require, + 3 => .verify_ca, + 4 => .verify_full, + else => .disable, + }; + + const tls_object = arguments[6]; + + var tls_config: JSC.API.ServerConfig.SSLConfig = .{}; + var tls_ctx: ?*uws.SocketContext = null; + if (ssl_mode != .disable) { + tls_config = if (tls_object.isBoolean() and tls_object.toBoolean()) + .{} + else if (tls_object.isObject()) + (JSC.API.ServerConfig.SSLConfig.fromJS(vm, globalObject, tls_object) catch return .zero) orelse .{} + else { + return globalObject.throwInvalidArguments("tls must be a boolean or an object", .{}); + }; + + if (globalObject.hasException()) { + tls_config.deinit(); + return .zero; + } + + if (tls_config.reject_unauthorized != 0) + tls_config.request_cert = 1; + + // We create it right here so we can throw errors early. + const context_options = tls_config.asUSockets(); + var err: uws.create_bun_socket_error_t = .none; + tls_ctx = uws.us_create_bun_socket_context(1, vm.uwsLoop(), @sizeOf(*MySQLConnection), context_options, &err) orelse { + if (err != .none) { + globalObject.throw("failed to create TLS context", .{}); + } else { + globalObject.throwValue(err.toJS(globalObject)); + } + return .zero; + }; + + if (err != .none) { + tls_config.deinit(); + globalObject.throwValue(err.toJS(globalObject)); + if (tls_ctx) |ctx| { + ctx.deinit(true); + } + return .zero; + } + + uws.NewSocketHandler(true).configure(tls_ctx.?, true, *MySQLConnection, SocketHandler(true)); + } + + var username: []const u8 = ""; + var password: []const u8 = ""; + var database: []const u8 = ""; + var options: []const u8 = ""; + + const options_str = arguments[7].toBunString(globalObject); + defer options_str.deref(); + + const options_buf: []u8 = brk: { + var b = bun.StringBuilder{}; + b.cap += username_str.utf8ByteLength() + 1 + password_str.utf8ByteLength() + 1 + database_str.utf8ByteLength() + 1 + options_str.utf8ByteLength() + 1; + + b.allocate(bun.default_allocator) catch {}; + var u = username_str.toUTF8WithoutRef(bun.default_allocator); + defer u.deinit(); + username = b.append(u.slice()); + + var p = password_str.toUTF8WithoutRef(bun.default_allocator); + defer p.deinit(); + password = b.append(p.slice()); + + var d = database_str.toUTF8WithoutRef(bun.default_allocator); + defer d.deinit(); + database = b.append(d.slice()); + + var o = options_str.toUTF8WithoutRef(bun.default_allocator); + defer o.deinit(); + options = b.append(o.slice()); + + break :brk b.allocatedSlice(); + }; + + const on_connect = arguments[8]; + const on_close = arguments[9]; + + var ptr = try bun.default_allocator.create(MySQLConnection); + + ptr.* = MySQLConnection{ + .globalObject = globalObject, + .on_connect = JSC.Strong.create(on_connect, globalObject), + .on_close = JSC.Strong.create(on_close, globalObject), + .database = database, + .user = username, + .password = password, + .options = options, + .options_buf = options_buf, + .socket = .{ + .SocketTCP = .{ .socket = .{ .detached = {} } }, + }, + .requests = Queue.init(bun.default_allocator), + .statements = PreparedStatementsMap{}, + .tls_config = tls_config, + .tls_ctx = tls_ctx, + .ssl_mode = ssl_mode, + .tls_status = if (ssl_mode != .disable) .pending else .none, + }; + + ptr.updateHasPendingActivity(); + ptr.poll_ref.ref(vm); + const js_value = ptr.toJS(globalObject); + js_value.ensureStillAlive(); + ptr.js_value = js_value; + + { + const hostname = hostname_str.toUTF8(bun.default_allocator); + defer hostname.deinit(); + + const ctx = vm.rareData().mysql_context.tcp orelse brk: { + var err: uws.create_bun_socket_error_t = .none; + const ctx_ = uws.us_create_bun_socket_context(0, vm.uwsLoop(), @sizeOf(*MySQLConnection), uws.us_bun_socket_context_options_t{}, &err).?; + uws.NewSocketHandler(false).configure(ctx_, true, *MySQLConnection, SocketHandler(false)); + vm.rareData().mysql_context.tcp = ctx_; + break :brk ctx_; + }; + ptr.socket = .{ + .SocketTCP = uws.SocketTCP.connectAnon(hostname.slice(), port, ctx, ptr, false) catch |err| { + tls_config.deinit(); + if (tls_ctx) |tls| { + tls.deinit(true); + } + ptr.deinit(); + return globalObject.throwError(err, "failed to connect to mysql"); + }, + }; + } + + return js_value; + } + + pub fn deinit(this: *MySQLConnection) void { debug("MySQLConnection deinit", .{}); bun.assert(this.ref_count == 0); + var requests = this.requests; + defer requests.deinit(); + this.requests = .{}; + // Clear any pending requests first - for (this.requests.readableSlice(0)) |request| { + for (requests.readableSlice(0)) |request| { request.onError(.{ .error_code = 2013, .error_message = .{ .temporary = "Connection closed" }, }, this.globalObject); } - - for (this.columns) |*column| { - @constCast(column).deinit(); + this.write_buffer.deinit(); + this.read_buffer.deinit(); + this.statements.deinit(bun.default_allocator); + this.tls_config.deinit(); + if (this.tls_ctx) |ctx| { + ctx.deinit(true); } - bun.default_allocator.free(this.columns); - bun.default_allocator.free(this.params); - this.cached_structure.deinit(); - this.error_response.deinit(); - this.signature.deinit(); + bun.default_allocator.free(this.options_buf); bun.default_allocator.destroy(this); } @@ -503,15 +731,18 @@ pub const MySQLConnection = struct { const header = protocol.PacketHeader.decode(reader.peek()) orelse break; try reader.skip(protocol.PACKET_HEADER_SIZE); + // Ensure we have the full packet + reader.ensureCapacity(header.length) catch |err| { + if (err == error.ShortRead) { + try reader.skip(-@as(isize, @intCast(protocol.PACKET_HEADER_SIZE))); + } + + return err; + }; + // Update sequence id this.sequence_id = header.sequence_id +% 1; - // Ensure we have the full packet - if (!reader.ensureCapacity(header.length)) { - try reader.skip(-@as(isize, @intCast(protocol.PACKET_HEADER_SIZE))); - return error.ShortRead; - } - // Process packet based on connection state switch (this.status) { .handshaking => try this.handleHandshake(Context, reader), @@ -539,6 +770,11 @@ pub const MySQLConnection = struct { this.character_set = handshake.character_set; this.status_flags = handshake.status_flags; + if (this.auth_data.len > 0) { + bun.default_allocator.free(this.auth_data); + this.auth_data = ""; + } + // Store auth data this.auth_data = try bun.default_allocator.alloc(u8, handshake.auth_plugin_data_part_1.len + handshake.auth_plugin_data_part_2.len); @memcpy(this.auth_data[0..8], &handshake.auth_plugin_data_part_1); @@ -813,7 +1049,6 @@ pub const MySQLConnection = struct { const request = this.requests.peekItem(0); if (request.statement) |statement| { statement.statement_id = ok.statement_id; - statement.status = .prepared; // Read parameter definitions if any if (ok.num_params > 0) { @@ -849,15 +1084,12 @@ pub const MySQLConnection = struct { statement.columns = columns; } - var execute = protocol.PreparedStatement.Execute{ - .statement_id = statement.statement_id, - .param_types = statement.params, - .iteration_count = 1, - }; - defer execute.deinit(); - try request.bind(&execute, this.globalObject); - try execute.writeInternal(Context, this.writer()); - this.flushData(); + statement.status = .prepared; + + if (request.status == .pending) { + try request.bindAndExecute(this.writer(), statement, this.globalObject); + this.flushData(); + } } }, @@ -1170,6 +1402,18 @@ pub const MySQLQuery = struct { }); } + pub fn bindAndExecute(this: *MySQLQuery, writer: anytype, statement: *MySQLStatement, globalObject: *JSC.JSGlobalObject) !void { + var execute = protocol.PreparedStatement.Execute{ + .statement_id = statement.statement_id, + .param_types = statement.params, + .iteration_count = 1, + }; + defer execute.deinit(); + try this.bind(&execute, globalObject); + try execute.write(writer); + this.status = .written; + } + pub fn bind(this: *MySQLQuery, execute: *protocol.PreparedStatement.Execute, globalObject: *JSC.JSGlobalObject) !void { const binding_value = MySQLQuery.bindingGetCached(this.thisValue) orelse .zero; const columns_value = MySQLQuery.columnsGetCached(this.thisValue) orelse .zero; @@ -1191,6 +1435,10 @@ pub const MySQLQuery = struct { i += 1; } + if (iter.anyFailed()) { + return error.InvalidQueryBinding; + } + this.status = .binding; execute.params = params; } @@ -1271,7 +1519,7 @@ pub const MySQLQuery = struct { } pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { - const arguments = callframe.arguments(4).slice(); + const arguments = callframe.argumentsUndef(4).slice(); const query = arguments[0]; const values = arguments[1]; const columns = arguments[3]; @@ -1316,6 +1564,86 @@ pub const MySQLQuery = struct { return this_value; } + pub fn doRun(this: *MySQLQuery, globalObject: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSValue { + var arguments_ = callframe.arguments_old(2); + const arguments = arguments_.slice(); + var connection: *MySQLConnection = arguments[0].as(MySQLConnection) orelse { + globalObject.throw("connection must be a PostgresSQLConnection", .{}); + return error.JSError; + }; + var query = arguments[1]; + + if (!query.isObject()) { + globalObject.throwInvalidArgumentType("run", "query", "Query"); + return error.JSError; + } + + this.target.set(globalObject, query); + const binding_value = MySQLQuery.bindingGetCached(callframe.this()) orelse .zero; + var query_str = this.query.toUTF8(bun.default_allocator); + defer query_str.deinit(); + const columns_value = MySQLQuery.columnsGetCached(callframe.this()) orelse .undefined; + + var signature = Signature.generate(globalObject, query_str.slice(), binding_value, columns_value) catch |err| { + if (!globalObject.hasException()) + return globalObject.throwError(err, "failed to generate signature"); + return error.JSError; + }; + errdefer signature.deinit(); + + const writer = connection.writer(); + + const entry = connection.statements.getOrPut(bun.default_allocator, bun.hash(signature.name)) catch |err| { + return globalObject.throwError(err, "failed to allocate statement"); + }; + + const has_params = signature.fields.len > 0; + var did_write = false; + + enqueue: { + if (entry.found_existing) { + this.statement = entry.value_ptr.*; + this.statement.?.ref(); + signature.deinit(); + signature = Signature{}; + + if (has_params and this.statement.?.status == .parsing) { + // if it has params, we need to wait for PrepareOk to be received before we can write the data + } else { + this.binary = true; + this.bindAndExecute(writer, this.statement.?, globalObject) catch |err| { + if (!globalObject.hasException()) + return globalObject.throwError(err, "failed to bind and execute query"); + return error.JSError; + }; + did_write = true; + } + + break :enqueue; + } + + const stmt = bun.default_allocator.create(MySQLStatement) catch |err| { + return globalObject.throwError(err, "failed to allocate statement"); + }; + stmt.* = .{ + .signature = signature, + .ref_count = 2, + .status = .parsing, + }; + this.statement = stmt; + entry.value_ptr.* = stmt; + } + + try connection.requests.writeItem(this); + this.ref(); + this.status = if (did_write) .binding else .pending; + + if (connection.is_ready_for_query) + connection.flushData(); + + return .undefined; + } + comptime { if (!JSC.is_bindgen) { const jscall = JSC.toJSHostFunction(call); @@ -1325,9 +1653,9 @@ pub const MySQLQuery = struct { }; pub const Signature = struct { - fields: []const types.FieldType, - name: []const u8, - query: []const u8, + fields: []const types.FieldType = &.{}, + name: []const u8 = "", + query: []const u8 = "", pub fn deinit(this: *Signature) void { bun.default_allocator.free(this.fields); @@ -1387,3 +1715,23 @@ pub const TLSStatus = enum { ssl_not_available, ssl_ok, }; + +pub fn createBinding(globalObject: *JSC.JSGlobalObject) JSValue { + const binding = JSValue.createEmptyObjectWithNullPrototype(globalObject); + const ZigString = JSC.ZigString; + binding.put(globalObject, ZigString.static("MySQLConnection"), MySQLConnection.getConstructor(globalObject)); + binding.put(globalObject, ZigString.static("init"), JSC.JSFunction.create(globalObject, "init", MySQLContext.init, 0, .{})); + binding.put( + globalObject, + ZigString.static("createQuery"), + JSC.JSFunction.create(globalObject, "createQuery", MySQLQuery.call, 2, .{}), + ); + + binding.put( + globalObject, + ZigString.static("createConnection"), + JSC.JSFunction.create(globalObject, "createConnection", MySQLConnection.call, 10, .{}), + ); + + return binding; +} diff --git a/src/sql/mysql/mysql_protocol.zig b/src/sql/mysql/mysql_protocol.zig index f5b6e08993..74dbef67c2 100644 --- a/src/sql/mysql/mysql_protocol.zig +++ b/src/sql/mysql/mysql_protocol.zig @@ -132,28 +132,6 @@ pub fn NewWriterWrap( try writeFn(this.wrapped, data); } - pub const LengthWriter = struct { - index: usize, - context: WrappedWriter, - - pub fn write(this: LengthWriter) anyerror!void { - try this.context.pwrite(&Int32(this.context.offset() - this.index), this.index); - } - - pub fn writeExcludingSelf(this: LengthWriter) anyerror!void { - try this.context.pwrite(&Int32(this.context.offset() -| (this.index + 4)), this.index); - } - }; - - pub inline fn length(this: @This()) anyerror!LengthWriter { - const i = this.offset(); - try this.int4(0); - return LengthWriter{ - .index = i, - .context = this, - }; - } - pub inline fn offset(this: @This()) usize { return offsetFn(this.wrapped); } @@ -285,10 +263,6 @@ fn writeWrap(comptime Container: type, comptime writeFn: anytype) type { }; } -fn Int32(value: anytype) [4]u8 { - return -} - // MySQL packet types pub const PacketType = enum(u8) { // Server packets @@ -766,7 +740,13 @@ pub const StmtExecutePacket = struct { } } - pub fn writeInternal(this: *const StmtExecutePacket, comptime Context: type, writer: NewWriter(Context), iter: *sql.QueryBindingIterator, ) !void { + pub fn writeInternal( + this: *const StmtExecutePacket, + comptime Context: type, + writer: NewWriter(Context), + iter: *sql.QueryBindingIterator, + ) !void { + _ = iter; // autofix try writer.int1(@intFromEnum(this.command)); try writer.int4(this.statement_id); try writer.int1(this.flags); diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index c624285b74..1e5593a3a7 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -1359,7 +1359,11 @@ pub const PostgresSQLConnection = struct { .password = password, .options = options, .options_buf = options_buf, - .socket = undefined, + .socket = .{ + .SocketTCP = .{ + .socket = .{ .detached = {} }, + }, + }, .requests = PostgresRequest.Queue.init(bun.default_allocator), .statements = PreparedStatementsMap{}, .tls_config = tls_config, @@ -1662,7 +1666,9 @@ pub const PostgresSQLConnection = struct { switch (this.tag) { .string => { - this.value.string.deref(); + if (this.value.string != null) { + this.value.string.?.deref(); + } }, .json => { this.value.json.deref();