From e3783c244fb939680385e58920052a6d625f94f6 Mon Sep 17 00:00:00 2001 From: Marko Vejnovic Date: Wed, 24 Sep 2025 20:55:25 -0700 Subject: [PATCH 01/43] chore(libuv): Update to 1.51.0 (#22942) ### What does this PR do? Uprevs `libuv` to version `1.51.0`. ### How did you verify your code works? CI passes. --------- Co-authored-by: Jarred Sumner --- cmake/targets/BuildLibuv.cmake | 3 ++- src/bun.js/bindings/libuv/uv/version.h | 8 ++++---- src/bun.js/webcore/FileReader.zig | 8 ++++++++ src/deps/libuv.zig | 10 +++++----- test/js/bun/spawn/spawn.test.ts | 2 +- 5 files changed, 20 insertions(+), 11 deletions(-) diff --git a/cmake/targets/BuildLibuv.cmake b/cmake/targets/BuildLibuv.cmake index feba612c44..de95e20955 100644 --- a/cmake/targets/BuildLibuv.cmake +++ b/cmake/targets/BuildLibuv.cmake @@ -4,7 +4,8 @@ register_repository( REPOSITORY libuv/libuv COMMIT - da527d8d2a908b824def74382761566371439003 + # Corresponds to v1.51.0 + 5152db2cbfeb5582e9c27c5ea1dba2cd9e10759b ) if(WIN32) diff --git a/src/bun.js/bindings/libuv/uv/version.h b/src/bun.js/bindings/libuv/uv/version.h index 6356e1ee44..77432f2595 100644 --- a/src/bun.js/bindings/libuv/uv/version.h +++ b/src/bun.js/bindings/libuv/uv/version.h @@ -31,10 +31,10 @@ */ #define UV_VERSION_MAJOR 1 -#define UV_VERSION_MINOR 50 -#define UV_VERSION_PATCH 1 -#define UV_VERSION_IS_RELEASE 0 -#define UV_VERSION_SUFFIX "dev" +#define UV_VERSION_MINOR 51 +#define UV_VERSION_PATCH 0 +#define UV_VERSION_IS_RELEASE 1 +#define UV_VERSION_SUFFIX "" #define UV_VERSION_HEX ((UV_VERSION_MAJOR << 16) | \ (UV_VERSION_MINOR << 8) | \ diff --git a/src/bun.js/webcore/FileReader.zig b/src/bun.js/webcore/FileReader.zig index 33dcc82036..d91d5ffb3a 100644 --- a/src/bun.js/webcore/FileReader.zig +++ b/src/bun.js/webcore/FileReader.zig @@ -351,6 +351,14 @@ pub fn onReadChunk(this: *@This(), init_buf: []const u8, state: bun.io.ReadState else => @panic("Invalid state"), } } else if (this.pending.state == .pending) { + // Certain readers (such as pipes) may return 0-byte reads even when + // not at EOF. Consequently, we need to check whether the reader is + // actually done or not. + if (buf.len == 0 and state == .drained) { + // If the reader is not done, we still want to keep reading. + return true; + } + defer { this.pending_value.clearWithoutDeallocation(); this.pending_view = &.{}; diff --git a/src/deps/libuv.zig b/src/deps/libuv.zig index 220d439973..e3b1837f1b 100644 --- a/src/deps/libuv.zig +++ b/src/deps/libuv.zig @@ -143,10 +143,10 @@ pub const UV__ENODATA = -@as(c_int, 4024); pub const UV__EUNATCH = -@as(c_int, 4023); pub const UV_VERSION_H = ""; pub const UV_VERSION_MAJOR = @as(c_int, 1); -pub const UV_VERSION_MINOR = @as(c_int, 46); -pub const UV_VERSION_PATCH = @as(c_int, 1); -pub const UV_VERSION_IS_RELEASE = @as(c_int, 0); -pub const UV_VERSION_SUFFIX = "dev"; +pub const UV_VERSION_MINOR = @as(c_int, 51); +pub const UV_VERSION_PATCH = @as(c_int, 0); +pub const UV_VERSION_IS_RELEASE = @as(c_int, 1); +pub const UV_VERSION_SUFFIX = ""; pub const UV_VERSION_HEX = ((UV_VERSION_MAJOR << @as(c_int, 16)) | (UV_VERSION_MINOR << @as(c_int, 8))) | UV_VERSION_PATCH; pub const UV_THREADPOOL_H_ = ""; @@ -2981,7 +2981,7 @@ fn StreamMixin(comptime Type: type) type { req.readStop(); error_cb(context_data, ReturnCodeI64.init(nreads).errEnum() orelse bun.sys.E.CANCELED); } else { - read_cb(context_data, buffer.slice()); + read_cb(context_data, buffer.base[0..@intCast(nreads)]); } } }; diff --git a/test/js/bun/spawn/spawn.test.ts b/test/js/bun/spawn/spawn.test.ts index 7cf2780fc7..e611f2b924 100644 --- a/test/js/bun/spawn/spawn.test.ts +++ b/test/js/bun/spawn/spawn.test.ts @@ -448,7 +448,7 @@ for (let [gcTick, label] of [ for (const [callback, fixture] of fixtures) { describe(fixture.slice(0, 12), () => { - describe("should should allow reading stdout", () => { + describe("should allow reading stdout", () => { it("before exit", async () => { const process = callback(); const output = await process.stdout.text(); From 7798e6638bed5f37b801cced0ee4110fe7d5131b Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Wed, 24 Sep 2025 21:55:57 -0700 Subject: [PATCH 02/43] Implement NODE_USE_SYSTEM_CA with --use-system-ca CLI flag (#22441) ### What does this PR do? Resume work on https://github.com/oven-sh/bun/pull/21898 ### How did you verify your code works? Manually tested on MacOS, Windows 11 and Ubuntu 25.04. CI changes are needed for the tests --------- Co-authored-by: Claude Bot Co-authored-by: Claude Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jarred Sumner --- cmake/targets/BuildBun.cmake | 2 + .../bun-usockets/src/crypto/root_certs.cpp | 129 +++++- .../src/crypto/root_certs_darwin.cpp | 431 ++++++++++++++++++ .../src/crypto/root_certs_header.h | 1 + .../src/crypto/root_certs_linux.cpp | 170 +++++++ .../src/crypto/root_certs_platform.h | 18 + .../src/crypto/root_certs_windows.cpp | 53 +++ src/bun.js/bindings/NodeTLS.cpp | 63 ++- src/bun.js/bindings/NodeTLS.h | 1 + src/bun.zig | 8 + src/cli/Arguments.zig | 34 ++ src/js/node/tls.ts | 15 +- .../fetch/node-use-system-ca-complete.test.ts | 238 ++++++++++ test/js/bun/fetch/node-use-system-ca.test.ts | 255 +++++++++++ ...-get-ca-certificates-node-use-system-ca.js | 29 ++ .../test-tls-get-ca-certificates-system.js | 32 ++ .../node/tls/test-node-extra-ca-certs.test.ts | 94 ++++ test/js/node/tls/test-system-ca-https.test.ts | 149 ++++++ test/js/node/tls/test-use-system-ca.test.ts | 69 +++ test/no-validate-exceptions.txt | 1 + 20 files changed, 1782 insertions(+), 10 deletions(-) create mode 100644 packages/bun-usockets/src/crypto/root_certs_darwin.cpp create mode 100644 packages/bun-usockets/src/crypto/root_certs_linux.cpp create mode 100644 packages/bun-usockets/src/crypto/root_certs_platform.h create mode 100644 packages/bun-usockets/src/crypto/root_certs_windows.cpp create mode 100644 test/js/bun/fetch/node-use-system-ca-complete.test.ts create mode 100644 test/js/bun/fetch/node-use-system-ca.test.ts create mode 100644 test/js/node/test/parallel/test-tls-get-ca-certificates-node-use-system-ca.js create mode 100644 test/js/node/test/parallel/test-tls-get-ca-certificates-system.js create mode 100644 test/js/node/tls/test-node-extra-ca-certs.test.ts create mode 100644 test/js/node/tls/test-system-ca-https.test.ts create mode 100644 test/js/node/tls/test-use-system-ca.test.ts diff --git a/cmake/targets/BuildBun.cmake b/cmake/targets/BuildBun.cmake index a12cb54d6e..ac6104c398 100644 --- a/cmake/targets/BuildBun.cmake +++ b/cmake/targets/BuildBun.cmake @@ -969,6 +969,7 @@ if(WIN32) /delayload:WSOCK32.dll /delayload:ADVAPI32.dll /delayload:IPHLPAPI.dll + /delayload:CRYPT32.dll ) endif() endif() @@ -1188,6 +1189,7 @@ if(WIN32) ntdll userenv dbghelp + crypt32 wsock32 # ws2_32 required by TransmitFile aka sendfile on windows delayimp.lib ) diff --git a/packages/bun-usockets/src/crypto/root_certs.cpp b/packages/bun-usockets/src/crypto/root_certs.cpp index ba935a5a0c..80f8dd0138 100644 --- a/packages/bun-usockets/src/crypto/root_certs.cpp +++ b/packages/bun-usockets/src/crypto/root_certs.cpp @@ -6,10 +6,46 @@ #include #include #include "./default_ciphers.h" + +// System-specific includes for certificate loading +#include "./root_certs_platform.h" +#ifdef _WIN32 +#include +#include +#else +// Linux/Unix includes +#include +#include +#include +#endif static const int root_certs_size = sizeof(root_certs) / sizeof(root_certs[0]); extern "C" void BUN__warn__extra_ca_load_failed(const char* filename, const char* error_msg); +// Forward declarations for platform-specific functions +// (Actual implementations are in platform-specific files) + +// External variable from Zig CLI arguments +extern "C" bool Bun__Node__UseSystemCA; + +// Helper function to check if system CA should be used +// Checks both CLI flag (--use-system-ca) and environment variable (NODE_USE_SYSTEM_CA=1) +static bool us_should_use_system_ca() { + // Check CLI flag first + if (Bun__Node__UseSystemCA) { + return true; + } + + // Check environment variable + const char *use_system_ca = getenv("NODE_USE_SYSTEM_CA"); + return use_system_ca && strcmp(use_system_ca, "1") == 0; +} + +// Platform-specific system certificate loading implementations are separated: +// - macOS: root_certs_darwin.cpp (Security framework with dynamic loading) +// - Windows: root_certs_windows.cpp (Windows CryptoAPI) +// - Linux/Unix: us_load_system_certificates_linux() below + // This callback is used to avoid the default passphrase callback in OpenSSL // which will typically prompt for the passphrase. The prompting is designed // for the OpenSSL CLI, but works poorly for this case because it involves @@ -101,7 +137,8 @@ end: static void us_internal_init_root_certs( X509 *root_cert_instances[root_certs_size], - STACK_OF(X509) *&root_extra_cert_instances) { + STACK_OF(X509) *&root_extra_cert_instances, + STACK_OF(X509) *&root_system_cert_instances) { static std::atomic_flag root_cert_instances_lock = ATOMIC_FLAG_INIT; static std::atomic_bool root_cert_instances_initialized = 0; @@ -123,6 +160,17 @@ static void us_internal_init_root_certs( if (extra_certs && extra_certs[0]) { root_extra_cert_instances = us_ssl_ctx_load_all_certs_from_file(extra_certs); } + + // load system certificates if NODE_USE_SYSTEM_CA=1 + if (us_should_use_system_ca()) { +#ifdef __APPLE__ + us_load_system_certificates_macos(&root_system_cert_instances); +#elif defined(_WIN32) + us_load_system_certificates_windows(&root_system_cert_instances); +#else + us_load_system_certificates_linux(&root_system_cert_instances); +#endif + } } atomic_flag_clear_explicit(&root_cert_instances_lock, @@ -137,12 +185,15 @@ extern "C" int us_internal_raw_root_certs(struct us_cert_string_t **out) { struct us_default_ca_certificates { X509 *root_cert_instances[root_certs_size]; STACK_OF(X509) *root_extra_cert_instances; + STACK_OF(X509) *root_system_cert_instances; }; us_default_ca_certificates* us_get_default_ca_certificates() { - static us_default_ca_certificates default_ca_certificates = {{NULL}, NULL}; + static us_default_ca_certificates default_ca_certificates = {{NULL}, NULL, NULL}; - us_internal_init_root_certs(default_ca_certificates.root_cert_instances, default_ca_certificates.root_extra_cert_instances); + us_internal_init_root_certs(default_ca_certificates.root_cert_instances, + default_ca_certificates.root_extra_cert_instances, + default_ca_certificates.root_system_cert_instances); return &default_ca_certificates; } @@ -151,20 +202,33 @@ STACK_OF(X509) *us_get_root_extra_cert_instances() { return us_get_default_ca_certificates()->root_extra_cert_instances; } +STACK_OF(X509) *us_get_root_system_cert_instances() { + if (!us_should_use_system_ca()) + return NULL; + // Ensure single-path initialization via us_internal_init_root_certs + auto certs = us_get_default_ca_certificates(); + return certs->root_system_cert_instances; +} + extern "C" X509_STORE *us_get_default_ca_store() { X509_STORE *store = X509_STORE_new(); if (store == NULL) { return NULL; } - if (!X509_STORE_set_default_paths(store)) { - X509_STORE_free(store); - return NULL; + // Only load system default paths when NODE_USE_SYSTEM_CA=1 + // Otherwise, rely on bundled certificates only (like Node.js behavior) + if (us_should_use_system_ca()) { + if (!X509_STORE_set_default_paths(store)) { + X509_STORE_free(store); + return NULL; + } } us_default_ca_certificates *default_ca_certificates = us_get_default_ca_certificates(); X509** root_cert_instances = default_ca_certificates->root_cert_instances; STACK_OF(X509) *root_extra_cert_instances = default_ca_certificates->root_extra_cert_instances; + STACK_OF(X509) *root_system_cert_instances = default_ca_certificates->root_system_cert_instances; // load all root_cert_instances on the default ca store for (size_t i = 0; i < root_certs_size; i++) { @@ -183,8 +247,59 @@ extern "C" X509_STORE *us_get_default_ca_store() { } } + if (us_should_use_system_ca() && root_system_cert_instances) { + for (int i = 0; i < sk_X509_num(root_system_cert_instances); i++) { + X509 *cert = sk_X509_value(root_system_cert_instances, i); + X509_up_ref(cert); + X509_STORE_add_cert(store, cert); + } + } + return store; } extern "C" const char *us_get_default_ciphers() { return DEFAULT_CIPHER_LIST; -} \ No newline at end of file +} + +// Platform-specific implementations for loading system certificates + +#if defined(_WIN32) +// Windows implementation is split to avoid header conflicts: +// - root_certs_windows.cpp loads raw certificate data (uses Windows headers) +// - This file converts raw data to X509* (uses OpenSSL headers) + +#include + +struct RawCertificate { + std::vector data; +}; + +// Defined in root_certs_windows.cpp - loads raw certificate data +extern void us_load_system_certificates_windows_raw( + std::vector& raw_certs); + +// Convert raw Windows certificates to OpenSSL X509 format +void us_load_system_certificates_windows(STACK_OF(X509) **system_certs) { + *system_certs = sk_X509_new_null(); + if (*system_certs == NULL) { + return; + } + + // Load raw certificates from Windows stores + std::vector raw_certs; + us_load_system_certificates_windows_raw(raw_certs); + + // Convert each raw certificate to X509 + for (const auto& raw_cert : raw_certs) { + const unsigned char* data = raw_cert.data.data(); + X509* x509_cert = d2i_X509(NULL, &data, raw_cert.data.size()); + if (x509_cert != NULL) { + sk_X509_push(*system_certs, x509_cert); + } + } +} + +#else +// Linux and other Unix-like systems - implementation is in root_certs_linux.cpp +extern "C" void us_load_system_certificates_linux(STACK_OF(X509) **system_certs); +#endif \ No newline at end of file diff --git a/packages/bun-usockets/src/crypto/root_certs_darwin.cpp b/packages/bun-usockets/src/crypto/root_certs_darwin.cpp new file mode 100644 index 0000000000..c9256a828c --- /dev/null +++ b/packages/bun-usockets/src/crypto/root_certs_darwin.cpp @@ -0,0 +1,431 @@ +#ifdef __APPLE__ + +#include +#include +#include +#include +#include +#include + +// Security framework types and constants - dynamically loaded +typedef struct OpaqueSecCertificateRef* SecCertificateRef; +typedef struct OpaqueSecTrustRef* SecTrustRef; +typedef struct OpaqueSecPolicyRef* SecPolicyRef; +typedef int32_t OSStatus; +typedef uint32_t SecTrustSettingsDomain; + +// Security framework constants +enum { + errSecSuccess = 0, + errSecItemNotFound = -25300, +}; + +// Trust settings domains +enum { + kSecTrustSettingsDomainUser = 0, + kSecTrustSettingsDomainAdmin = 1, + kSecTrustSettingsDomainSystem = 2, +}; + +// Trust status enumeration +enum class TrustStatus { + TRUSTED, + DISTRUSTED, + UNSPECIFIED +}; + +// Dynamic Security framework loader +class SecurityFramework { +public: + void* handle; + void* cf_handle; + + // Core Foundation constants + CFStringRef kSecClass; + CFStringRef kSecClassCertificate; + CFStringRef kSecMatchLimit; + CFStringRef kSecMatchLimitAll; + CFStringRef kSecReturnRef; + CFStringRef kSecMatchTrustedOnly; + CFBooleanRef kCFBooleanTrue; + CFAllocatorRef kCFAllocatorDefault; + CFArrayCallBacks* kCFTypeArrayCallBacks; + CFDictionaryKeyCallBacks* kCFTypeDictionaryKeyCallBacks; + CFDictionaryValueCallBacks* kCFTypeDictionaryValueCallBacks; + + // Core Foundation function pointers + CFMutableArrayRef (*CFArrayCreateMutable)(CFAllocatorRef allocator, CFIndex capacity, const CFArrayCallBacks *callBacks); + CFArrayRef (*CFArrayCreate)(CFAllocatorRef allocator, const void **values, CFIndex numValues, const CFArrayCallBacks *callBacks); + void (*CFArraySetValueAtIndex)(CFMutableArrayRef theArray, CFIndex idx, const void *value); + const void* (*CFArrayGetValueAtIndex)(CFArrayRef theArray, CFIndex idx); + CFIndex (*CFArrayGetCount)(CFArrayRef theArray); + void (*CFRelease)(CFTypeRef cf); + CFDictionaryRef (*CFDictionaryCreate)(CFAllocatorRef allocator, const void **keys, const void **values, CFIndex numValues, const CFDictionaryKeyCallBacks *keyCallBacks, const CFDictionaryValueCallBacks *valueCallBacks); + const UInt8* (*CFDataGetBytePtr)(CFDataRef theData); + CFIndex (*CFDataGetLength)(CFDataRef theData); + + // Security framework function pointers + OSStatus (*SecItemCopyMatching)(CFDictionaryRef query, CFTypeRef *result); + CFDataRef (*SecCertificateCopyData)(SecCertificateRef certificate); + OSStatus (*SecTrustCreateWithCertificates)(CFArrayRef certificates, CFArrayRef policies, SecTrustRef *trust); + SecPolicyRef (*SecPolicyCreateSSL)(Boolean server, CFStringRef hostname); + Boolean (*SecTrustEvaluateWithError)(SecTrustRef trust, CFErrorRef *error); + OSStatus (*SecTrustSettingsCopyTrustSettings)(SecCertificateRef certRef, SecTrustSettingsDomain domain, CFArrayRef *trustSettings); + + SecurityFramework() : handle(nullptr), cf_handle(nullptr), + kSecClass(nullptr), kSecClassCertificate(nullptr), + kSecMatchLimit(nullptr), kSecMatchLimitAll(nullptr), + kSecReturnRef(nullptr), kSecMatchTrustedOnly(nullptr), kCFBooleanTrue(nullptr), + kCFAllocatorDefault(nullptr), kCFTypeArrayCallBacks(nullptr), + kCFTypeDictionaryKeyCallBacks(nullptr), kCFTypeDictionaryValueCallBacks(nullptr), + CFArrayCreateMutable(nullptr), CFArrayCreate(nullptr), + CFArraySetValueAtIndex(nullptr), CFArrayGetValueAtIndex(nullptr), + CFArrayGetCount(nullptr), CFRelease(nullptr), + CFDictionaryCreate(nullptr), CFDataGetBytePtr(nullptr), CFDataGetLength(nullptr), + SecItemCopyMatching(nullptr), SecCertificateCopyData(nullptr), + SecTrustCreateWithCertificates(nullptr), SecPolicyCreateSSL(nullptr), + SecTrustEvaluateWithError(nullptr), SecTrustSettingsCopyTrustSettings(nullptr) {} + + ~SecurityFramework() { + if (handle) { + dlclose(handle); + } + if (cf_handle) { + dlclose(cf_handle); + } + } + + bool load() { + if (handle && cf_handle) return true; // Already loaded + + // Load CoreFoundation framework + cf_handle = dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", RTLD_LAZY | RTLD_LOCAL); + if (!cf_handle) { + fprintf(stderr, "Failed to load CoreFoundation framework: %s\n", dlerror()); + return false; + } + + // Load Security framework + handle = dlopen("/System/Library/Frameworks/Security.framework/Security", RTLD_LAZY | RTLD_LOCAL); + if (!handle) { + fprintf(stderr, "Failed to load Security framework: %s\n", dlerror()); + dlclose(cf_handle); + cf_handle = nullptr; + return false; + } + + // Load constants and functions + if (!load_constants()) { + if (handle) { + dlclose(handle); + handle = nullptr; + } + if (cf_handle) { + dlclose(cf_handle); + cf_handle = nullptr; + } + return false; + } + + if (!load_functions()) { + if (handle) { + dlclose(handle); + handle = nullptr; + } + if (cf_handle) { + dlclose(cf_handle); + cf_handle = nullptr; + } + return false; + } + + return true; + } + +private: + bool load_constants() { + // Load Security framework constants + void* ptr = dlsym(handle, "kSecClass"); + if (!ptr) { fprintf(stderr, "DEBUG: kSecClass not found\n"); return false; } + kSecClass = *(CFStringRef*)ptr; + + ptr = dlsym(handle, "kSecClassCertificate"); + if (!ptr) { fprintf(stderr, "DEBUG: kSecClassCertificate not found\n"); return false; } + kSecClassCertificate = *(CFStringRef*)ptr; + + ptr = dlsym(handle, "kSecMatchLimit"); + if (!ptr) { fprintf(stderr, "DEBUG: kSecMatchLimit not found\n"); return false; } + kSecMatchLimit = *(CFStringRef*)ptr; + + ptr = dlsym(handle, "kSecMatchLimitAll"); + if (!ptr) { fprintf(stderr, "DEBUG: kSecMatchLimitAll not found\n"); return false; } + kSecMatchLimitAll = *(CFStringRef*)ptr; + + ptr = dlsym(handle, "kSecReturnRef"); + if (!ptr) { fprintf(stderr, "DEBUG: kSecReturnRef not found\n"); return false; } + kSecReturnRef = *(CFStringRef*)ptr; + + ptr = dlsym(handle, "kSecMatchTrustedOnly"); + if (!ptr) { fprintf(stderr, "DEBUG: kSecMatchTrustedOnly not found\n"); return false; } + kSecMatchTrustedOnly = *(CFStringRef*)ptr; + + // Load CoreFoundation constants + ptr = dlsym(cf_handle, "kCFBooleanTrue"); + if (!ptr) { fprintf(stderr, "DEBUG: kCFBooleanTrue not found\n"); return false; } + kCFBooleanTrue = *(CFBooleanRef*)ptr; + + ptr = dlsym(cf_handle, "kCFAllocatorDefault"); + if (!ptr) { fprintf(stderr, "DEBUG: kCFAllocatorDefault not found\n"); return false; } + kCFAllocatorDefault = *(CFAllocatorRef*)ptr; + + ptr = dlsym(cf_handle, "kCFTypeArrayCallBacks"); + if (!ptr) { fprintf(stderr, "DEBUG: kCFTypeArrayCallBacks not found\n"); return false; } + kCFTypeArrayCallBacks = (CFArrayCallBacks*)ptr; + + ptr = dlsym(cf_handle, "kCFTypeDictionaryKeyCallBacks"); + if (!ptr) { fprintf(stderr, "DEBUG: kCFTypeDictionaryKeyCallBacks not found\n"); return false; } + kCFTypeDictionaryKeyCallBacks = (CFDictionaryKeyCallBacks*)ptr; + + ptr = dlsym(cf_handle, "kCFTypeDictionaryValueCallBacks"); + if (!ptr) { fprintf(stderr, "DEBUG: kCFTypeDictionaryValueCallBacks not found\n"); return false; } + kCFTypeDictionaryValueCallBacks = (CFDictionaryValueCallBacks*)ptr; + + return true; + } + + bool load_functions() { + // Load CoreFoundation functions + CFArrayCreateMutable = (CFMutableArrayRef (*)(CFAllocatorRef, CFIndex, const CFArrayCallBacks*))dlsym(cf_handle, "CFArrayCreateMutable"); + CFArrayCreate = (CFArrayRef (*)(CFAllocatorRef, const void**, CFIndex, const CFArrayCallBacks*))dlsym(cf_handle, "CFArrayCreate"); + CFArraySetValueAtIndex = (void (*)(CFMutableArrayRef, CFIndex, const void*))dlsym(cf_handle, "CFArraySetValueAtIndex"); + CFArrayGetValueAtIndex = (const void* (*)(CFArrayRef, CFIndex))dlsym(cf_handle, "CFArrayGetValueAtIndex"); + CFArrayGetCount = (CFIndex (*)(CFArrayRef))dlsym(cf_handle, "CFArrayGetCount"); + CFRelease = (void (*)(CFTypeRef))dlsym(cf_handle, "CFRelease"); + CFDictionaryCreate = (CFDictionaryRef (*)(CFAllocatorRef, const void**, const void**, CFIndex, const CFDictionaryKeyCallBacks*, const CFDictionaryValueCallBacks*))dlsym(cf_handle, "CFDictionaryCreate"); + CFDataGetBytePtr = (const UInt8* (*)(CFDataRef))dlsym(cf_handle, "CFDataGetBytePtr"); + CFDataGetLength = (CFIndex (*)(CFDataRef))dlsym(cf_handle, "CFDataGetLength"); + + // Load Security framework functions + SecItemCopyMatching = (OSStatus (*)(CFDictionaryRef, CFTypeRef*))dlsym(handle, "SecItemCopyMatching"); + SecCertificateCopyData = (CFDataRef (*)(SecCertificateRef))dlsym(handle, "SecCertificateCopyData"); + SecTrustCreateWithCertificates = (OSStatus (*)(CFArrayRef, CFArrayRef, SecTrustRef*))dlsym(handle, "SecTrustCreateWithCertificates"); + SecPolicyCreateSSL = (SecPolicyRef (*)(Boolean, CFStringRef))dlsym(handle, "SecPolicyCreateSSL"); + SecTrustEvaluateWithError = (Boolean (*)(SecTrustRef, CFErrorRef*))dlsym(handle, "SecTrustEvaluateWithError"); + SecTrustSettingsCopyTrustSettings = (OSStatus (*)(SecCertificateRef, SecTrustSettingsDomain, CFArrayRef*))dlsym(handle, "SecTrustSettingsCopyTrustSettings"); + + return CFArrayCreateMutable && CFArrayCreate && CFArraySetValueAtIndex && + CFArrayGetValueAtIndex && CFArrayGetCount && CFRelease && + CFDictionaryCreate && CFDataGetBytePtr && CFDataGetLength && + SecItemCopyMatching && SecCertificateCopyData && + SecTrustCreateWithCertificates && SecPolicyCreateSSL && + SecTrustEvaluateWithError && SecTrustSettingsCopyTrustSettings; + } +}; + +// Global instance for dynamic loading +static std::atomic g_security_framework{nullptr}; + +static SecurityFramework* get_security_framework() { + SecurityFramework* framework = g_security_framework.load(); + if (!framework) { + SecurityFramework* new_framework = new SecurityFramework(); + if (new_framework->load()) { + SecurityFramework* expected = nullptr; + if (g_security_framework.compare_exchange_strong(expected, new_framework)) { + framework = new_framework; + } else { + delete new_framework; + framework = expected; + } + } else { + delete new_framework; + framework = nullptr; + } + } + return framework; +} + +// Helper function to determine if a certificate is self-issued +static bool is_certificate_self_issued(X509* cert) { + X509_NAME* subject = X509_get_subject_name(cert); + X509_NAME* issuer = X509_get_issuer_name(cert); + + return subject && issuer && X509_NAME_cmp(subject, issuer) == 0; +} + +// Validate certificate trust using Security framework +static bool is_certificate_trust_valid(SecurityFramework* security, SecCertificateRef cert_ref) { + CFMutableArrayRef subj_certs = security->CFArrayCreateMutable(nullptr, 1, security->kCFTypeArrayCallBacks); + if (!subj_certs) return false; + + security->CFArraySetValueAtIndex(subj_certs, 0, cert_ref); + + SecPolicyRef policy = security->SecPolicyCreateSSL(true, nullptr); + if (!policy) { + security->CFRelease(subj_certs); + return false; + } + + CFArrayRef policies = security->CFArrayCreate(nullptr, (const void**)&policy, 1, security->kCFTypeArrayCallBacks); + if (!policies) { + security->CFRelease(policy); + security->CFRelease(subj_certs); + return false; + } + + SecTrustRef sec_trust = nullptr; + OSStatus ortn = security->SecTrustCreateWithCertificates(subj_certs, policies, &sec_trust); + + bool result = false; + if (ortn == errSecSuccess && sec_trust) { + result = security->SecTrustEvaluateWithError(sec_trust, nullptr); + } + + // Cleanup + if (sec_trust) security->CFRelease(sec_trust); + security->CFRelease(policies); + security->CFRelease(policy); + security->CFRelease(subj_certs); + + return result; +} + +// Check trust settings for policy (simplified version) +static TrustStatus is_trust_settings_trusted_for_policy(SecurityFramework* security, CFArrayRef trust_settings, bool is_self_issued) { + if (!trust_settings) { + return TrustStatus::UNSPECIFIED; + } + + // Empty trust settings array means "always trust this certificate" + if (security->CFArrayGetCount(trust_settings) == 0) { + return is_self_issued ? TrustStatus::TRUSTED : TrustStatus::UNSPECIFIED; + } + + // For simplicity, we'll do basic checking here + // A full implementation would parse the trust dictionary entries + return TrustStatus::UNSPECIFIED; +} + +// Check if certificate is trusted for server auth policy +static bool is_certificate_trusted_for_policy(SecurityFramework* security, X509* cert, SecCertificateRef cert_ref) { + bool is_self_issued = is_certificate_self_issued(cert); + bool trust_evaluated = false; + + // Check user trust domain, then admin domain + for (const auto& trust_domain : {kSecTrustSettingsDomainUser, kSecTrustSettingsDomainAdmin, kSecTrustSettingsDomainSystem}) { + CFArrayRef trust_settings = nullptr; + OSStatus err = security->SecTrustSettingsCopyTrustSettings(cert_ref, trust_domain, &trust_settings); + + if (err != errSecSuccess && err != errSecItemNotFound) { + continue; + } + + if (err == errSecSuccess && trust_settings) { + TrustStatus result = is_trust_settings_trusted_for_policy(security, trust_settings, is_self_issued); + security->CFRelease(trust_settings); + + if (result == TrustStatus::TRUSTED) { + return true; + } else if (result == TrustStatus::DISTRUSTED) { + return false; + } + } + + // If no trust settings and we haven't evaluated trust yet, check trust validity + if (!trust_settings && !trust_evaluated) { + if (is_certificate_trust_valid(security, cert_ref)) { + return true; + } + trust_evaluated = true; + } + } + + return false; +} + +// Main function to load system certificates on macOS +extern "C" void us_load_system_certificates_macos(STACK_OF(X509) **system_certs) { + *system_certs = sk_X509_new_null(); + if (!*system_certs) { + return; + } + + SecurityFramework* security = get_security_framework(); + if (!security) { + return; // Fail silently + } + + // Create search dictionary for certificates + CFTypeRef search_keys[] = { + security->kSecClass, + security->kSecMatchLimit, + security->kSecReturnRef, + security->kSecMatchTrustedOnly, + }; + CFTypeRef search_values[] = { + security->kSecClassCertificate, + security->kSecMatchLimitAll, + security->kCFBooleanTrue, + security->kCFBooleanTrue, + }; + + CFDictionaryRef search = security->CFDictionaryCreate( + security->kCFAllocatorDefault, + search_keys, + search_values, + 4, + security->kCFTypeDictionaryKeyCallBacks, + security->kCFTypeDictionaryValueCallBacks + ); + + if (!search) { + return; + } + + CFArrayRef certificates = nullptr; + OSStatus status = security->SecItemCopyMatching(search, (CFTypeRef*)&certificates); + security->CFRelease(search); + + if (status != errSecSuccess || !certificates) { + return; + } + + CFIndex count = security->CFArrayGetCount(certificates); + + for (CFIndex i = 0; i < count; ++i) { + SecCertificateRef cert_ref = (SecCertificateRef)security->CFArrayGetValueAtIndex(certificates, i); + if (!cert_ref) continue; + + // Get certificate data + CFDataRef cert_data = security->SecCertificateCopyData(cert_ref); + if (!cert_data) continue; + + // Convert to X509 + const unsigned char* data_ptr = security->CFDataGetBytePtr(cert_data); + long data_len = security->CFDataGetLength(cert_data); + X509* x509_cert = d2i_X509(nullptr, &data_ptr, data_len); + security->CFRelease(cert_data); + + if (!x509_cert) continue; + + // Only consider CA certificates + if (X509_check_ca(x509_cert) == 1 && + is_certificate_trusted_for_policy(security, x509_cert, cert_ref)) { + sk_X509_push(*system_certs, x509_cert); + } else { + X509_free(x509_cert); + } + } + + security->CFRelease(certificates); +} + +// Cleanup function for Security framework +extern "C" void us_cleanup_security_framework() { + SecurityFramework* framework = g_security_framework.exchange(nullptr); + if (framework) { + delete framework; + } +} + +#endif // __APPLE__ \ No newline at end of file diff --git a/packages/bun-usockets/src/crypto/root_certs_header.h b/packages/bun-usockets/src/crypto/root_certs_header.h index 2a10adf077..0d95a6b584 100644 --- a/packages/bun-usockets/src/crypto/root_certs_header.h +++ b/packages/bun-usockets/src/crypto/root_certs_header.h @@ -5,6 +5,7 @@ #define CPPDECL extern "C" STACK_OF(X509) *us_get_root_extra_cert_instances(); +STACK_OF(X509) *us_get_root_system_cert_instances(); #else #define CPPDECL extern diff --git a/packages/bun-usockets/src/crypto/root_certs_linux.cpp b/packages/bun-usockets/src/crypto/root_certs_linux.cpp new file mode 100644 index 0000000000..8a54e98ad9 --- /dev/null +++ b/packages/bun-usockets/src/crypto/root_certs_linux.cpp @@ -0,0 +1,170 @@ +#ifndef _WIN32 +#ifndef __APPLE__ + +#include +#include +#include +#include +#include +#include +#include +#include + +extern "C" void BUN__warn__extra_ca_load_failed(const char* filename, const char* error_msg); + +// Helper function to load certificates from a directory +static void load_certs_from_directory(const char* dir_path, STACK_OF(X509)* cert_stack) { + DIR* dir = opendir(dir_path); + if (!dir) { + return; + } + + struct dirent* entry; + while ((entry = readdir(dir)) != NULL) { + // Skip . and .. + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) { + continue; + } + + // Check if file has .crt, .pem, or .cer extension + const char* ext = strrchr(entry->d_name, '.'); + if (!ext || (strcmp(ext, ".crt") != 0 && strcmp(ext, ".pem") != 0 && strcmp(ext, ".cer") != 0)) { + continue; + } + + // Build full path + char filepath[PATH_MAX]; + snprintf(filepath, sizeof(filepath), "%s/%s", dir_path, entry->d_name); + + // Try to load certificate + FILE* file = fopen(filepath, "r"); + if (file) { + X509* cert = PEM_read_X509(file, NULL, NULL, NULL); + fclose(file); + + if (cert) { + if (!sk_X509_push(cert_stack, cert)) { + X509_free(cert); + } + } + } + } + + closedir(dir); +} + +// Helper function to load certificates from a bundle file +static void load_certs_from_bundle(const char* bundle_path, STACK_OF(X509)* cert_stack) { + FILE* file = fopen(bundle_path, "r"); + if (!file) { + return; + } + + X509* cert; + while ((cert = PEM_read_X509(file, NULL, NULL, NULL)) != NULL) { + if (!sk_X509_push(cert_stack, cert)) { + X509_free(cert); + break; + } + } + ERR_clear_error(); + + fclose(file); +} + +// Main function to load system certificates on Linux and other Unix-like systems +extern "C" void us_load_system_certificates_linux(STACK_OF(X509) **system_certs) { + *system_certs = sk_X509_new_null(); + if (*system_certs == NULL) { + return; + } + + // First check environment variables (same as Node.js and OpenSSL) + const char* ssl_cert_file = getenv("SSL_CERT_FILE"); + const char* ssl_cert_dir = getenv("SSL_CERT_DIR"); + + // If SSL_CERT_FILE is set, load from it + if (ssl_cert_file && strlen(ssl_cert_file) > 0) { + load_certs_from_bundle(ssl_cert_file, *system_certs); + } + + // If SSL_CERT_DIR is set, load from each directory (colon-separated) + if (ssl_cert_dir && strlen(ssl_cert_dir) > 0) { + char* dir_copy = strdup(ssl_cert_dir); + if (dir_copy) { + char* token = strtok(dir_copy, ":"); + while (token != NULL) { + // Skip empty tokens + if (strlen(token) > 0) { + load_certs_from_directory(token, *system_certs); + } + token = strtok(NULL, ":"); + } + free(dir_copy); + } + } + + // If environment variables were set, use only those (even if they yield zero certs) + if (ssl_cert_file || ssl_cert_dir) { + return; + } + + // Otherwise, load certificates from standard Linux/Unix paths + // These are the common locations for system certificates + + // Common certificate bundle locations (single file with multiple certs) + // These paths are based on common Linux distributions and OpenSSL defaults + static const char* bundle_paths[] = { + "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu/Gentoo + "/etc/pki/tls/certs/ca-bundle.crt", // Fedora/RHEL 6 + "/etc/ssl/ca-bundle.pem", // OpenSUSE + "/etc/pki/tls/cert.pem", // Fedora/RHEL 7+ + "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", // CentOS/RHEL 7+ + "/etc/ssl/cert.pem", // Alpine Linux, macOS OpenSSL + "/usr/local/etc/openssl/cert.pem", // Homebrew OpenSSL on macOS + "/usr/local/share/ca-certificates/ca-certificates.crt", // Custom CA installs + NULL + }; + + // Common certificate directory locations (multiple files) + // Note: OpenSSL expects hashed symlinks in directories (c_rehash format) + static const char* dir_paths[] = { + "/etc/ssl/certs", // Common location (Debian/Ubuntu with hashed links) + "/etc/pki/tls/certs", // RHEL/Fedora + "/usr/share/ca-certificates", // Debian/Ubuntu (original certs, not hashed) + "/usr/local/share/certs", // FreeBSD + "/etc/openssl/certs", // NetBSD + "/var/ssl/certs", // AIX + "/usr/local/etc/openssl/certs", // Homebrew OpenSSL on macOS + "/System/Library/OpenSSL/certs", // macOS system OpenSSL (older versions) + NULL + }; + + // Try loading from bundle files first + for (const char** path = bundle_paths; *path != NULL; path++) { + load_certs_from_bundle(*path, *system_certs); + } + + // Then try loading from directories + for (const char** path = dir_paths; *path != NULL; path++) { + load_certs_from_directory(*path, *system_certs); + } + + // Also check NODE_EXTRA_CA_CERTS environment variable + const char* extra_ca_certs = getenv("NODE_EXTRA_CA_CERTS"); + if (extra_ca_certs && strlen(extra_ca_certs) > 0) { + FILE* file = fopen(extra_ca_certs, "r"); + if (file) { + X509* cert; + while ((cert = PEM_read_X509(file, NULL, NULL, NULL)) != NULL) { + sk_X509_push(*system_certs, cert); + } + fclose(file); + } else { + BUN__warn__extra_ca_load_failed(extra_ca_certs, "Failed to open file"); + } + } +} + +#endif // !__APPLE__ +#endif // !_WIN32 \ No newline at end of file diff --git a/packages/bun-usockets/src/crypto/root_certs_platform.h b/packages/bun-usockets/src/crypto/root_certs_platform.h new file mode 100644 index 0000000000..e357b63ffb --- /dev/null +++ b/packages/bun-usockets/src/crypto/root_certs_platform.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +// Platform-specific certificate loading functions +extern "C" { + +// Load system certificates for the current platform +void us_load_system_certificates_linux(STACK_OF(X509) **system_certs); +void us_load_system_certificates_macos(STACK_OF(X509) **system_certs); +void us_load_system_certificates_windows(STACK_OF(X509) **system_certs); + +// Platform-specific cleanup functions +#ifdef __APPLE__ +void us_cleanup_security_framework(); +#endif + +} \ No newline at end of file diff --git a/packages/bun-usockets/src/crypto/root_certs_windows.cpp b/packages/bun-usockets/src/crypto/root_certs_windows.cpp new file mode 100644 index 0000000000..1015a282bf --- /dev/null +++ b/packages/bun-usockets/src/crypto/root_certs_windows.cpp @@ -0,0 +1,53 @@ +#ifdef _WIN32 + +#include +#include +#include +#include + +// Forward declaration to avoid including OpenSSL headers here +// This prevents conflicts with Windows macros like X509_NAME +// Note: We don't use STACK_OF macro here since we don't have OpenSSL headers + +// Structure to hold raw certificate data +struct RawCertificate { + std::vector data; +}; + +// Helper function to load raw certificates from a Windows certificate store +static void LoadRawCertsFromStore(std::vector& raw_certs, + DWORD store_flags, + const wchar_t* store_name) { + HCERTSTORE cert_store = CertOpenStore( + CERT_STORE_PROV_SYSTEM_W, + 0, + 0, + store_flags | CERT_STORE_READONLY_FLAG, + store_name + ); + + if (cert_store == NULL) { + return; + } + + PCCERT_CONTEXT cert_context = NULL; + while ((cert_context = CertEnumCertificatesInStore(cert_store, cert_context)) != NULL) { + RawCertificate raw_cert; + raw_cert.data.assign(cert_context->pbCertEncoded, + cert_context->pbCertEncoded + cert_context->cbCertEncoded); + raw_certs.push_back(std::move(raw_cert)); + } + + CertCloseStore(cert_store, 0); +} + +// Main function to load raw system certificates on Windows +// Returns certificates as raw DER data to avoid OpenSSL header conflicts +extern void us_load_system_certificates_windows_raw( + std::vector& raw_certs) { + // Load only from ROOT by default + LoadRawCertsFromStore(raw_certs, CERT_SYSTEM_STORE_CURRENT_USER, L"ROOT"); + LoadRawCertsFromStore(raw_certs, CERT_SYSTEM_STORE_LOCAL_MACHINE, L"ROOT"); +} + +#endif // _WIN32 \ No newline at end of file diff --git a/src/bun.js/bindings/NodeTLS.cpp b/src/bun.js/bindings/NodeTLS.cpp index 0fbce49ec9..218c78cd99 100644 --- a/src/bun.js/bindings/NodeTLS.cpp +++ b/src/bun.js/bindings/NodeTLS.cpp @@ -9,6 +9,7 @@ #include "ErrorCode.h" #include "openssl/base.h" #include "openssl/bio.h" +#include "openssl/x509.h" #include "../../packages/bun-usockets/src/crypto/root_certs_header.h" namespace Bun { @@ -44,7 +45,7 @@ JSC_DEFINE_HOST_FUNCTION(getExtraCACertificates, (JSC::JSGlobalObject * globalOb auto size = sk_X509_num(root_extra_cert_instances); if (size < 0) size = 0; // root_extra_cert_instances is nullptr - auto rootCertificates = JSC::JSArray::create(vm, globalObject->arrayStructureForIndexingTypeDuringAllocation(JSC::ArrayWithContiguous), size); + JSC::MarkedArgumentBuffer args; for (auto i = 0; i < size; i++) { BIO* bio = BIO_new(BIO_s_mem()); if (!bio) { @@ -65,10 +66,68 @@ JSC_DEFINE_HOST_FUNCTION(getExtraCACertificates, (JSC::JSGlobalObject * globalOb } auto str = WTF::String::fromUTF8(std::span { bioData, static_cast(bioLen) }); - rootCertificates->putDirectIndex(globalObject, i, JSC::jsString(vm, str)); + args.append(JSC::jsString(vm, str)); BIO_free(bio); } + if (args.hasOverflowed()) { + throwOutOfMemoryError(globalObject, scope); + return {}; + } + + auto rootCertificates = JSC::constructArray(globalObject, static_cast(nullptr), args); + RETURN_IF_EXCEPTION(scope, {}); + + RELEASE_AND_RETURN(scope, JSValue::encode(JSC::objectConstructorFreeze(globalObject, rootCertificates))); +} + +JSC_DEFINE_HOST_FUNCTION(getSystemCACertificates, (JSC::JSGlobalObject * globalObject, JSC::CallFrame* callFrame)) +{ + auto scope = DECLARE_THROW_SCOPE(globalObject->vm()); + VM& vm = globalObject->vm(); + + STACK_OF(X509)* root_system_cert_instances = us_get_root_system_cert_instances(); + + auto size = sk_X509_num(root_system_cert_instances); + if (size < 0) size = 0; // root_system_cert_instances is nullptr + + JSC::MarkedArgumentBuffer args; + for (auto i = 0; i < size; i++) { + BIO* bio = BIO_new(BIO_s_mem()); + if (!bio) { + throwOutOfMemoryError(globalObject, scope); + return {}; + } + X509* cert = sk_X509_value(root_system_cert_instances, i); + if (!cert) { + BIO_free(bio); + continue; + } + if (!PEM_write_bio_X509(bio, cert)) { + BIO_free(bio); + continue; + } + + char* bioData; + long bioLen = BIO_get_mem_data(bio, &bioData); + if (bioLen <= 0) { + BIO_free(bio); + continue; + } + + auto str = WTF::String::fromUTF8(std::span { bioData, static_cast(bioLen) }); + args.append(JSC::jsString(vm, str)); + BIO_free(bio); + } + + if (args.hasOverflowed()) { + throwOutOfMemoryError(globalObject, scope); + return {}; + } + + auto rootCertificates = JSC::constructArray(globalObject, static_cast(nullptr), args); + RETURN_IF_EXCEPTION(scope, {}); + RELEASE_AND_RETURN(scope, JSValue::encode(JSC::objectConstructorFreeze(globalObject, rootCertificates))); } diff --git a/src/bun.js/bindings/NodeTLS.h b/src/bun.js/bindings/NodeTLS.h index 9def4bca54..c8948b6bf9 100644 --- a/src/bun.js/bindings/NodeTLS.h +++ b/src/bun.js/bindings/NodeTLS.h @@ -6,6 +6,7 @@ namespace Bun { BUN_DECLARE_HOST_FUNCTION(Bun__canonicalizeIP); JSC_DECLARE_HOST_FUNCTION(getBundledRootCertificates); JSC_DECLARE_HOST_FUNCTION(getExtraCACertificates); +JSC_DECLARE_HOST_FUNCTION(getSystemCACertificates); JSC_DECLARE_HOST_FUNCTION(getDefaultCiphers); JSC_DECLARE_HOST_FUNCTION(setDefaultCiphers); diff --git a/src/bun.zig b/src/bun.zig index ee17cd40d7..2c2b890429 100644 --- a/src/bun.zig +++ b/src/bun.zig @@ -3782,6 +3782,14 @@ pub fn contains(item: anytype, list: *const std.ArrayListUnmanaged(@TypeOf(item) pub const safety = @import("./safety.zig"); +// Export function to check if --use-system-ca flag is set +pub fn getUseSystemCA(globalObject: *jsc.JSGlobalObject, callFrame: *jsc.CallFrame) error{ JSError, OutOfMemory }!jsc.JSValue { + _ = globalObject; + _ = callFrame; + const Arguments = @import("./cli/Arguments.zig"); + return jsc.JSValue.jsBoolean(Arguments.Bun__Node__UseSystemCA); +} + const CopyFile = @import("./copy_file.zig"); const builtin = @import("builtin"); const std = @import("std"); diff --git a/src/cli/Arguments.zig b/src/cli/Arguments.zig index 9e3bb61b73..f0884f20df 100644 --- a/src/cli/Arguments.zig +++ b/src/cli/Arguments.zig @@ -104,6 +104,9 @@ pub const runtime_params_ = [_]ParamType{ clap.parseParam("--throw-deprecation Determine whether or not deprecation warnings result in errors.") catch unreachable, clap.parseParam("--title Set the process title") catch unreachable, clap.parseParam("--zero-fill-buffers Boolean to force Buffer.allocUnsafe(size) to be zero-filled.") catch unreachable, + clap.parseParam("--use-system-ca Use the system's trusted certificate authorities") catch unreachable, + clap.parseParam("--use-openssl-ca Use OpenSSL's default CA store") catch unreachable, + clap.parseParam("--use-bundled-ca Use bundled CA store") catch unreachable, clap.parseParam("--redis-preconnect Preconnect to $REDIS_URL at startup") catch unreachable, clap.parseParam("--sql-preconnect Preconnect to PostgreSQL at startup") catch unreachable, clap.parseParam("--no-addons Throw an error if process.dlopen is called, and disable export condition \"node-addons\"") catch unreachable, @@ -750,6 +753,33 @@ pub fn parse(allocator: std.mem.Allocator, ctx: Command.Context, comptime cmd: C if (args.flag("--zero-fill-buffers")) { Bun__Node__ZeroFillBuffers = true; } + const use_system_ca = args.flag("--use-system-ca"); + const use_openssl_ca = args.flag("--use-openssl-ca"); + const use_bundled_ca = args.flag("--use-bundled-ca"); + + // Disallow any combination > 1 + if (@as(u8, @intFromBool(use_system_ca)) + @as(u8, @intFromBool(use_openssl_ca)) + @as(u8, @intFromBool(use_bundled_ca)) > 1) { + Output.prettyErrorln("error: choose exactly one of --use-system-ca, --use-openssl-ca, or --use-bundled-ca", .{}); + Global.exit(1); + } + + // CLI overrides env var (NODE_USE_SYSTEM_CA) + if (use_bundled_ca) { + Bun__Node__CAStore = .bundled; + } else if (use_openssl_ca) { + Bun__Node__CAStore = .openssl; + } else if (use_system_ca) { + Bun__Node__CAStore = .system; + } else { + if (bun.getenvZ("NODE_USE_SYSTEM_CA")) |val| { + if (val.len > 0 and val[0] == '1') { + Bun__Node__CAStore = .system; + } + } + } + + // Back-compat boolean used by native code until fully migrated + Bun__Node__UseSystemCA = (Bun__Node__CAStore == .system); } if (opts.port != null and opts.origin == null) { @@ -1255,6 +1285,10 @@ export var Bun__Node__ZeroFillBuffers = false; export var Bun__Node__ProcessNoDeprecation = false; export var Bun__Node__ProcessThrowDeprecation = false; +pub const BunCAStore = enum(u8) { bundled, openssl, system }; +pub export var Bun__Node__CAStore: BunCAStore = .bundled; +pub export var Bun__Node__UseSystemCA = false; + const string = []const u8; const builtin = @import("builtin"); diff --git a/src/js/node/tls.ts b/src/js/node/tls.ts index df0f37fcdc..25bbc96d8e 100644 --- a/src/js/node/tls.ts +++ b/src/js/node/tls.ts @@ -11,6 +11,7 @@ const { Server: NetServer, Socket: NetSocket } = net; const getBundledRootCertificates = $newCppFunction("NodeTLS.cpp", "getBundledRootCertificates", 1); const getExtraCACertificates = $newCppFunction("NodeTLS.cpp", "getExtraCACertificates", 1); +const getSystemCACertificates = $newCppFunction("NodeTLS.cpp", "getSystemCACertificates", 1); const canonicalizeIP = $newCppFunction("NodeTLS.cpp", "Bun__canonicalizeIP", 1); const getTLSDefaultCiphers = $newCppFunction("NodeTLS.cpp", "getDefaultCiphers", 0); @@ -930,6 +931,8 @@ function cacheBundledRootCertificates(): string[] { bundledRootCertificates ||= getBundledRootCertificates() as string[]; return bundledRootCertificates; } +const getUseSystemCA = $newZigFunction("bun.zig", "getUseSystemCA", 0); + let defaultCACertificates: string[] | undefined; function cacheDefaultCACertificates() { if (defaultCACertificates) return defaultCACertificates; @@ -940,6 +943,14 @@ function cacheDefaultCACertificates() { ArrayPrototypePush.$call(defaultCACertificates, bundled[i]); } + // Include system certificates when --use-system-ca is set or NODE_USE_SYSTEM_CA=1 + if (getUseSystemCA() || process.env.NODE_USE_SYSTEM_CA === "1") { + const system = cacheSystemCACertificates(); + for (let i = 0; i < system.length; ++i) { + ArrayPrototypePush.$call(defaultCACertificates, system[i]); + } + } + if (process.env.NODE_EXTRA_CA_CERTS) { const extra = cacheExtraCACertificates(); for (let i = 0; i < extra.length; ++i) { @@ -951,8 +962,10 @@ function cacheDefaultCACertificates() { return defaultCACertificates; } +let systemCACertificates: string[] | undefined; function cacheSystemCACertificates(): string[] { - throw new Error("getCACertificates('system') is not yet implemented in Bun"); + systemCACertificates ||= getSystemCACertificates() as string[]; + return systemCACertificates; } let extraCACertificates: string[] | undefined; diff --git a/test/js/bun/fetch/node-use-system-ca-complete.test.ts b/test/js/bun/fetch/node-use-system-ca-complete.test.ts new file mode 100644 index 0000000000..be07b50e3c --- /dev/null +++ b/test/js/bun/fetch/node-use-system-ca-complete.test.ts @@ -0,0 +1,238 @@ +import { describe, expect, test } from "bun:test"; +import { promises as fs } from "fs"; +import { bunEnv, bunExe, tempDirWithFiles } from "harness"; +import { platform } from "os"; +import { join } from "path"; + +describe("NODE_USE_SYSTEM_CA Complete Implementation", () => { + test("should work with standard HTTPS sites", async () => { + const testDir = tempDirWithFiles("node-use-system-ca-basic", {}); + + const testScript = ` +async function testHttpsRequest() { + try { + const response = await fetch('https://httpbin.org/user-agent'); + console.log('SUCCESS: GitHub request completed with status', response.status); + process.exit(0); + } catch (error) { + console.log('ERROR: HTTPS request failed:', error.message); + process.exit(1); + } +} + +testHttpsRequest(); +`; + + await fs.writeFile(join(testDir, "test.js"), testScript); + + // Test with NODE_USE_SYSTEM_CA=1 + const proc1 = Bun.spawn({ + cmd: [bunExe(), "test.js"], + env: { + ...bunEnv, + NODE_USE_SYSTEM_CA: "1", + }, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout1, stderr1, exitCode1] = await Promise.all([proc1.stdout.text(), proc1.stderr.text(), proc1.exited]); + + expect(exitCode1).toBe(0); + expect(stdout1).toContain("SUCCESS"); + + // Test without NODE_USE_SYSTEM_CA + const proc2 = Bun.spawn({ + cmd: [bunExe(), "test.js"], + env: bunEnv, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout2, stderr2, exitCode2] = await Promise.all([proc2.stdout.text(), proc2.stderr.text(), proc2.exited]); + + expect(exitCode2).toBe(0); + expect(stdout2).toContain("SUCCESS"); + }); + + test("should properly parse NODE_USE_SYSTEM_CA environment variable", async () => { + const testDir = tempDirWithFiles("node-use-system-ca-env-parsing", {}); + + const testScript = ` +const testCases = [ + { env: '1', description: 'string "1"' }, + { env: 'true', description: 'string "true"' }, + { env: '0', description: 'string "0"' }, + { env: 'false', description: 'string "false"' }, + { env: undefined, description: 'undefined' } +]; + +console.log('Testing NODE_USE_SYSTEM_CA environment variable parsing:'); + +for (const testCase of testCases) { + if (testCase.env !== undefined) { + process.env.NODE_USE_SYSTEM_CA = testCase.env; + } else { + delete process.env.NODE_USE_SYSTEM_CA; + } + + const actual = process.env.NODE_USE_SYSTEM_CA; + console.log(\` \${testCase.description}: \${actual || 'undefined'}\`); +} + +console.log('Environment variable parsing test completed successfully'); +process.exit(0); +`; + + await fs.writeFile(join(testDir, "test-env.js"), testScript); + + const proc = Bun.spawn({ + cmd: [bunExe(), "test-env.js"], + env: bunEnv, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + expect(stdout).toContain("Environment variable parsing test completed successfully"); + }); + + test("should handle platform-specific behavior correctly", async () => { + const testDir = tempDirWithFiles("node-use-system-ca-platform", {}); + + const testScript = ` +const { platform } = require('os'); + +console.log(\`Platform: \${platform()}\`); +console.log(\`NODE_USE_SYSTEM_CA: \${process.env.NODE_USE_SYSTEM_CA}\`); + +async function testPlatformBehavior() { + try { + // Test a reliable HTTPS endpoint + const response = await fetch('https://httpbin.org/user-agent'); + const data = await response.json(); + + console.log('SUCCESS: Platform-specific certificate loading working'); + console.log('User-Agent:', data['user-agent']); + + if (platform() === 'darwin' && process.env.NODE_USE_SYSTEM_CA === '1') { + console.log('SUCCESS: macOS Security framework integration should be active'); + } else if (platform() === 'linux' && process.env.NODE_USE_SYSTEM_CA === '1') { + console.log('SUCCESS: Linux system certificate loading should be active'); + } else if (platform() === 'win32' && process.env.NODE_USE_SYSTEM_CA === '1') { + console.log('SUCCESS: Windows certificate store integration should be active'); + } else { + console.log('SUCCESS: Using bundled certificates'); + } + + process.exit(0); + } catch (error) { + console.error('FAILED: Platform test failed:', error.message); + process.exit(1); + } +} + +testPlatformBehavior(); +`; + + await fs.writeFile(join(testDir, "test-platform.js"), testScript); + + const proc = Bun.spawn({ + cmd: [bunExe(), "test-platform.js"], + env: { + ...bunEnv, + NODE_USE_SYSTEM_CA: "1", + }, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + console.log("Platform test output:", stdout); + console.log("Platform test errors:", stderr); + + expect(exitCode).toBe(0); + expect(stdout).toContain("SUCCESS: Platform-specific certificate loading working"); + + if (platform() === "darwin") { + expect(stdout).toContain("macOS Security framework integration should be active"); + } else if (platform() === "linux") { + expect(stdout).toContain("Linux system certificate loading should be active"); + } + }); + + test("should work with TLS connections", async () => { + const testDir = tempDirWithFiles("node-use-system-ca-tls", {}); + + const testScript = ` +const tls = require('tls'); + +async function testTLSConnection() { + return new Promise((resolve, reject) => { + const options = { + host: 'www.google.com', + port: 443, + rejectUnauthorized: true, + }; + + const socket = tls.connect(options, () => { + console.log('SUCCESS: TLS connection established'); + console.log('Certificate authorized:', socket.authorized); + + socket.destroy(); + resolve(); + }); + + socket.on('error', (error) => { + console.error('FAILED: TLS connection failed:', error.message); + reject(error); + }); + + socket.setTimeout(10000, () => { + console.error('FAILED: Connection timeout'); + socket.destroy(); + reject(new Error('Timeout')); + }); + }); +} + +testTLSConnection() + .then(() => { + console.log('TLS test completed successfully'); + process.exit(0); + }) + .catch((error) => { + console.error('TLS test failed:', error.message); + process.exit(1); + }); +`; + + await fs.writeFile(join(testDir, "test-tls.js"), testScript); + + const proc = Bun.spawn({ + cmd: [bunExe(), "test-tls.js"], + env: { + ...bunEnv, + NODE_USE_SYSTEM_CA: "1", + }, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + console.log("TLS test output:", stdout); + + expect(exitCode).toBe(0); + expect(stdout).toContain("SUCCESS: TLS connection established"); + expect(stdout).toContain("TLS test completed successfully"); + }); +}); diff --git a/test/js/bun/fetch/node-use-system-ca.test.ts b/test/js/bun/fetch/node-use-system-ca.test.ts new file mode 100644 index 0000000000..b960372a0a --- /dev/null +++ b/test/js/bun/fetch/node-use-system-ca.test.ts @@ -0,0 +1,255 @@ +import { describe, expect, test } from "bun:test"; +import { promises as fs } from "fs"; +import { bunEnv, bunExe, tempDirWithFiles } from "harness"; +import { join } from "path"; + +// Gate network tests behind environment variable to avoid CI flakes +// TODO: Replace with hermetic local TLS fixtures in a follow-up +const networkTest = process.env.BUN_TEST_ALLOW_NET === "1" ? test : test.skip; + +describe("NODE_USE_SYSTEM_CA", () => { + networkTest("should use system CA when NODE_USE_SYSTEM_CA=1", async () => { + const testDir = tempDirWithFiles("node-use-system-ca", {}); + + // Create a simple test script that tries to make an HTTPS request + const testScript = ` +const https = require('https'); + +async function testHttpsRequest() { + try { + const response = await fetch('https://httpbin.org/get'); + console.log('SUCCESS: HTTPS request completed'); + process.exit(0); + } catch (error) { + console.log('ERROR: HTTPS request failed:', error.message); + process.exit(1); + } +} + +testHttpsRequest(); +`; + + await fs.writeFile(join(testDir, "test-system-ca.js"), testScript); + + // Test with NODE_USE_SYSTEM_CA=1 + const proc1 = Bun.spawn({ + cmd: [bunExe(), "test-system-ca.js"], + env: { + ...bunEnv, + NODE_USE_SYSTEM_CA: "1", + }, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout1, stderr1, exitCode1] = await Promise.all([proc1.stdout.text(), proc1.stderr.text(), proc1.exited]); + + console.log("With NODE_USE_SYSTEM_CA=1:"); + console.log("stdout:", stdout1); + console.log("stderr:", stderr1); + console.log("exitCode:", exitCode1); + + // Test without NODE_USE_SYSTEM_CA (should still work with bundled certs) + const proc2 = Bun.spawn({ + cmd: [bunExe(), "test-system-ca.js"], + env: { + ...bunEnv, + NODE_USE_SYSTEM_CA: undefined, + }, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout2, stderr2, exitCode2] = await Promise.all([proc2.stdout.text(), proc2.stderr.text(), proc2.exited]); + + console.log("\nWithout NODE_USE_SYSTEM_CA:"); + console.log("stdout:", stdout2); + console.log("stderr:", stderr2); + console.log("exitCode:", exitCode2); + + // Both should succeed (system CA and bundled should work for common sites) + expect(exitCode1).toBe(0); + expect(exitCode2).toBe(0); + expect(stdout1).toContain("SUCCESS"); + expect(stdout2).toContain("SUCCESS"); + }); + + test("should validate NODE_USE_SYSTEM_CA environment variable parsing", async () => { + const testDir = tempDirWithFiles("node-use-system-ca-env", {}); + + const testScript = ` +// Test that the environment variable is read correctly +const testCases = [ + { env: '1', expected: true }, + { env: 'true', expected: true }, + { env: '0', expected: false }, + { env: 'false', expected: false }, + { env: undefined, expected: false } +]; + +let allPassed = true; + +for (const testCase of testCases) { + if (testCase.env !== undefined) { + process.env.NODE_USE_SYSTEM_CA = testCase.env; + } else { + delete process.env.NODE_USE_SYSTEM_CA; + } + + // Here we would test the internal function if it was exposed + // For now, we just test that the environment variable is set correctly + const actual = process.env.NODE_USE_SYSTEM_CA; + const passes = (testCase.env === undefined && !actual) || (actual === testCase.env); + + console.log(\`Testing NODE_USE_SYSTEM_CA=\${testCase.env}: \${passes ? 'PASS' : 'FAIL'}\`); + + if (!passes) { + allPassed = false; + } +} + +process.exit(allPassed ? 0 : 1); +`; + + await fs.writeFile(join(testDir, "test-env-parsing.js"), testScript); + + const proc = Bun.spawn({ + cmd: [bunExe(), "test-env-parsing.js"], + env: bunEnv, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + console.log("Environment variable parsing test:"); + console.log("stdout:", stdout); + console.log("stderr:", stderr); + + expect(exitCode).toBe(0); + expect(stdout).toContain("PASS"); + }); + + networkTest( + "should work with Bun.serve and fetch using system certificates", + async () => { + const testDir = tempDirWithFiles("node-use-system-ca-serve", {}); + + const serverScript = ` +const server = Bun.serve({ + port: 0, + fetch(req) { + return new Response('Hello from test server'); + }, +}); + +console.log(\`Server listening on port \${server.port}\`); + +// Keep server alive +await new Promise(() => {}); // Never resolves +`; + + const clientScript = ` +const port = process.argv[2]; + +async function testClient() { + try { + // Test local HTTP first (should work) + const response = await fetch(\`http://localhost:\${port}\`); + const text = await response.text(); + console.log('Local HTTP request successful:', text); + + // Test external HTTPS with system CA + const httpsResponse = await fetch('https://httpbin.org/get'); + console.log('External HTTPS request successful'); + + process.exit(0); + } catch (error) { + console.error('Client request failed:', error.message); + process.exit(1); + } +} + +testClient(); +`; + + await fs.writeFile(join(testDir, "server.js"), serverScript); + await fs.writeFile(join(testDir, "client.js"), clientScript); + + // Start server + const serverProc = Bun.spawn({ + cmd: [bunExe(), "server.js"], + env: { + ...bunEnv, + NODE_USE_SYSTEM_CA: "1", + }, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + // Wait for server to start and get port + let serverPort; + const serverOutput = []; + const reader = serverProc.stdout.getReader(); + + const timeout = setTimeout(() => { + serverProc.kill(); + }, 10000); + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = new TextDecoder().decode(value); + serverOutput.push(chunk); + + const match = chunk.match(/Server listening on port (\d+)/); + if (match) { + serverPort = match[1]; + break; + } + } + } finally { + reader.releaseLock(); + } + + expect(serverPort).toBeDefined(); + console.log("Server started on port:", serverPort); + + // Test client + const clientProc = Bun.spawn({ + cmd: [bunExe(), "client.js", serverPort], + env: { + ...bunEnv, + NODE_USE_SYSTEM_CA: "1", + }, + cwd: testDir, + stdout: "pipe", + stderr: "pipe", + }); + + const [clientStdout, clientStderr, clientExitCode] = await Promise.all([ + clientProc.stdout.text(), + clientProc.stderr.text(), + clientProc.exited, + ]); + + // Clean up server + clearTimeout(timeout); + serverProc.kill(); + + console.log("Client output:", clientStdout); + console.log("Client errors:", clientStderr); + + expect(clientExitCode).toBe(0); + expect(clientStdout).toContain("Local HTTP request successful"); + expect(clientStdout).toContain("External HTTPS request successful"); + }, + 30000, + ); // 30 second timeout for this test +}); diff --git a/test/js/node/test/parallel/test-tls-get-ca-certificates-node-use-system-ca.js b/test/js/node/test/parallel/test-tls-get-ca-certificates-node-use-system-ca.js new file mode 100644 index 0000000000..a591f2e3ec --- /dev/null +++ b/test/js/node/test/parallel/test-tls-get-ca-certificates-node-use-system-ca.js @@ -0,0 +1,29 @@ +'use strict'; +// This tests that NODE_USE_SYSTEM_CA environment variable works the same +// as --use-system-ca flag by comparing certificate counts. + +const common = require('../common'); +if (!common.hasCrypto) common.skip('missing crypto'); + +const tls = require('tls'); +const { spawnSyncAndExitWithoutError } = require('../common/child_process'); + +const systemCerts = tls.getCACertificates('system'); +if (systemCerts.length === 0) { + common.skip('no system certificates available'); +} + +const { child: { stdout: expectedLength } } = spawnSyncAndExitWithoutError(process.execPath, [ + '--use-system-ca', + '-p', + `tls.getCACertificates('default').length`, +], { + env: { ...process.env, NODE_USE_SYSTEM_CA: '0' }, +}); + +spawnSyncAndExitWithoutError(process.execPath, [ + '-p', + `assert.strictEqual(tls.getCACertificates('default').length, ${expectedLength.toString()})`, +], { + env: { ...process.env, NODE_USE_SYSTEM_CA: '1' }, +}); \ No newline at end of file diff --git a/test/js/node/test/parallel/test-tls-get-ca-certificates-system.js b/test/js/node/test/parallel/test-tls-get-ca-certificates-system.js new file mode 100644 index 0000000000..ab320183a1 --- /dev/null +++ b/test/js/node/test/parallel/test-tls-get-ca-certificates-system.js @@ -0,0 +1,32 @@ +'use strict'; +// Flags: --use-system-ca +// This tests that tls.getCACertificates() returns the system +// certificates correctly. + +const common = require('../common'); +if (!common.hasCrypto) common.skip('missing crypto'); + +const assert = require('assert'); +const tls = require('tls'); +const { assertIsCAArray } = require('../common/tls'); + +const systemCerts = tls.getCACertificates('system'); +// Usually Windows come with some certificates installed by default. +// This can't be said about other systems, in that case check that +// at least systemCerts is an array (which may be empty). +if (common.isWindows) { + assertIsCAArray(systemCerts); +} else { + assert(Array.isArray(systemCerts)); +} + +// When --use-system-ca is true, default is a superset of system +// certificates. +const defaultCerts = tls.getCACertificates('default'); +assert(defaultCerts.length >= systemCerts.length); +const defaultSet = new Set(defaultCerts); +const systemSet = new Set(systemCerts); +assert.deepStrictEqual(defaultSet.intersection(systemSet), systemSet); + +// It's cached on subsequent accesses. +assert.strictEqual(systemCerts, tls.getCACertificates('system')); \ No newline at end of file diff --git a/test/js/node/tls/test-node-extra-ca-certs.test.ts b/test/js/node/tls/test-node-extra-ca-certs.test.ts new file mode 100644 index 0000000000..2a3ed201b0 --- /dev/null +++ b/test/js/node/tls/test-node-extra-ca-certs.test.ts @@ -0,0 +1,94 @@ +import { spawn } from "bun"; +import { describe, expect, test } from "bun:test"; +import { bunEnv, bunExe, tempDirWithFiles } from "harness"; +import { join } from "path"; + +describe("NODE_EXTRA_CA_CERTS", () => { + test("loads additional certificates from file", async () => { + // Create a test certificate file + const testCert = `-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAKLdQVPy90WjMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTgwNDEwMDgwNzQ4WhcNMjgwNDA3MDgwNzQ4WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAyOB7tY2Uo2lTNjJgGEhJAVZDWnHbLjbmTMP4pSXLlNMr9KdyaKE+J3xn +xAz7TbGPHUBH5dqMzlWqEkZxcY9u9GL19SJPpC7dl8K8V5dKBwvgOubcLp4qLvZU +-----END CERTIFICATE-----`; + + const dir = tempDirWithFiles("test-extra-ca", { + "extra-ca.pem": testCert, + "test.js": `console.log('OK');`, + }); + + const certPath = join(dir, "extra-ca.pem"); + + // Test that NODE_EXTRA_CA_CERTS loads the certificate + await using proc = spawn({ + cmd: [bunExe(), "test.js"], + env: { ...bunEnv, NODE_EXTRA_CA_CERTS: certPath }, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + expect(stdout.trim()).toBe("OK"); + }); + + test("handles missing certificate file gracefully", async () => { + const dir = tempDirWithFiles("test-missing-ca", { + "test.js": `console.log('OK');`, + }); + + const nonExistentPath = join(dir, "non-existent.pem"); + + // Test that missing file doesn't crash the process + await using proc = spawn({ + cmd: [bunExe(), "test.js"], + env: { ...bunEnv, NODE_EXTRA_CA_CERTS: nonExistentPath }, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + // Process should still run successfully even with missing cert file + expect(exitCode).toBe(0); + expect(stdout.trim()).toBe("OK"); + // Bun may or may not warn about the missing file in stderr + // The important thing is that the process doesn't crash + }); + + test("works with both NODE_EXTRA_CA_CERTS and --use-system-ca", async () => { + const testCert = `-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAKLdQVPy90WjMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTgwNDEwMDgwNzQ4WhcNMjgwNDA3MDgwNzQ4WjBF +-----END CERTIFICATE-----`; + + const dir = tempDirWithFiles("test-extra-and-system", { + "extra-ca.pem": testCert, + "test.js": `console.log('OK');`, + }); + + const certPath = join(dir, "extra-ca.pem"); + + // Test that both work together + await using proc = spawn({ + cmd: [bunExe(), "--use-system-ca", "test.js"], + env: { ...bunEnv, NODE_EXTRA_CA_CERTS: certPath }, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + expect(stdout.trim()).toBe("OK"); + }); +}); diff --git a/test/js/node/tls/test-system-ca-https.test.ts b/test/js/node/tls/test-system-ca-https.test.ts new file mode 100644 index 0000000000..b6fb3a54a5 --- /dev/null +++ b/test/js/node/tls/test-system-ca-https.test.ts @@ -0,0 +1,149 @@ +import { spawn } from "bun"; +import { describe, expect, test } from "bun:test"; +import { readFileSync } from "fs"; +import { bunEnv, bunExe, tempDirWithFiles } from "harness"; + +// Gate network tests behind environment variable to avoid CI flakes +// TODO: Replace with hermetic local TLS fixtures in a follow-up +const networkTest = process.env.BUN_TEST_ALLOW_NET === "1" ? test : test.skip; + +describe("system CA with HTTPS", () => { + // Skip test if no system certificates are available + const skipIfNoSystemCerts = () => { + if (process.platform === "linux") { + // Check if common certificate paths exist on Linux + const certPaths = [ + "/etc/ssl/certs/ca-certificates.crt", + "/etc/pki/tls/certs/ca-bundle.crt", + "/etc/ssl/ca-bundle.pem", + "/etc/pki/tls/cacert.pem", + "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", + ]; + const hasSystemCerts = certPaths.some(path => { + try { + readFileSync(path); + return true; + } catch { + return false; + } + }); + if (!hasSystemCerts) { + return "no system certificates available on Linux"; + } + } + return null; + }; + + networkTest("HTTPS request with system CA", async () => { + const skipReason = skipIfNoSystemCerts(); + if (skipReason) { + test.skip(skipReason); + return; + } + + // Test that we can make HTTPS requests to well-known sites with system CA + const testCode = ` + const https = require('https'); + + // Test against a well-known HTTPS endpoint + https.get('https://www.google.com', (res) => { + console.log('STATUS:', res.statusCode); + process.exit(res.statusCode === 200 || res.statusCode === 301 || res.statusCode === 302 ? 0 : 1); + }).on('error', (err) => { + console.error('ERROR:', err.message); + process.exit(1); + }); + `; + + const dir = tempDirWithFiles("test-system-ca", { + "test.js": testCode, + }); + + // Test with --use-system-ca flag + await using proc1 = spawn({ + cmd: [bunExe(), "--use-system-ca", "test.js"], + env: bunEnv, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout1, stderr1, exitCode1] = await Promise.all([proc1.stdout.text(), proc1.stderr.text(), proc1.exited]); + + expect(exitCode1).toBe(0); + expect(stdout1).toContain("STATUS:"); + + // Test with NODE_USE_SYSTEM_CA=1 + await using proc2 = spawn({ + cmd: [bunExe(), "test.js"], + env: { ...bunEnv, NODE_USE_SYSTEM_CA: "1" }, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout2, stderr2, exitCode2] = await Promise.all([proc2.stdout.text(), proc2.stderr.text(), proc2.exited]); + + expect(exitCode2).toBe(0); + expect(stdout2).toContain("STATUS:"); + }); + + networkTest("HTTPS fails without system CA for custom root cert", async () => { + // This test verifies that without system CA, connections to sites + // with certificates not in the bundled list will fail + const testCode = ` + const https = require('https'); + + // Test against a site that typically uses a custom or less common CA + // Using a government site as they often have their own CAs + https.get('https://www.irs.gov', (res) => { + console.log('SUCCESS'); + process.exit(0); + }).on('error', (err) => { + if (err.code === 'UNABLE_TO_VERIFY_LEAF_SIGNATURE' || + err.code === 'CERT_HAS_EXPIRED' || + err.code === 'SELF_SIGNED_CERT_IN_CHAIN' || + err.message.includes('certificate')) { + console.log('CERT_ERROR'); + process.exit(1); + } + // Other errors (network, DNS, etc) + console.error('OTHER_ERROR:', err.code); + process.exit(2); + }); + `; + + const dir = tempDirWithFiles("test-no-system-ca", { + "test.js": testCode, + }); + + // Test WITHOUT system CA - might fail for some sites + await using proc1 = spawn({ + cmd: [bunExe(), "test.js"], + env: { ...bunEnv, NODE_USE_SYSTEM_CA: "0" }, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout1, stderr1, exitCode1] = await Promise.all([proc1.stdout.text(), proc1.stderr.text(), proc1.exited]); + + // This might succeed or fail depending on whether the site's CA is bundled + // We just verify the test runs without crashing + expect([0, 1, 2]).toContain(exitCode1); + + // Test WITH system CA - should have better success rate + await using proc2 = spawn({ + cmd: [bunExe(), "--use-system-ca", "test.js"], + env: bunEnv, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout2, stderr2, exitCode2] = await Promise.all([proc2.stdout.text(), proc2.stderr.text(), proc2.exited]); + + // With system CA, we expect either success or non-cert errors + expect([0, 2]).toContain(exitCode2); + }); +}); diff --git a/test/js/node/tls/test-use-system-ca.test.ts b/test/js/node/tls/test-use-system-ca.test.ts new file mode 100644 index 0000000000..52fed35e21 --- /dev/null +++ b/test/js/node/tls/test-use-system-ca.test.ts @@ -0,0 +1,69 @@ +import { spawn } from "bun"; +import { describe, expect, test } from "bun:test"; +import { bunEnv, bunExe } from "harness"; + +describe("--use-system-ca", () => { + test("flag loads system certificates", async () => { + // Test that --use-system-ca loads system certificates + await using proc = spawn({ + cmd: [bunExe(), "--use-system-ca", "-e", "console.log('OK')"], + env: bunEnv, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + expect(stdout.trim()).toBe("OK"); + expect(stderr).toBe(""); + }); + + test("NODE_USE_SYSTEM_CA=1 loads system certificates", async () => { + // Test that NODE_USE_SYSTEM_CA environment variable works + await using proc = spawn({ + cmd: [bunExe(), "-e", "console.log('OK')"], + env: { ...bunEnv, NODE_USE_SYSTEM_CA: "1" }, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + expect(stdout.trim()).toBe("OK"); + expect(stderr).toBe(""); + }); + + test("NODE_USE_SYSTEM_CA=0 doesn't load system certificates", async () => { + // Test that NODE_USE_SYSTEM_CA=0 doesn't load system certificates + await using proc = spawn({ + cmd: [bunExe(), "-e", "console.log('OK')"], + env: { ...bunEnv, NODE_USE_SYSTEM_CA: "0" }, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + expect(stdout.trim()).toBe("OK"); + expect(stderr).toBe(""); + }); + + test("--use-system-ca overrides NODE_USE_SYSTEM_CA=0", async () => { + // Test that CLI flag takes precedence over environment variable + await using proc = spawn({ + cmd: [bunExe(), "--use-system-ca", "-e", "console.log('OK')"], + env: { ...bunEnv, NODE_USE_SYSTEM_CA: "0" }, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + expect(stdout.trim()).toBe("OK"); + expect(stderr).toBe(""); + }); +}); diff --git a/test/no-validate-exceptions.txt b/test/no-validate-exceptions.txt index c6a0f3ede8..dff3f68a5c 100644 --- a/test/no-validate-exceptions.txt +++ b/test/no-validate-exceptions.txt @@ -23,6 +23,7 @@ test/js/node/test/parallel/test-require-dot.js test/js/node/test/parallel/test-util-promisify-custom-names.mjs test/js/node/test/parallel/test-whatwg-readablestream.mjs test/js/node/test/parallel/test-worker.mjs +test/js/node/test/system-ca/test-native-root-certs.test.mjs test/js/node/events/event-emitter.test.ts test/js/node/module/node-module-module.test.js test/js/node/process/call-constructor.test.js From 6c381b0e03fa02a78b921215a6617b03d677d2f2 Mon Sep 17 00:00:00 2001 From: robobun Date: Thu, 25 Sep 2025 00:37:10 -0700 Subject: [PATCH 03/43] Fix double slash in error stack traces when root_path has trailing slash (#22951) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Fixes double slashes appearing in error stack traces when `root_path` ends with a trailing slash - Followup to #22469 which added dimmed cwd prefixes to error messages ## Changes - Use `strings.withoutTrailingSlash()` to strip any trailing separator from `root_path` before adding the path separator - This prevents paths like `/workspace//file.js` from appearing in error messages 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude Bot Co-authored-by: Claude --- src/bun.js/bindings/ZigStackFrame.zig | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/bun.js/bindings/ZigStackFrame.zig b/src/bun.js/bindings/ZigStackFrame.zig index e9a950033d..3d9bc442ae 100644 --- a/src/bun.js/bindings/ZigStackFrame.zig +++ b/src/bun.js/bindings/ZigStackFrame.zig @@ -69,9 +69,10 @@ pub const ZigStackFrame = extern struct { if (this.enable_color) { const not_root = if (comptime bun.Environment.isWindows) this.root_path.len > "C:\\".len else this.root_path.len > "/".len; if (not_root and strings.startsWith(source_slice, this.root_path)) { + const root_path = strings.withoutTrailingSlash(this.root_path); const relative_path = strings.withoutLeadingPathSeparator(source_slice[this.root_path.len..]); try writer.writeAll(comptime Output.prettyFmt("", true)); - try writer.writeAll(this.root_path); + try writer.writeAll(root_path); try writer.writeByte(std.fs.path.sep); try writer.writeAll(comptime Output.prettyFmt("", true)); try writer.writeAll(relative_path); From 0ea4ce1bb430c67624b576f464d00dc15fa2f9be Mon Sep 17 00:00:00 2001 From: pfg Date: Thu, 25 Sep 2025 03:52:18 -0700 Subject: [PATCH 04/43] Synchronous concurrent test fix (#22928) ```ts beforeEach(() => { console.log("beforeEach"); }); afterEach(() => { console.log("afterEach"); }); test.concurrent("test 1", () => { console.log("start test 1"); }); test.concurrent("test 2", async () => { console.log("start test 2"); }); test.concurrent("test 3", () => { console.log("start test 3"); }); ``` ``` $> bun-before test synchronous-concurrent beforeEach beforeEach beforeEach start test 1 start test 2 start test 3 afterEach afterEach afterEach $> bun-after test synchronous-concurrent beforeEach start test 1 afterEach beforeEach start test 2 afterEach beforeEach start test 3 afterEach ``` --------- Co-authored-by: Jarred Sumner --- src/bun.js/event_loop.zig | 8 ++ src/bun.js/test/Collection.zig | 11 +- src/bun.js/test/Execution.zig | 30 +++-- src/bun.js/test/bun_test.zig | 124 ++++++++++++------ .../bun/test/concurrent_immediate.fixture.ts | 15 +++ test/js/bun/test/concurrent_immediate.test.ts | 77 +++++++++++ .../concurrent_immediate_error.fixture.ts | 15 +++ ...current_immediate_error_promise.fixture.ts | 15 +++ .../concurrent_immediate_promise.fixture.ts | 15 +++ test/js/node/fs/fs.test.ts | 14 +- 10 files changed, 261 insertions(+), 63 deletions(-) create mode 100644 test/js/bun/test/concurrent_immediate.fixture.ts create mode 100644 test/js/bun/test/concurrent_immediate.test.ts create mode 100644 test/js/bun/test/concurrent_immediate_error.fixture.ts create mode 100644 test/js/bun/test/concurrent_immediate_error_promise.fixture.ts create mode 100644 test/js/bun/test/concurrent_immediate_promise.fixture.ts diff --git a/src/bun.js/event_loop.zig b/src/bun.js/event_loop.zig index 58b149db96..dd4451e81f 100644 --- a/src/bun.js/event_loop.zig +++ b/src/bun.js/event_loop.zig @@ -182,6 +182,14 @@ comptime { @export(&externRunCallback3, .{ .name = "Bun__EventLoop__runCallback3" }); } +/// Prefer `runCallbackWithResult` unless you really need to make sure that microtasks are drained. +pub fn runCallbackWithResultAndForcefullyDrainMicrotasks(this: *EventLoop, callback: jsc.JSValue, globalObject: *jsc.JSGlobalObject, thisValue: jsc.JSValue, arguments: []const jsc.JSValue) !jsc.JSValue { + const result = try callback.call(globalObject, thisValue, arguments); + result.ensureStillAlive(); + try this.drainMicrotasksWithGlobal(globalObject, globalObject.bunVM().jsc_vm); + return result; +} + pub fn runCallbackWithResult(this: *EventLoop, callback: jsc.JSValue, globalObject: *jsc.JSGlobalObject, thisValue: jsc.JSValue, arguments: []const jsc.JSValue) jsc.JSValue { this.enter(); defer this.exit(); diff --git a/src/bun.js/test/Collection.zig b/src/bun.js/test/Collection.zig index a59f6b87e4..b4e2405436 100644 --- a/src/bun.js/test/Collection.zig +++ b/src/bun.js/test/Collection.zig @@ -137,11 +137,12 @@ pub fn step(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject this.active_scope = new_scope; group.log("collection:runOne set scope to {s}", .{this.active_scope.base.name orelse "undefined"}); - BunTest.runTestCallback(buntest_strong, globalThis, callback.get(), false, .{ - .collection = .{ - .active_scope = previous_scope, - }, - }, .epoch); + if (BunTest.runTestCallback(buntest_strong, globalThis, callback.get(), false, .{ + .collection = .{ .active_scope = previous_scope }, + }, &.epoch)) |cfg_data| { + // the result is available immediately; queue + buntest.addResult(cfg_data); + } return .{ .waiting = .{} }; } diff --git a/src/bun.js/test/Execution.zig b/src/bun.js/test/Execution.zig index 3819fad8dc..c57ae6414c 100644 --- a/src/bun.js/test/Execution.zig +++ b/src/bun.js/test/Execution.zig @@ -222,10 +222,11 @@ pub fn step(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject defer groupLog.end(); const buntest = buntest_strong.get(); const this = &buntest.execution; + var now = bun.timespec.now(); switch (data) { .start => { - return try stepGroup(buntest_strong, globalThis, bun.timespec.now()); + return try stepGroup(buntest_strong, globalThis, &now); }, else => { // determine the active sequence,group @@ -242,21 +243,20 @@ pub fn step(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject bun.assert(sequence.active_index < sequence.entries(this).len); this.advanceSequence(sequence, group); - const now = bun.timespec.now(); - const sequence_result = try stepSequence(buntest_strong, globalThis, sequence, group, sequence_index, now); + const sequence_result = try stepSequence(buntest_strong, globalThis, sequence, group, sequence_index, &now); switch (sequence_result) { .done => {}, .execute => |exec| return .{ .waiting = .{ .timeout = exec.timeout } }, } if (group.remaining_incomplete_entries == 0) { - return try stepGroup(buntest_strong, globalThis, now); + return try stepGroup(buntest_strong, globalThis, &now); } return .{ .waiting = .{} }; }, } } -pub fn stepGroup(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, now: bun.timespec) bun.JSError!bun_test.StepResult { +pub fn stepGroup(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, now: *bun.timespec) bun.JSError!bun_test.StepResult { groupLog.begin(@src()); defer groupLog.end(); const buntest = buntest_strong.get(); @@ -295,7 +295,7 @@ pub fn stepGroup(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalO } } const AdvanceStatus = union(enum) { done, execute: struct { timeout: bun.timespec = .epoch } }; -fn stepGroupOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, group: *ConcurrentGroup, now: bun.timespec) !AdvanceStatus { +fn stepGroupOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, group: *ConcurrentGroup, now: *bun.timespec) !AdvanceStatus { const buntest = buntest_strong.get(); const this = &buntest.execution; var final_status: AdvanceStatus = .done; @@ -320,13 +320,13 @@ const AdvanceSequenceStatus = union(enum) { timeout: bun.timespec = .epoch, }, }; -fn stepSequence(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, sequence: *ExecutionSequence, group: *ConcurrentGroup, sequence_index: usize, now: bun.timespec) !AdvanceSequenceStatus { +fn stepSequence(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, sequence: *ExecutionSequence, group: *ConcurrentGroup, sequence_index: usize, now: *bun.timespec) !AdvanceSequenceStatus { while (true) { return try stepSequenceOne(buntest_strong, globalThis, sequence, group, sequence_index, now) orelse continue; } } /// returns null if the while loop should continue -fn stepSequenceOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, sequence: *ExecutionSequence, group: *ConcurrentGroup, sequence_index: usize, now: bun.timespec) !?AdvanceSequenceStatus { +fn stepSequenceOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, sequence: *ExecutionSequence, group: *ConcurrentGroup, sequence_index: usize, now: *bun.timespec) !?AdvanceSequenceStatus { groupLog.begin(@src()); defer groupLog.end(); const buntest = buntest_strong.get(); @@ -337,10 +337,7 @@ fn stepSequenceOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGloba bun.debugAssert(false); // sequence is executing with no active entry return .{ .execute = .{} }; }; - if (!active_entry.timespec.eql(&.epoch) and active_entry.timespec.order(&now) == .lt) { - // timed out - sequence.result = if (active_entry == sequence.test_entry) if (active_entry.has_done_parameter) .fail_because_timeout_with_done_callback else .fail_because_timeout else if (active_entry.has_done_parameter) .fail_because_hook_timeout_with_done_callback else .fail_because_hook_timeout; - sequence.maybe_skip = true; + if (active_entry.evaluateTimeout(sequence, now)) { this.advanceSequence(sequence, group); return null; // run again } @@ -374,7 +371,14 @@ fn stepSequenceOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGloba }; groupLog.log("runSequence queued callback: {}", .{callback_data}); - BunTest.runTestCallback(buntest_strong, globalThis, cb.get(), next_item.has_done_parameter, callback_data, next_item.timespec); + if (BunTest.runTestCallback(buntest_strong, globalThis, cb.get(), next_item.has_done_parameter, callback_data, &next_item.timespec) != null) { + now.* = bun.timespec.now(); + _ = next_item.evaluateTimeout(sequence, now); + + // the result is available immediately; advance the sequence and run again. + this.advanceSequence(sequence, group); + return null; // run again + } return .{ .execute = .{ .timeout = next_item.timespec } }; } else { switch (next_item.base.mode) { diff --git a/src/bun.js/test/bun_test.zig b/src/bun.js/test/bun_test.zig index bcde3fe7d4..2fa96ba79f 100644 --- a/src/bun.js/test/bun_test.zig +++ b/src/bun.js/test/bun_test.zig @@ -331,6 +331,7 @@ pub const BunTest = struct { errdefer group.log("ended in error", .{}); const result, const this_ptr = callframe.argumentsAsArray(2); + if (this_ptr.isEmptyOrUndefinedOrNull()) return; const refdata: *RefData = this_ptr.asPromisePtr(RefData); defer refdata.deref(); @@ -472,21 +473,21 @@ pub const BunTest = struct { } } - this.updateMinTimeout(globalThis, min_timeout); + this.updateMinTimeout(globalThis, &min_timeout); } - fn updateMinTimeout(this: *BunTest, globalThis: *jsc.JSGlobalObject, min_timeout: bun.timespec) void { + fn updateMinTimeout(this: *BunTest, globalThis: *jsc.JSGlobalObject, min_timeout: *const bun.timespec) void { group.begin(@src()); defer group.end(); // only set the timer if the new timeout is sooner than the current timeout. this unfortunately means that we can't unset an unnecessary timer. - group.log("-> timeout: {} {}, {s}", .{ min_timeout, this.timer.next, @tagName(min_timeout.orderIgnoreEpoch(this.timer.next)) }); + group.log("-> timeout: {} {}, {s}", .{ min_timeout.*, this.timer.next, @tagName(min_timeout.orderIgnoreEpoch(this.timer.next)) }); if (min_timeout.orderIgnoreEpoch(this.timer.next) == .lt) { - group.log("-> setting timer to {}", .{min_timeout}); + group.log("-> setting timer to {}", .{min_timeout.*}); if (!this.timer.next.eql(&.epoch)) { group.log("-> removing existing timer", .{}); globalThis.bunVM().timer.remove(&this.timer); } - this.timer.next = min_timeout; + this.timer.next = min_timeout.*; if (!this.timer.next.eql(&.epoch)) { group.log("-> inserting timer", .{}); globalThis.bunVM().timer.insert(&this.timer); @@ -534,48 +535,55 @@ pub const BunTest = struct { } } - fn drain(globalThis: *jsc.JSGlobalObject) void { - const bun_vm = globalThis.bunVM(); - bun_vm.drainMicrotasks(); - var count = bun_vm.unhandled_error_counter; - bun_vm.global.handleRejectedPromises(); - while (bun_vm.unhandled_error_counter > count) { - count = bun_vm.unhandled_error_counter; - bun_vm.drainMicrotasks(); - bun_vm.global.handleRejectedPromises(); - } - } - - /// if sync, the result is queued and appended later - pub fn runTestCallback(this_strong: BunTestPtr, globalThis: *jsc.JSGlobalObject, cfg_callback: jsc.JSValue, cfg_done_parameter: bool, cfg_data: BunTest.RefDataValue, timeout: bun.timespec) void { + /// if sync, the result is returned. if async, null is returned. + pub fn runTestCallback(this_strong: BunTestPtr, globalThis: *jsc.JSGlobalObject, cfg_callback: jsc.JSValue, cfg_done_parameter: bool, cfg_data: BunTest.RefDataValue, timeout: *const bun.timespec) ?RefDataValue { group.begin(@src()); defer group.end(); const this = this_strong.get(); + const vm = globalThis.bunVM(); - var done_arg: ?jsc.JSValue = null; + // Don't use ?jsc.JSValue to make it harder for the conservative stack + // scanner to miss it. + var done_arg: jsc.JSValue = .zero; + var done_callback: jsc.JSValue = .zero; - var done_callback: ?jsc.JSValue = null; if (cfg_done_parameter) { group.log("callTestCallback -> appending done callback param: data {}", .{cfg_data}); done_callback = DoneCallback.createUnbound(globalThis); - done_arg = DoneCallback.bind(done_callback.?, globalThis) catch |e| blk: { + done_arg = DoneCallback.bind(done_callback, globalThis) catch |e| blk: { this.onUncaughtException(globalThis, globalThis.takeException(e), false, cfg_data); - break :blk jsc.JSValue.js_undefined; // failed to bind done callback + break :blk .zero; // failed to bind done callback }; } this.updateMinTimeout(globalThis, timeout); - const result: ?jsc.JSValue = cfg_callback.call(globalThis, .js_undefined, if (done_arg) |done| &.{done} else &.{}) catch blk: { + const result: jsc.JSValue = vm.eventLoop().runCallbackWithResultAndForcefullyDrainMicrotasks(cfg_callback, globalThis, .js_undefined, if (done_arg != .zero) &.{done_arg} else &.{}) catch blk: { globalThis.clearTerminationException(); this.onUncaughtException(globalThis, globalThis.tryTakeException(), false, cfg_data); group.log("callTestCallback -> error", .{}); - break :blk null; + break :blk .zero; }; + done_callback.ensureStillAlive(); + + // Drain unhandled promise rejections. + while (true) { + // Prevent the user's Promise rejection from going into the uncaught promise rejection queue. + if (result != .zero) + if (result.asPromise()) |promise| + if (promise.status(globalThis.vm()) == .rejected) + promise.setHandled(globalThis.vm()); + + const prev_unhandled_count = vm.unhandled_error_counter; + globalThis.handleRejectedPromises(); + if (vm.unhandled_error_counter == prev_unhandled_count) + break; + } + var dcb_ref: ?*RefData = null; - if (done_callback) |dcb| { - if (DoneCallback.fromJS(dcb)) |dcb_data| { - if (dcb_data.called or result == null) { + if (done_callback != .zero and result != .zero) { + if (DoneCallback.fromJS(done_callback)) |dcb_data| { + if (dcb_data.called) { // done callback already called or the callback errored; add result immediately } else { dcb_ref = ref(this_strong, cfg_data); @@ -584,25 +592,43 @@ pub const BunTest = struct { } else bun.debugAssert(false); // this should be unreachable, we create DoneCallback above } - if (result != null and result.?.asPromise() != null) { - group.log("callTestCallback -> promise: data {}", .{cfg_data}); - const this_ref: *RefData = if (dcb_ref) |dcb_ref_value| dcb_ref_value.dupe() else ref(this_strong, cfg_data); - result.?.then(globalThis, this_ref, bunTestThen, bunTestCatch); - drain(globalThis); - return; + if (result != .zero) { + if (result.asPromise()) |promise| { + defer result.ensureStillAlive(); // because sometimes we use promise without result + + group.log("callTestCallback -> promise: data {}", .{cfg_data}); + + switch (promise.status(globalThis.vm())) { + .pending => { + // not immediately resolved; register 'then' to handle the result when it becomes available + const this_ref: *RefData = if (dcb_ref) |dcb_ref_value| dcb_ref_value.dupe() else ref(this_strong, cfg_data); + result.then(globalThis, this_ref, bunTestThen, bunTestCatch); + return null; + }, + .fulfilled => { + // Do not register a then callback when it's already fulfilled. + return cfg_data; + }, + .rejected => { + const value = promise.result(globalThis.vm()); + this.onUncaughtException(globalThis, value, true, cfg_data); + + // We previously marked it as handled above. + + return cfg_data; + }, + } + } } if (dcb_ref) |_| { // completed asynchronously group.log("callTestCallback -> wait for done callback", .{}); - drain(globalThis); - return; + return null; } group.log("callTestCallback -> sync", .{}); - drain(globalThis); - this.addResult(cfg_data); - return; + return cfg_data; } /// called from the uncaught exception handler, or if a test callback rejects or throws an error @@ -843,6 +869,26 @@ pub const ExecutionEntry = struct { } return entry; } + + pub fn evaluateTimeout(this: *ExecutionEntry, sequence: *Execution.ExecutionSequence, now: *const bun.timespec) bool { + if (!this.timespec.eql(&.epoch) and this.timespec.order(now) == .lt) { + // timed out + sequence.result = if (this == sequence.test_entry) + if (this.has_done_parameter) + .fail_because_timeout_with_done_callback + else + .fail_because_timeout + else if (this.has_done_parameter) + .fail_because_hook_timeout_with_done_callback + else + .fail_because_hook_timeout; + sequence.maybe_skip = true; + return true; + } + + return false; + } + pub fn destroy(this: *ExecutionEntry, gpa: std.mem.Allocator) void { if (this.callback) |*c| c.deinit(); this.base.deinit(gpa); diff --git a/test/js/bun/test/concurrent_immediate.fixture.ts b/test/js/bun/test/concurrent_immediate.fixture.ts new file mode 100644 index 0000000000..0c9dc496a6 --- /dev/null +++ b/test/js/bun/test/concurrent_immediate.fixture.ts @@ -0,0 +1,15 @@ +beforeEach(() => { + console.log("beforeEach"); +}); +afterEach(() => { + console.log("afterEach"); +}); +test.concurrent("test 1", () => { + console.log("start test 1"); +}); +test.concurrent("test 2", () => { + console.log("start test 2"); +}); +test.concurrent("test 3", () => { + console.log("start test 3"); +}); diff --git a/test/js/bun/test/concurrent_immediate.test.ts b/test/js/bun/test/concurrent_immediate.test.ts new file mode 100644 index 0000000000..40256fb48f --- /dev/null +++ b/test/js/bun/test/concurrent_immediate.test.ts @@ -0,0 +1,77 @@ +import { expect, test } from "bun:test"; +import { bunEnv, bunExe, normalizeBunSnapshot } from "harness"; + +test("concurrent immediate", async () => { + const result = await Bun.spawn({ + cmd: [bunExe(), "test", import.meta.dir + "/concurrent_immediate.fixture.ts"], + stdout: "pipe", + stderr: "pipe", + env: bunEnv, + }); + const exitCode = await result.exited; + const stdout = await result.stdout.text(); + const stderr = await result.stderr.text(); + expect(exitCode).toBe(0); + expect(normalizeBunSnapshot(stdout)).toMatchInlineSnapshot(` + "bun test () + beforeEach + start test 1 + afterEach + beforeEach + start test 2 + afterEach + beforeEach + start test 3 + afterEach" + `); + + const result2 = await Bun.spawn({ + cmd: [bunExe(), "test", import.meta.dir + "/concurrent_immediate_promise.fixture.ts"], + stdout: "pipe", + stderr: "pipe", + env: bunEnv, + }); + const exitCode2 = await result2.exited; + const stdout2 = await result2.stdout.text(); + const stderr2 = await result2.stderr.text(); + expect(exitCode2).toBe(0); + expect(normalizeBunSnapshot(stdout2)).toBe(normalizeBunSnapshot(stdout)); + expect(normalizeBunSnapshot(stderr2).replaceAll("_promise.", ".")).toBe(normalizeBunSnapshot(stderr)); +}); + +function filterImportantLines(stderr: string) { + return normalizeBunSnapshot(stderr) + .split("\n") + .filter(l => l.startsWith("(pass)") || l.startsWith("(fail)") || l.startsWith("error:")) + .join("\n"); +} + +test("concurrent immediate error", async () => { + const result = await Bun.spawn({ + cmd: [bunExe(), "test", import.meta.dir + "/concurrent_immediate_error.fixture.ts"], + stdout: "pipe", + stderr: "pipe", + env: bunEnv, + }); + const exitCode = await result.exited; + const stdout = await result.stdout.text(); + const stderr = await result.stderr.text(); + expect(exitCode).toBe(1); + expect(filterImportantLines(stderr)).toMatchInlineSnapshot(` + "(pass) test 1 + error: test 2 error + (fail) test 2 + (pass) test 3" + `); + + const result2 = await Bun.spawn({ + cmd: [bunExe(), "test", import.meta.dir + "/concurrent_immediate_error_promise.fixture.ts"], + stdout: "pipe", + stderr: "pipe", + env: bunEnv, + }); + const exitCode2 = await result2.exited; + const stdout2 = await result2.stdout.text(); + const stderr2 = await result2.stderr.text(); + expect(filterImportantLines(stderr2)).toBe(filterImportantLines(stderr)); +}); diff --git a/test/js/bun/test/concurrent_immediate_error.fixture.ts b/test/js/bun/test/concurrent_immediate_error.fixture.ts new file mode 100644 index 0000000000..a318c76260 --- /dev/null +++ b/test/js/bun/test/concurrent_immediate_error.fixture.ts @@ -0,0 +1,15 @@ +beforeEach(() => { + console.log("beforeEach"); +}); +afterEach(() => { + console.log("afterEach"); +}); +test.concurrent("test 1", () => { + console.log("start test 1"); +}); +test.concurrent("test 2", () => { + throw new Error("test 2 error"); +}); +test.concurrent("test 3", () => { + console.log("start test 3"); +}); diff --git a/test/js/bun/test/concurrent_immediate_error_promise.fixture.ts b/test/js/bun/test/concurrent_immediate_error_promise.fixture.ts new file mode 100644 index 0000000000..a757ec4fe5 --- /dev/null +++ b/test/js/bun/test/concurrent_immediate_error_promise.fixture.ts @@ -0,0 +1,15 @@ +beforeEach(async () => { + console.log("beforeEach"); +}); +afterEach(async () => { + console.log("afterEach"); +}); +test.concurrent("test 1", async () => { + console.log("start test 1"); +}); +test.concurrent("test 2", async () => { + throw new Error("test 2 error"); +}); +test.concurrent("test 3", async () => { + console.log("start test 3"); +}); diff --git a/test/js/bun/test/concurrent_immediate_promise.fixture.ts b/test/js/bun/test/concurrent_immediate_promise.fixture.ts new file mode 100644 index 0000000000..709ce5d42a --- /dev/null +++ b/test/js/bun/test/concurrent_immediate_promise.fixture.ts @@ -0,0 +1,15 @@ +beforeEach(async () => { + console.log("beforeEach"); +}); +afterEach(async () => { + console.log("afterEach"); +}); +test.concurrent("test 1", async () => { + console.log("start test 1"); +}); +test.concurrent("test 2", async () => { + console.log("start test 2"); +}); +test.concurrent("test 3", async () => { + console.log("start test 3"); +}); diff --git a/test/js/node/fs/fs.test.ts b/test/js/node/fs/fs.test.ts index c56483180a..f4196f282f 100644 --- a/test/js/node/fs/fs.test.ts +++ b/test/js/node/fs/fs.test.ts @@ -13,6 +13,7 @@ import fs, { fdatasync, fdatasyncSync, fstatSync, + ftruncateSync, lstatSync, mkdirSync, mkdtemp, @@ -2714,14 +2715,15 @@ it("fstat on a large file", () => { try { dest = `${tmpdir()}/fs.test.ts/${Math.trunc(Math.random() * 10000000000).toString(32)}.stat.txt`; mkdirSync(dirname(dest), { recursive: true }); - const bigBuffer = new Uint8Array(1024 * 1024 * 1024); fd = openSync(dest, "w"); - let offset = 0; - while (offset < 5 * 1024 * 1024 * 1024) { - offset += writeSync(fd, bigBuffer, 0, bigBuffer.length, offset); - } + + // Instead of writing the actual bytes, we can use ftruncate to make a + // hole-y file and extend it to the desired size This should generally avoid + // the ENOSPC issue and avoid timeouts. + ftruncateSync(fd, 5 * 1024 * 1024 * 1024); fdatasyncSync(fd); - expect(fstatSync(fd).size).toEqual(offset); + const stats = fstatSync(fd); + expect(stats.size).toEqual(5 * 1024 * 1024 * 1024); } catch (error) { // TODO: Once `fs.statfsSync` is implemented, make sure that the buffer size // is small enough not to cause: ENOSPC: No space left on device. From be15f6c80c7843c76f7eb8be06a7981964509eb9 Mon Sep 17 00:00:00 2001 From: robobun Date: Thu, 25 Sep 2025 14:20:47 -0700 Subject: [PATCH 05/43] feat(test): add --randomize flag to run tests in random order (#22945) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR adds a `--randomize` flag to `bun test` that shuffles test execution order. This helps developers catch test interdependencies and identify flaky tests that may depend on execution order. ## Changes - ✨ Added `--randomize` CLI flag to test command - 🔀 Implemented test shuffling using `bun.fastRandom()` as PRNG seed - 🧪 Added comprehensive tests to verify randomization behavior - 📝 Tests are shuffled at the scheduling phase, properly handling describe blocks and hooks ## Usage ```bash # Run tests in random order bun test --randomize # Works with other test flags bun test --randomize --bail bun test mytest.test.ts --randomize ``` ## Implementation Details The randomization happens in `Order.zig`'s `generateOrderDescribe` function, which shuffles the `current.entries.items` array when the randomize flag is set. This ensures: - All tests still run (just in different order) - Hooks (beforeAll, afterAll, beforeEach, afterEach) maintain proper relationships - Describe blocks and their children are shuffled independently - Each run uses a different random seed for varied execution orders ## Test Coverage Added tests in `test/cli/test/test-randomize.test.ts` that verify: - Tests run in random order with the flag - All tests execute (none are skipped) - Without the flag, tests run in consistent order - Randomization works with describe blocks ## Example Output ```bash # Without --randomize (consistent order) $ bun test mytest.js Running test 1 Running test 2 Running test 3 Running test 4 Running test 5 # With --randomize (different order each run) $ bun test mytest.js --randomize Running test 3 Running test 5 Running test 1 Running test 4 Running test 2 $ bun test mytest.js --randomize Running test 2 Running test 4 Running test 5 Running test 1 Running test 3 ``` 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --------- Co-authored-by: Claude Bot Co-authored-by: Claude Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: pfg --- src/bun.js/test/Order.zig | 8 + src/bun.js/test/bun_test.zig | 3 +- src/bun.js/test/jest.zig | 1 + src/cli.zig | 1 + src/cli/Arguments.zig | 2 + src/cli/test_command.zig | 1 + test/cli/test/test-randomize.test.ts | 256 +++++++++++++++++++++++++++ 7 files changed, 271 insertions(+), 1 deletion(-) create mode 100644 test/cli/test/test-randomize.test.ts diff --git a/src/bun.js/test/Order.zig b/src/bun.js/test/Order.zig index ce51ae80f6..3adae8f64b 100644 --- a/src/bun.js/test/Order.zig +++ b/src/bun.js/test/Order.zig @@ -38,6 +38,7 @@ pub const AllOrderResult = struct { }; pub const Config = struct { always_use_hooks: bool = false, + randomize: bool = false, }; pub fn generateAllOrder(this: *Order, entries: []const *ExecutionEntry, _: Config) bun.JSError!AllOrderResult { const start = this.groups.items.len; @@ -61,6 +62,13 @@ pub fn generateOrderDescribe(this: *Order, current: *DescribeScope, cfg: Config) // gather beforeAll const beforeall_order: AllOrderResult = if (use_hooks) try generateAllOrder(this, current.beforeAll.items, cfg) else .empty; + // shuffle entries if randomize flag is set + if (cfg.randomize) { + var prng = std.Random.DefaultPrng.init(bun.fastRandom()); + const random = prng.random(); + random.shuffle(TestScheduleEntry, current.entries.items); + } + // gather children for (current.entries.items) |entry| { if (current.base.only == .contains and entry.base().only == .no) continue; diff --git a/src/bun.js/test/bun_test.zig b/src/bun.js/test/bun_test.zig index 2fa96ba79f..1150c8c916 100644 --- a/src/bun.js/test/bun_test.zig +++ b/src/bun.js/test/bun_test.zig @@ -514,7 +514,8 @@ pub const BunTest = struct { defer order.deinit(); const has_filter = if (this.reporter) |reporter| if (reporter.jest.filter_regex) |_| true else false else false; - const cfg: Order.Config = .{ .always_use_hooks = this.collection.root_scope.base.only == .no and !has_filter }; + const should_randomize = if (this.reporter) |reporter| reporter.jest.randomize else false; + const cfg: Order.Config = .{ .always_use_hooks = this.collection.root_scope.base.only == .no and !has_filter, .randomize = should_randomize }; const beforeall_order: Order.AllOrderResult = if (cfg.always_use_hooks or this.collection.root_scope.base.has_callback) try order.generateAllOrder(this.buntest.hook_scope.beforeAll.items, cfg) else .empty; try order.generateOrderDescribe(this.collection.root_scope, cfg); beforeall_order.setFailureSkipTo(&order); diff --git a/src/bun.js/test/jest.zig b/src/bun.js/test/jest.zig index 1a592dc041..013bbcb5f9 100644 --- a/src/bun.js/test/jest.zig +++ b/src/bun.js/test/jest.zig @@ -55,6 +55,7 @@ pub const TestRunner = struct { only: bool = false, run_todo: bool = false, concurrent: bool = false, + randomize: bool = false, concurrent_test_glob: ?[]const []const u8 = null, last_file: u64 = 0, bail: u32 = 0, diff --git a/src/cli.zig b/src/cli.zig index ea1129320e..f192591414 100644 --- a/src/cli.zig +++ b/src/cli.zig @@ -339,6 +339,7 @@ pub const Command = struct { run_todo: bool = false, only: bool = false, concurrent: bool = false, + randomize: bool = false, concurrent_test_glob: ?[]const []const u8 = null, bail: u32 = 0, coverage: TestCommand.CodeCoverageOptions = .{}, diff --git a/src/cli/Arguments.zig b/src/cli/Arguments.zig index f0884f20df..1828060979 100644 --- a/src/cli/Arguments.zig +++ b/src/cli/Arguments.zig @@ -197,6 +197,7 @@ pub const test_only_params = [_]ParamType{ clap.parseParam("--rerun-each Re-run each test file times, helps catch certain bugs") catch unreachable, clap.parseParam("--todo Include tests that are marked with \"test.todo()\"") catch unreachable, clap.parseParam("--concurrent Treat all tests as `test.concurrent()` tests") catch unreachable, + clap.parseParam("--randomize Run tests in random order") catch unreachable, clap.parseParam("--coverage Generate a coverage profile") catch unreachable, clap.parseParam("--coverage-reporter ... Report coverage in 'text' and/or 'lcov'. Defaults to 'text'.") catch unreachable, clap.parseParam("--coverage-dir Directory for coverage files. Defaults to 'coverage'.") catch unreachable, @@ -495,6 +496,7 @@ pub fn parse(allocator: std.mem.Allocator, ctx: Command.Context, comptime cmd: C ctx.test_options.update_snapshots = args.flag("--update-snapshots"); ctx.test_options.run_todo = args.flag("--todo"); ctx.test_options.concurrent = args.flag("--concurrent"); + ctx.test_options.randomize = args.flag("--randomize"); } ctx.args.absolute_working_dir = cwd; diff --git a/src/cli/test_command.zig b/src/cli/test_command.zig index a0f70cd1ee..924c415445 100644 --- a/src/cli/test_command.zig +++ b/src/cli/test_command.zig @@ -1301,6 +1301,7 @@ pub const TestCommand = struct { .allocator = ctx.allocator, .default_timeout_ms = ctx.test_options.default_timeout_ms, .concurrent = ctx.test_options.concurrent, + .randomize = ctx.test_options.randomize, .concurrent_test_glob = ctx.test_options.concurrent_test_glob, .run_todo = ctx.test_options.run_todo, .only = ctx.test_options.only, diff --git a/test/cli/test/test-randomize.test.ts b/test/cli/test/test-randomize.test.ts new file mode 100644 index 0000000000..b53dfc2804 --- /dev/null +++ b/test/cli/test/test-randomize.test.ts @@ -0,0 +1,256 @@ +import { expect, test } from "bun:test"; +import { bunEnv, bunExe, tempDir } from "harness"; +import { join } from "path"; + +test("--randomize flag randomizes test execution order", async () => { + // Create a test file with multiple tests that output their names + using dir = tempDir("test-randomize", {}); + const testFile = join(String(dir), "order.test.js"); + + await Bun.write( + testFile, + ` + import { test } from "bun:test"; + + test("test-01", () => { + console.log("test-01"); + }); + + test("test-02", () => { + console.log("test-02"); + }); + + test("test-03", () => { + console.log("test-03"); + }); + + test("test-04", () => { + console.log("test-04"); + }); + + test("test-05", () => { + console.log("test-05"); + }); + + test("test-06", () => { + console.log("test-06"); + }); + + test("test-07", () => { + console.log("test-07"); + }); + + test("test-08", () => { + console.log("test-08"); + }); + `, + ); + + // Run without --randomize to get the default order + await using defaultProc = Bun.spawn({ + cmd: [bunExe(), "test", testFile], + env: bunEnv, + stdout: "pipe", + stderr: "pipe", + cwd: String(dir), + }); + + const [defaultOut, defaultErr, defaultExit] = await Promise.all([ + defaultProc.stdout.text(), + defaultProc.stderr.text(), + defaultProc.exited, + ]); + + expect(defaultExit).toBe(0); + + // Extract test execution order from output + const defaultTests = defaultOut.match(/test-\d+/g) || []; + expect(defaultTests.length).toBe(8); + + // Run multiple times WITH --randomize to find a different order + let foundDifferentOrder = false; + const maxAttempts = 20; // Increase attempts since randomization might occasionally match + + for (let i = 0; i < maxAttempts; i++) { + await using randomProc = Bun.spawn({ + cmd: [bunExe(), "test", testFile, "--randomize"], + env: bunEnv, + stdout: "pipe", + stderr: "pipe", + cwd: String(dir), + }); + + const [randomOut, randomErr, randomExit] = await Promise.all([ + randomProc.stdout.text(), + randomProc.stderr.text(), + randomProc.exited, + ]); + + expect(randomExit).toBe(0); + + const randomTests = randomOut.match(/test-\d+/g) || []; + expect(randomTests.length).toBe(8); + + // Check if all tests ran (just different order) + const sortedRandom = [...randomTests].sort(); + const sortedDefault = [...defaultTests].sort(); + expect(sortedRandom).toEqual(sortedDefault); + + // Check if order is different + const orderIsDifferent = randomTests.some((test, index) => test !== defaultTests[index]); + if (orderIsDifferent) { + foundDifferentOrder = true; + break; + } + } + + // With 8 tests and 20 attempts, the probability of not finding a different order + // by pure chance is (1/8!)^20 which is astronomically small + expect(foundDifferentOrder).toBe(true); +}, 30000); // 30 second timeout for this test + +test("--randomize flag works with describe blocks", async () => { + using dir = tempDir("test-randomize-describe", {}); + const testFile = join(String(dir), "describe.test.js"); + + await Bun.write( + testFile, + ` + import { test, describe } from "bun:test"; + + describe("Suite-A", () => { + test("A1", () => { + console.log("A1"); + }); + + test("A2", () => { + console.log("A2"); + }); + + test("A3", () => { + console.log("A3"); + }); + }); + + describe("Suite-B", () => { + test("B1", () => { + console.log("B1"); + }); + + test("B2", () => { + console.log("B2"); + }); + }); + + describe("Suite-C", () => { + test("C1", () => { + console.log("C1"); + }); + + test("C2", () => { + console.log("C2"); + }); + }); + `, + ); + + // Run without --randomize + await using defaultProc = Bun.spawn({ + cmd: [bunExe(), "test", testFile], + env: bunEnv, + stdout: "pipe", + stderr: "pipe", + cwd: String(dir), + }); + + const [defaultOut, defaultErr, defaultExit] = await Promise.all([ + defaultProc.stdout.text(), + defaultProc.stderr.text(), + defaultProc.exited, + ]); + + expect(defaultExit).toBe(0); + + const defaultTests = defaultOut.match(/[ABC]\d/g) || []; + expect(defaultTests.length).toBe(7); + + // Run with --randomize multiple times + let foundDifferentOrder = false; + + for (let i = 0; i < 20; i++) { + await using randomProc = Bun.spawn({ + cmd: [bunExe(), "test", testFile, "--randomize"], + env: bunEnv, + stdout: "pipe", + stderr: "pipe", + cwd: String(dir), + }); + + const [randomOut, randomErr, randomExit] = await Promise.all([ + randomProc.stdout.text(), + randomProc.stderr.text(), + randomProc.exited, + ]); + + expect(randomExit).toBe(0); + + const randomTests = randomOut.match(/[ABC]\d/g) || []; + expect(randomTests.length).toBe(7); + + // Verify all tests ran + expect([...randomTests].sort()).toEqual([...defaultTests].sort()); + + // Check if order is different + const orderIsDifferent = randomTests.some((test, index) => test !== defaultTests[index]); + if (orderIsDifferent) { + foundDifferentOrder = true; + break; + } + } + + expect(foundDifferentOrder).toBe(true); +}, 30000); + +test("without --randomize flag tests run in consistent order", async () => { + using dir = tempDir("test-consistent", {}); + const testFile = join(String(dir), "consistent.test.js"); + + await Bun.write( + testFile, + ` + import { test } from "bun:test"; + + test("test-1", () => { console.log("1"); }); + test("test-2", () => { console.log("2"); }); + test("test-3", () => { console.log("3"); }); + test("test-4", () => { console.log("4"); }); + test("test-5", () => { console.log("5"); }); + `, + ); + + const runs = []; + + // Run 5 times without --randomize + for (let i = 0; i < 5; i++) { + await using proc = Bun.spawn({ + cmd: [bunExe(), "test", testFile], + env: bunEnv, + stdout: "pipe", + stderr: "pipe", + cwd: String(dir), + }); + + const [out, err, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + + const order = out.match(/\d/g) || []; + runs.push(order.join("")); + } + + // All runs should have the same order + const firstRun = runs[0]; + for (const run of runs) { + expect(run).toBe(firstRun); + } +}, 20000); From 20854fb285a31f7c8940ceced0d3e3b2f5ffdea8 Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Thu, 25 Sep 2025 14:28:42 -0800 Subject: [PATCH 06/43] node:crypto: add blake2s256 hasher (#22958) --- docs/api/hashing.md | 1 + packages/bun-types/bun.d.ts | 1 + src/bun.js/api/crypto/CryptoHasher.zig | 1 + src/bun.js/api/crypto/EVP.zig | 2 ++ test/js/bun/util/bun-cryptohasher.test.ts | 1 + test/js/node/crypto/crypto-hmac-algorithm.test.ts | 1 + test/js/node/crypto/crypto-oneshot.test.ts | 1 + test/js/node/crypto/crypto.test.ts | 3 +++ test/js/node/crypto/node-crypto.test.js | 1 - 9 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/api/hashing.md b/docs/api/hashing.md index 79d3c54e86..384a20243c 100644 --- a/docs/api/hashing.md +++ b/docs/api/hashing.md @@ -184,6 +184,7 @@ Bun.hash.rapidhash("data", 1234); - `"blake2b256"` - `"blake2b512"` +- `"blake2s256"` - `"md4"` - `"md5"` - `"ripemd160"` diff --git a/packages/bun-types/bun.d.ts b/packages/bun-types/bun.d.ts index 89f318c0cc..406e3d8466 100644 --- a/packages/bun-types/bun.d.ts +++ b/packages/bun-types/bun.d.ts @@ -5047,6 +5047,7 @@ declare module "bun" { type SupportedCryptoAlgorithms = | "blake2b256" | "blake2b512" + | "blake2s256" | "md4" | "md5" | "ripemd160" diff --git a/src/bun.js/api/crypto/CryptoHasher.zig b/src/bun.js/api/crypto/CryptoHasher.zig index e46b30f6b4..f40c324f0b 100644 --- a/src/bun.js/api/crypto/CryptoHasher.zig +++ b/src/bun.js/api/crypto/CryptoHasher.zig @@ -488,6 +488,7 @@ const CryptoHasherZig = struct { .{ "sha3-512", std.crypto.hash.sha3.Sha3_512 }, .{ "shake128", std.crypto.hash.sha3.Shake128 }, .{ "shake256", std.crypto.hash.sha3.Shake256 }, + .{ "blake2s256", std.crypto.hash.blake2.Blake2s256 }, }; inline fn digestLength(Algorithm: type) comptime_int { diff --git a/src/bun.js/api/crypto/EVP.zig b/src/bun.js/api/crypto/EVP.zig index 14a976997f..2d3f3b088f 100644 --- a/src/bun.js/api/crypto/EVP.zig +++ b/src/bun.js/api/crypto/EVP.zig @@ -21,6 +21,7 @@ pub const Algorithm = enum { // @"ecdsa-with-SHA1", blake2b256, blake2b512, + blake2s256, md4, md5, ripemd160, @@ -69,6 +70,7 @@ pub const Algorithm = enum { pub const map = bun.ComptimeStringMap(Algorithm, .{ .{ "blake2b256", .blake2b256 }, .{ "blake2b512", .blake2b512 }, + .{ "blake2s256", .blake2s256 }, .{ "ripemd160", .ripemd160 }, .{ "rmd160", .ripemd160 }, .{ "md4", .md4 }, diff --git a/test/js/bun/util/bun-cryptohasher.test.ts b/test/js/bun/util/bun-cryptohasher.test.ts index aa159a65e6..545c3f884e 100644 --- a/test/js/bun/util/bun-cryptohasher.test.ts +++ b/test/js/bun/util/bun-cryptohasher.test.ts @@ -185,6 +185,7 @@ describe("CryptoHasher", () => { const algorithms = [ "blake2b256", "blake2b512", + "blake2s256", "ripemd160", "rmd160", "md4", diff --git a/test/js/node/crypto/crypto-hmac-algorithm.test.ts b/test/js/node/crypto/crypto-hmac-algorithm.test.ts index 9125077470..88d7f38c73 100644 --- a/test/js/node/crypto/crypto-hmac-algorithm.test.ts +++ b/test/js/node/crypto/crypto-hmac-algorithm.test.ts @@ -21,6 +21,7 @@ test("createHmac works with various algorithm names", () => { const toRemove = [ "blake2b256", "blake2b512", + "blake2s256", "md4", "sha512-224", "sha512-256", diff --git a/test/js/node/crypto/crypto-oneshot.test.ts b/test/js/node/crypto/crypto-oneshot.test.ts index eeedbde20d..99529c3de3 100644 --- a/test/js/node/crypto/crypto-oneshot.test.ts +++ b/test/js/node/crypto/crypto-oneshot.test.ts @@ -39,6 +39,7 @@ describe("crypto.hash", () => { [ "blake2b256", "blake2b512", + "blake2s256", "ripemd160", "rmd160", "md4", diff --git a/test/js/node/crypto/crypto.test.ts b/test/js/node/crypto/crypto.test.ts index 9384543b93..16336cdcda 100644 --- a/test/js/node/crypto/crypto.test.ts +++ b/test/js/node/crypto/crypto.test.ts @@ -11,6 +11,7 @@ describe("CryptoHasher", () => { expect(CryptoHasher.algorithms).toEqual([ "blake2b256", "blake2b512", + "blake2s256", "md4", "md5", "ripemd160", @@ -34,6 +35,7 @@ describe("CryptoHasher", () => { const expected = { blake2b256: "256c83b297114d201b30179f3f0ef0cace9783622da5974326b436178aeef610", blake2b512: "021ced8799296ceca557832ab941a50b4a11f83478cf141f51f933f653ab9fbcc05a037cddbed06e309bf334942c4e58cdf1a46e237911ccd7fcf9787cbc7fd0", + blake2s256: "9aec6806794561107e594b1f6a8a6b0c92a0cba9acf5e5e93cca06f781813b0b", md4: "aa010fbc1d14c795d86ef98c95479d17", md5: "5eb63bbbe01eeed093cb22bb8f5acdc3", ripemd160: "98c615784ccb5fe5936fbc0cbe9dfdb408d92f0f", @@ -55,6 +57,7 @@ describe("CryptoHasher", () => { const expectedBitLength = { blake2b256: 256, blake2b512: 512, + blake2s256: 256, md4: 128, md5: 128, ripemd160: 160, diff --git a/test/js/node/crypto/node-crypto.test.js b/test/js/node/crypto/node-crypto.test.js index e4b8874a40..2316d0ba5e 100644 --- a/test/js/node/crypto/node-crypto.test.js +++ b/test/js/node/crypto/node-crypto.test.js @@ -255,7 +255,6 @@ describe("createHash", () => { }; const unsupported = [ - "blake2s256", "id-rsassa-pkcs1-v1_5-with-sha3-224", "id-rsassa-pkcs1-v1_5-with-sha3-256", "id-rsassa-pkcs1-v1_5-with-sha3-384", From 4dfd87a3023c5a5af3cd97764d48f2a7cebcff99 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Thu, 25 Sep 2025 16:08:06 -0700 Subject: [PATCH 07/43] Fix aborting fetch() calls while the socket is connecting. Fix a thread-safety issue involving redirects and AbortSignal. (#22842) ### What does this PR do? When we added "happy eyeballs" support to fetch(), it meant that `onOpen` would not be called potentially for awhile. If the AbortSignal is aborted between `connect()` and the socket becoming readable/writable, then we would delay closing the connection until the connection opens. Fixing that fixes #18536. Separately, the `isHTTPS()` function used in abort and in request body streams was not thread safe. This caused a crash when many redirects happen simultaneously while either AbortSignal or request body messages are in-flight. This PR fixes https://github.com/oven-sh/bun/issues/14137 ### How did you verify your code works? There are tests --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Claude Bot Co-authored-by: Ciro Spaciari --- src/bun.js/webcore.zig | 1 + src/bun.js/webcore/ResumableSink.zig | 159 +++++++++--------- src/bun.js/webcore/fetch.zig | 99 +++++++---- src/http.zig | 72 +++++--- src/http/HTTPContext.zig | 61 +++---- src/http/HTTPThread.zig | 140 +++++++++------ src/http/InternalState.zig | 4 + src/js/builtins/ReadableStreamInternals.ts | 132 +++++++++------ src/s3/client.zig | 3 +- src/s3/multipart.zig | 17 +- .../io/fetch/fetch-abort-slow-connect.test.ts | 59 +++++++ test/js/bun/s3/s3-stream-leak-fixture.js | 2 +- test/js/bun/s3/s3.leak.test.ts | 3 +- test/js/web/fetch/fetch.test.ts | 35 ++++ 14 files changed, 497 insertions(+), 290 deletions(-) create mode 100644 test/js/bun/io/fetch/fetch-abort-slow-connect.test.ts diff --git a/src/bun.js/webcore.zig b/src/bun.js/webcore.zig index 9fcfebfc29..5bf55087b4 100644 --- a/src/bun.js/webcore.zig +++ b/src/bun.js/webcore.zig @@ -28,6 +28,7 @@ pub const Blob = @import("./webcore/Blob.zig"); pub const S3Stat = @import("./webcore/S3Stat.zig").S3Stat; pub const ResumableFetchSink = @import("./webcore/ResumableSink.zig").ResumableFetchSink; pub const ResumableS3UploadSink = @import("./webcore/ResumableSink.zig").ResumableS3UploadSink; +pub const ResumableSinkBackpressure = @import("./webcore/ResumableSink.zig").ResumableSinkBackpressure; pub const S3Client = @import("./webcore/S3Client.zig").S3Client; pub const Request = @import("./webcore/Request.zig"); pub const Body = @import("./webcore/Body.zig"); diff --git a/src/bun.js/webcore/ResumableSink.zig b/src/bun.js/webcore/ResumableSink.zig index ddbe325c40..3812777eb8 100644 --- a/src/bun.js/webcore/ResumableSink.zig +++ b/src/bun.js/webcore/ResumableSink.zig @@ -6,7 +6,7 @@ pub fn ResumableSink( comptime js: type, comptime Context: type, - comptime onWrite: fn (context: *Context, chunk: []const u8) bool, + comptime onWrite: fn (context: *Context, chunk: []const u8) ResumableSinkBackpressure, comptime onEnd: fn (context: *Context, err: ?jsc.JSValue) void, ) type { return struct { @@ -15,6 +15,8 @@ pub fn ResumableSink( pub const fromJS = js.fromJS; pub const fromJSDirect = js.fromJSDirect; + const ThisSink = @This(); + pub const new = bun.TrivialNew(@This()); const RefCount = bun.ptr.RefCount(@This(), "ref_count", deinit, .{}); pub const ref = RefCount.ref; @@ -26,7 +28,7 @@ pub fn ResumableSink( const setStream = js.streamSetCached; const getStream = js.streamGetCached; ref_count: RefCount, - self: jsc.Strong.Optional = jsc.Strong.Optional.empty, + #js_this: jsc.JSRef = .empty(), // We can have a detached self, and still have a strong reference to the stream stream: jsc.WebCore.ReadableStream.Strong = .{}, globalThis: *jsc.JSGlobalObject, @@ -41,16 +43,16 @@ pub fn ResumableSink( done, }; - pub fn constructor(globalThis: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!*@This() { + pub fn constructor(globalThis: *jsc.JSGlobalObject, _: *jsc.CallFrame) bun.JSError!*ThisSink { return globalThis.throwInvalidArguments("ResumableSink is not constructable", .{}); } - pub fn init(globalThis: *jsc.JSGlobalObject, stream: jsc.WebCore.ReadableStream, context: *Context) *@This() { + pub fn init(globalThis: *jsc.JSGlobalObject, stream: jsc.WebCore.ReadableStream, context: *Context) *ThisSink { return initExactRefs(globalThis, stream, context, 1); } - pub fn initExactRefs(globalThis: *jsc.JSGlobalObject, stream: jsc.WebCore.ReadableStream, context: *Context, ref_count: u32) *@This() { - const this = @This().new(.{ + pub fn initExactRefs(globalThis: *jsc.JSGlobalObject, stream: jsc.WebCore.ReadableStream, context: *Context, ref_count: u32) *ThisSink { + const this: *ThisSink = ThisSink.new(.{ .globalThis = globalThis, .context = context, .ref_count = RefCount.initExactRefs(ref_count), @@ -123,13 +125,15 @@ pub fn ResumableSink( self.ensureStillAlive(); const js_stream = stream.toJS(); js_stream.ensureStillAlive(); - _ = Bun__assignStreamIntoResumableSink(globalThis, js_stream, self); - this.self = jsc.Strong.Optional.create(self, globalThis); + this.#js_this.setStrong(self, globalThis); setStream(self, globalThis, js_stream); + + _ = Bun__assignStreamIntoResumableSink(globalThis, js_stream, self); + return this; } - pub fn jsSetHandlers(_: *@This(), globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame, this_value: jsc.JSValue) bun.JSError!jsc.JSValue { + pub fn jsSetHandlers(_: *ThisSink, globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame, this_value: jsc.JSValue) bun.JSError!jsc.JSValue { jsc.markBinding(@src()); const args = callframe.arguments(); @@ -149,7 +153,7 @@ pub fn ResumableSink( return .js_undefined; } - pub fn jsStart(this: *@This(), globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + pub fn jsStart(this: *ThisSink, globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { jsc.markBinding(@src()); const args = callframe.arguments(); if (args.len > 0 and args[0].isObject()) { @@ -161,38 +165,43 @@ pub fn ResumableSink( return .js_undefined; } - pub fn jsWrite(this: *@This(), globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + pub fn jsWrite(this: *ThisSink, globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { jsc.markBinding(@src()); const args = callframe.arguments(); // ignore any call if detached - if (!this.self.has() or this.status == .done) return .js_undefined; + if (this.isDetached()) return .js_undefined; if (args.len < 1) { return globalThis.throwInvalidArguments("ResumableSink.write requires at least 1 argument", .{}); } const buffer = args[0]; - buffer.ensureStillAlive(); - if (try jsc.Node.StringOrBuffer.fromJS(globalThis, bun.default_allocator, buffer)) |sb| { - defer sb.deinit(); - const bytes = sb.slice(); - log("jsWrite {}", .{bytes.len}); - const should_continue = onWrite(this.context, bytes); - if (!should_continue) { + const sb = try jsc.Node.StringOrBuffer.fromJS(globalThis, bun.default_allocator, buffer) orelse { + return globalThis.throwInvalidArguments("ResumableSink.write requires a string or buffer", .{}); + }; + + defer sb.deinit(); + const bytes = sb.slice(); + log("jsWrite {}", .{bytes.len}); + switch (onWrite(this.context, bytes)) { + .backpressure => { log("paused", .{}); this.status = .paused; - } - return .jsBoolean(should_continue); + }, + .done => {}, + .want_more => { + this.status = .started; + }, } - return globalThis.throwInvalidArguments("ResumableSink.write requires a string or buffer", .{}); + return .jsBoolean(this.status != .paused); } - pub fn jsEnd(this: *@This(), _: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { + pub fn jsEnd(this: *ThisSink, _: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!jsc.JSValue { jsc.markBinding(@src()); const args = callframe.arguments(); // ignore any call if detached - if (!this.self.has() or this.status == .done) return .js_undefined; + if (this.isDetached()) return .js_undefined; this.detachJS(); log("jsEnd {}", .{args.len}); this.status = .done; @@ -201,86 +210,73 @@ pub fn ResumableSink( return .js_undefined; } - pub fn drain(this: *@This()) void { + pub fn drain(this: *ThisSink) void { log("drain", .{}); if (this.status != .paused) { return; } - if (this.self.get()) |js_this| { + if (this.#js_this.tryGet()) |js_this| { const globalObject = this.globalThis; - const vm = globalObject.bunVM(); - vm.eventLoop().enter(); - defer vm.eventLoop().exit(); + if (getDrain(js_this)) |ondrain| { - if (ondrain.isCallable()) { - this.status = .started; - _ = ondrain.call(globalObject, .js_undefined, &.{.js_undefined}) catch |err| { - // should never happen - bun.debugAssert(false); - _ = globalObject.takeError(err); - }; - } + this.status = .started; + globalObject.bunVM().eventLoop().runCallback(ondrain, globalObject, .js_undefined, &.{ .js_undefined, .js_undefined }); } } } - pub fn cancel(this: *@This(), reason: jsc.JSValue) void { + pub fn cancel(this: *ThisSink, reason: jsc.JSValue) void { if (this.status == .piped) { reason.ensureStillAlive(); this.endPipe(reason); return; } - if (this.self.get()) |js_this| { + if (this.#js_this.tryGet()) |js_this| { this.status = .done; js_this.ensureStillAlive(); + const onCancelCallback = getCancel(js_this); const globalObject = this.globalThis; - const vm = globalObject.bunVM(); - vm.eventLoop().enter(); - defer vm.eventLoop().exit(); - if (getCancel(js_this)) |oncancel| { - oncancel.ensureStillAlive(); - // detach first so if cancel calls end will be a no-op - this.detachJS(); - // call onEnd to indicate the native side that the stream errored - onEnd(this.context, reason); - if (oncancel.isCallable()) { - _ = oncancel.call(globalObject, .js_undefined, &.{ .js_undefined, reason }) catch |err| { - // should never happen - bun.debugAssert(false); - _ = globalObject.takeError(err); - }; - } - } else { - // should never happen but lets call onEnd to indicate the native side that the stream errored - this.detachJS(); - onEnd(this.context, reason); + // detach first so if cancel calls end will be a no-op + this.detachJS(); + + // call onEnd to indicate the native side that the stream errored + onEnd(this.context, reason); + + js_this.ensureStillAlive(); + if (onCancelCallback) |callback| { + const event_loop = globalObject.bunVM().eventLoop(); + event_loop.runCallback(callback, globalObject, .js_undefined, &.{ .js_undefined, reason }); } } } - fn detachJS(this: *@This()) void { - if (this.self.trySwap()) |js_this| { + pub fn isDetached(this: *const ThisSink) bool { + return this.#js_this != .strong or this.status == .done; + } + + fn detachJS(this: *ThisSink) void { + if (this.#js_this.tryGet()) |js_this| { setDrain(js_this, this.globalThis, .zero); setCancel(js_this, this.globalThis, .zero); setStream(js_this, this.globalThis, .zero); - this.self.deinit(); - this.self = jsc.Strong.Optional.empty; + this.#js_this.downgrade(); } } - pub fn deinit(this: *@This()) void { + pub fn deinit(this: *ThisSink) void { this.detachJS(); this.stream.deinit(); bun.destroy(this); } - pub fn finalize(this: *@This()) void { + pub fn finalize(this: *ThisSink) void { + this.#js_this.finalize(); this.deref(); } fn onStreamPipe( - this: *@This(), + this: *ThisSink, stream: bun.webcore.streams.Result, allocator: std.mem.Allocator, ) void { @@ -298,7 +294,10 @@ pub fn ResumableSink( } const chunk = stream.slice(); log("onWrite {}", .{chunk.len}); - const stopStream = !onWrite(this.context, chunk); + + // TODO: should the "done" state also trigger `endPipe`? + _ = onWrite(this.context, chunk); + const is_done = stream.isDone(); if (is_done) { @@ -313,34 +312,31 @@ pub fn ResumableSink( break :brk_err null; }; this.endPipe(err); - } else if (stopStream) { - // dont make sense pausing the stream here - // it will be buffered in the pipe anyways } } - fn endPipe(this: *@This(), err: ?jsc.JSValue) void { + fn endPipe(this: *ThisSink, err: ?jsc.JSValue) void { log("endPipe", .{}); if (this.status != .piped) return; this.status = .done; - if (this.stream.get(this.globalThis)) |stream_| { + const globalObject = this.globalThis; + if (this.stream.get(globalObject)) |stream_| { if (stream_.ptr == .Bytes) { stream_.ptr.Bytes.pipe = .{}; } if (err != null) { - stream_.cancel(this.globalThis); + stream_.cancel(globalObject); } else { - stream_.done(this.globalThis); + stream_.done(globalObject); } var stream = this.stream; this.stream = .{}; stream.deinit(); } - // We ref when we attach the stream so we deref when we detach the stream - this.deref(); onEnd(this.context, err); - if (this.self.has()) { + + if (this.#js_this == .strong) { // JS owns the stream, so we need to detach the JS and let finalize handle the deref // this should not happen but lets handle it anyways this.detachJS(); @@ -348,10 +344,17 @@ pub fn ResumableSink( // no js attached, so we can just deref this.deref(); } + + // We ref when we attach the stream so we deref when we detach the stream + this.deref(); } }; } - +pub const ResumableSinkBackpressure = enum { + want_more, + backpressure, + done, +}; pub const ResumableFetchSink = ResumableSink(jsc.Codegen.JSResumableFetchSink, FetchTasklet, FetchTasklet.writeRequestData, FetchTasklet.writeEndRequest); pub const ResumableS3UploadSink = ResumableSink(jsc.Codegen.JSResumableS3UploadSink, S3UploadStreamWrapper, S3UploadStreamWrapper.writeRequestData, S3UploadStreamWrapper.writeEndRequest); diff --git a/src/bun.js/webcore/fetch.zig b/src/bun.js/webcore/fetch.zig index 633f17d16c..b48c2f406b 100644 --- a/src/bun.js/webcore/fetch.zig +++ b/src/bun.js/webcore/fetch.zig @@ -302,6 +302,8 @@ pub const FetchTasklet = struct { this.abort_reason.deinit(); this.check_server_identity.deinit(); this.clearAbortSignal(); + // Clear the sink only after the requested ended otherwise we would potentialy lose the last chunk + this.clearSink(); } pub fn deinit(this: *FetchTasklet) void { @@ -343,6 +345,13 @@ pub const FetchTasklet = struct { this.is_waiting_request_stream_start = false; bun.assert(this.request_body == .ReadableStream); if (this.request_body.ReadableStream.get(this.global_this)) |stream| { + if (this.signal) |signal| { + if (signal.aborted()) { + stream.abort(this.global_this); + return; + } + } + const globalThis = this.global_this; this.ref(); // lets only unref when sink is done // +1 because the task refs the sink @@ -1176,53 +1185,63 @@ pub const FetchTasklet = struct { pub fn resumeRequestDataStream(this: *FetchTasklet) void { // deref when done because we ref inside onWriteRequestDataDrain defer this.deref(); + log("resumeRequestDataStream", .{}); if (this.sink) |sink| { + if (this.signal) |signal| { + if (signal.aborted()) { + // already aborted; nothing to drain + return; + } + } sink.drain(); } } - pub fn writeRequestData(this: *FetchTasklet, data: []const u8) bool { + pub fn writeRequestData(this: *FetchTasklet, data: []const u8) ResumableSinkBackpressure { log("writeRequestData {}", .{data.len}); - if (this.request_body_streaming_buffer) |buffer| { - const highWaterMark = if (this.sink) |sink| sink.highWaterMark else 16384; - const stream_buffer = buffer.acquire(); - var needs_schedule = false; - defer if (needs_schedule) { - // wakeup the http thread to write the data - http.http_thread.scheduleRequestWrite(this.http.?, .data); - }; - defer buffer.release(); - - // dont have backpressure so we will schedule the data to be written - // if we have backpressure the onWritable will drain the buffer - needs_schedule = stream_buffer.isEmpty(); - if (this.upgraded_connection) { - bun.handleOom(stream_buffer.write(data)); - } else { - //16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n - var formated_size_buffer: [18]u8 = undefined; - const formated_size = std.fmt.bufPrint( - formated_size_buffer[0..], - "{x}\r\n", - .{data.len}, - ) catch |err| switch (err) { - error.NoSpaceLeft => unreachable, - }; - bun.handleOom(stream_buffer.ensureUnusedCapacity(formated_size.len + data.len + 2)); - stream_buffer.writeAssumeCapacity(formated_size); - stream_buffer.writeAssumeCapacity(data); - stream_buffer.writeAssumeCapacity("\r\n"); + if (this.signal) |signal| { + if (signal.aborted()) { + return .done; } - - // pause the stream if we hit the high water mark - return stream_buffer.size() >= highWaterMark; } - return false; + const thread_safe_stream_buffer = this.request_body_streaming_buffer orelse return .done; + const stream_buffer = thread_safe_stream_buffer.acquire(); + defer thread_safe_stream_buffer.release(); + const highWaterMark = if (this.sink) |sink| sink.highWaterMark else 16384; + + var needs_schedule = false; + defer if (needs_schedule) { + // wakeup the http thread to write the data + http.http_thread.scheduleRequestWrite(this.http.?, .data); + }; + + // dont have backpressure so we will schedule the data to be written + // if we have backpressure the onWritable will drain the buffer + needs_schedule = stream_buffer.isEmpty(); + if (this.upgraded_connection) { + bun.handleOom(stream_buffer.write(data)); + } else { + //16 is the max size of a hex number size that represents 64 bits + 2 for the \r\n + var formated_size_buffer: [18]u8 = undefined; + const formated_size = std.fmt.bufPrint( + formated_size_buffer[0..], + "{x}\r\n", + .{data.len}, + ) catch |err| switch (err) { + error.NoSpaceLeft => unreachable, + }; + bun.handleOom(stream_buffer.ensureUnusedCapacity(formated_size.len + data.len + 2)); + stream_buffer.writeAssumeCapacity(formated_size); + stream_buffer.writeAssumeCapacity(data); + stream_buffer.writeAssumeCapacity("\r\n"); + } + + // pause the stream if we hit the high water mark + return if (stream_buffer.size() >= highWaterMark) .backpressure else .want_more; } pub fn writeEndRequest(this: *FetchTasklet, err: ?jsc.JSValue) void { log("writeEndRequest hasError? {}", .{err != null}); - this.clearSink(); defer this.deref(); if (err) |jsError| { if (this.signal_store.aborted.load(.monotonic) or this.abort_reason.has()) { @@ -1233,9 +1252,16 @@ pub const FetchTasklet = struct { } this.abortTask(); } else { + if (!this.upgraded_connection) { + // If is not upgraded we need to send the terminating chunk + const thread_safe_stream_buffer = this.request_body_streaming_buffer orelse return; + const stream_buffer = thread_safe_stream_buffer.acquire(); + defer thread_safe_stream_buffer.release(); + bun.handleOom(stream_buffer.write(http.end_of_chunked_http1_1_encoding_response_body)); + } if (this.http) |http_| { // just tell to write the end of the chunked encoding aka 0\r\n\r\n - http.http_thread.scheduleRequestWrite(http_, .endChunked); + http.http_thread.scheduleRequestWrite(http_, .end); } } } @@ -2743,6 +2769,7 @@ const JSType = jsc.C.JSType; const Body = jsc.WebCore.Body; const Request = jsc.WebCore.Request; const Response = jsc.WebCore.Response; +const ResumableSinkBackpressure = jsc.WebCore.ResumableSinkBackpressure; const Blob = jsc.WebCore.Blob; const AnyBlob = jsc.WebCore.Blob.Any; diff --git a/src/http.zig b/src/http.zig index b081db2ecc..c5cea683eb 100644 --- a/src/http.zig +++ b/src/http.zig @@ -6,7 +6,7 @@ pub var default_arena: Arena = undefined; pub var http_thread: HTTPThread = undefined; //TODO: this needs to be freed when Worker Threads are implemented -pub var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, uws.InternalSocket).init(bun.default_allocator); +pub var socket_async_http_abort_tracker = std.AutoArrayHashMap(u32, uws.AnySocket).init(bun.default_allocator); pub var async_http_id_monotonic: std.atomic.Value(u32) = std.atomic.Value(u32).init(0); const MAX_REDIRECT_URL_LENGTH = 128 * 1024; @@ -107,7 +107,10 @@ pub fn registerAbortTracker( socket: NewHTTPContext(is_ssl).HTTPSocket, ) void { if (client.signals.aborted != null) { - socket_async_http_abort_tracker.put(client.async_http_id, socket.socket) catch unreachable; + switch (is_ssl) { + true => socket_async_http_abort_tracker.put(client.async_http_id, .{ .SocketTLS = socket }) catch unreachable, + false => socket_async_http_abort_tracker.put(client.async_http_id, .{ .SocketTCP = socket }) catch unreachable, + } } } @@ -139,6 +142,9 @@ pub fn onOpen( return error.ClientAborted; } + if (client.state.request_stage == .pending) + client.state.request_stage = .opened; + if (comptime is_ssl) { var ssl_ptr: *BoringSSL.SSL = @ptrCast(socket.getNativeHandle()); if (!ssl_ptr.isInitFinished()) { @@ -181,8 +187,11 @@ pub fn firstCall( } } - if (client.state.request_stage == .pending) { - client.onWritable(true, comptime is_ssl, socket); + switch (client.state.request_stage) { + .opened, .pending => { + client.onWritable(true, comptime is_ssl, socket); + }, + else => {}, } } pub fn onClose( @@ -724,10 +733,7 @@ pub fn doRedirect( log("close the tunnel in redirect", .{}); this.proxy_tunnel = null; tunnel.detachAndDeref(); - if (!socket.isClosed()) { - log("close socket in redirect", .{}); - NewHTTPContext(is_ssl).closeSocket(socket); - } + NewHTTPContext(is_ssl).closeSocket(socket); } else { // we need to clean the client reference before closing the socket because we are going to reuse the same ref in a another request if (this.isKeepAlivePossible()) { @@ -762,6 +768,8 @@ pub fn doRedirect( return this.start(.{ .bytes = request_body }, body_out_str); } + +/// **Not thread safe while request is in-flight** pub fn isHTTPS(this: *HTTPClient) bool { if (this.http_proxy) |proxy| { if (proxy.isHTTPS()) { @@ -774,6 +782,7 @@ pub fn isHTTPS(this: *HTTPClient) bool { } return false; } + pub fn start(this: *HTTPClient, body: HTTPRequestBody, body_out_str: *MutableString) void { body_out_str.reset(); @@ -788,6 +797,8 @@ pub fn start(this: *HTTPClient, body: HTTPRequestBody, body_out_str: *MutableStr } fn start_(this: *HTTPClient, comptime is_ssl: bool) void { + this.unregisterAbortTracker(); + // mark that we are connecting this.flags.defer_fail_until_connecting_is_complete = true; // this will call .fail() if the connection fails in the middle of the function avoiding UAF with can happen when the connection is aborted @@ -819,6 +830,18 @@ fn start_(this: *HTTPClient, comptime is_ssl: bool) void { this.fail(error.ConnectionClosed); return; } + + // If we haven't already called onOpen(), then that means we need to + // register the abort tracker. We need to do this in cases where the + // connection takes a long time to happen such as when it's not routable. + // See test/js/bun/io/fetch/fetch-abort-slow-connect.test.ts. + // + // We have to be careful here because if .connect() had finished + // synchronously, then this socket is on longer valid and the pointer points + // to invalid memory. + if (this.state.request_stage == .pending) { + this.registerAbortTracker(is_ssl, socket); + } } pub const HTTPResponseMetadata = struct { @@ -1003,8 +1026,8 @@ fn writeToSocketWithBufferFallback(comptime is_ssl: bool, socket: NewHTTPContext /// Write buffered data to the socket returning true if there is backpressure fn writeToStreamUsingBuffer(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket, buffer: *bun.io.StreamBuffer, data: []const u8) !bool { - if (buffer.isNotEmpty()) { - const to_send = buffer.slice(); + const to_send = buffer.slice(); + if (to_send.len > 0) { const amount = try writeToSocket(is_ssl, socket, to_send); this.state.request_sent_len += amount; buffer.cursor += amount; @@ -1020,6 +1043,7 @@ fn writeToStreamUsingBuffer(this: *HTTPClient, comptime is_ssl: bool, socket: Ne buffer.reset(); } } + // ok we flushed all pending data so we can reset the backpressure if (data.len > 0) { // no backpressure everything was sended so we can just try to send @@ -1109,7 +1133,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s } switch (this.state.request_stage) { - .pending, .headers => { + .pending, .headers, .opened => { log("sendInitialRequestPayload", .{}); this.setTimeout(socket, 5); const result = sendInitialRequestPayload(this, is_first_call, is_ssl, socket) catch |err| { @@ -1164,13 +1188,15 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s switch (this.state.original_request_body) { .bytes => { const to_send = this.state.request_body; - const sent = writeToSocket(is_ssl, socket, to_send) catch |err| { - this.closeAndFail(err, is_ssl, socket); - return; - }; + if (to_send.len > 0) { + const sent = writeToSocket(is_ssl, socket, to_send) catch |err| { + this.closeAndFail(err, is_ssl, socket); + return; + }; - this.state.request_sent_len += sent; - this.state.request_body = this.state.request_body[sent..]; + this.state.request_sent_len += sent; + this.state.request_body = this.state.request_body[sent..]; + } if (this.state.request_body.len == 0) { this.state.request_stage = .done; @@ -1312,9 +1338,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s pub fn closeAndFail(this: *HTTPClient, err: anyerror, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket) void { log("closeAndFail: {s}", .{@errorName(err)}); - if (!socket.isClosed()) { - NewHTTPContext(is_ssl).terminateSocket(socket); - } + NewHTTPContext(is_ssl).terminateSocket(socket); this.fail(err); } @@ -1684,10 +1708,7 @@ pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPCon this.proxy_tunnel = null; tunnel.shutdown(); tunnel.detachAndDeref(); - if (!socket.isClosed()) { - log("close socket", .{}); - NewHTTPContext(is_ssl).closeSocket(socket); - } + NewHTTPContext(is_ssl).closeSocket(socket); } else { if (this.isKeepAlivePossible() and !socket.isClosedOrHasError()) { log("release socket", .{}); @@ -1697,8 +1718,7 @@ pub fn progressUpdate(this: *HTTPClient, comptime is_ssl: bool, ctx: *NewHTTPCon this.connected_url.hostname, this.connected_url.getPortAuto(), ); - } else if (!socket.isClosed()) { - log("close socket", .{}); + } else { NewHTTPContext(is_ssl).closeSocket(socket); } } diff --git a/src/http/HTTPContext.zig b/src/http/HTTPContext.zig index bc36f9abef..8cf733037e 100644 --- a/src/http/HTTPContext.zig +++ b/src/http/HTTPContext.zig @@ -10,10 +10,18 @@ pub fn NewHTTPContext(comptime ssl: bool) type { did_have_handshaking_error_while_reject_unauthorized_is_false: bool = false, }; - pub fn markSocketAsDead(socket: HTTPSocket) void { - if (socket.ext(**anyopaque)) |ctx| { - ctx.* = bun.cast(**anyopaque, ActiveSocket.init(&dead_socket).ptr()); + pub fn markTaggedSocketAsDead(socket: HTTPSocket, tagged: ActiveSocket) void { + if (tagged.is(PooledSocket)) { + Handler.addMemoryBackToPool(tagged.as(PooledSocket)); } + + if (socket.ext(**anyopaque)) |ctx| { + ctx.* = bun.cast(**anyopaque, ActiveSocket.init(dead_socket).ptr()); + } + } + + pub fn markSocketAsDead(socket: HTTPSocket) void { + markTaggedSocketAsDead(socket, getTaggedFromSocket(socket)); } pub fn terminateSocket(socket: HTTPSocket) void { @@ -34,7 +42,7 @@ pub fn NewHTTPContext(comptime ssl: bool) type { if (socket.ext(anyopaque)) |ctx| { return getTagged(ctx); } - return ActiveSocket.init(&dead_socket); + return ActiveSocket.init(dead_socket); } pub const PooledSocketHiveAllocator = bun.HiveArray(PooledSocket, pool_size); @@ -54,7 +62,7 @@ pub fn NewHTTPContext(comptime ssl: bool) type { } const ActiveSocket = TaggedPointerUnion(.{ - *DeadSocket, + DeadSocket, HTTPClient, PooledSocket, }); @@ -208,11 +216,6 @@ pub fn NewHTTPContext(comptime ssl: bool) type { } } - if (active.get(PooledSocket)) |pooled| { - addMemoryBackToPool(pooled); - return; - } - log("Unexpected open on unknown socket", .{}); terminateSocket(socket); } @@ -268,9 +271,6 @@ pub fn NewHTTPContext(comptime ssl: bool) type { if (socket.isClosed()) { markSocketAsDead(socket); - if (active.get(PooledSocket)) |pooled| { - addMemoryBackToPool(pooled); - } return; } @@ -284,10 +284,6 @@ pub fn NewHTTPContext(comptime ssl: bool) type { } } - if (active.get(PooledSocket)) |pooled| { - addMemoryBackToPool(pooled); - } - terminateSocket(socket); } pub fn onClose( @@ -302,12 +298,6 @@ pub fn NewHTTPContext(comptime ssl: bool) type { if (tagged.get(HTTPClient)) |client| { return client.onClose(comptime ssl, socket); } - - if (tagged.get(PooledSocket)) |pooled| { - addMemoryBackToPool(pooled); - } - - return; } fn addMemoryBackToPool(pooled: *PooledSocket) void { @@ -366,10 +356,6 @@ pub fn NewHTTPContext(comptime ssl: bool) type { const tagged = getTagged(ptr); if (tagged.get(HTTPClient)) |client| { return client.onTimeout(comptime ssl, socket); - } else if (tagged.get(PooledSocket)) |pooled| { - // If a socket has been sitting around for 5 minutes - // Let's close it and remove it from the pool. - addMemoryBackToPool(pooled); } terminateSocket(socket); @@ -380,16 +366,14 @@ pub fn NewHTTPContext(comptime ssl: bool) type { _: c_int, ) void { const tagged = getTagged(ptr); - markSocketAsDead(socket); + markTaggedSocketAsDead(socket, tagged); if (tagged.get(HTTPClient)) |client| { client.onConnectError(); - } else if (tagged.get(PooledSocket)) |pooled| { - addMemoryBackToPool(pooled); } // us_connecting_socket_close is always called internally by uSockets } pub fn onEnd( - _: *anyopaque, + ptr: *anyopaque, socket: HTTPSocket, ) void { // TCP fin must be closed, but we must keep the original tagged @@ -399,7 +383,14 @@ pub fn NewHTTPContext(comptime ssl: bool) type { // 1. HTTP Keep-Alive socket: it must be removed from the pool // 2. HTTP Client socket: it might need to be retried // 3. Dead socket: it is already marked as dead + const tagged = getTagged(ptr); + markTaggedSocketAsDead(socket, tagged); socket.close(.failure); + + if (tagged.get(HTTPClient)) |client| { + client.onClose(comptime ssl, socket); + return; + } } }; @@ -489,8 +480,12 @@ pub fn NewHTTPContext(comptime ssl: bool) type { }; } -const DeadSocket = opaque {}; -var dead_socket = @as(*DeadSocket, @ptrFromInt(1)); +const DeadSocket = struct { + garbage: u8 = 0, + pub var dead_socket: DeadSocket = .{}; +}; + +var dead_socket = &DeadSocket.dead_socket; const log = bun.Output.scoped(.HTTPContext, .hidden); const HTTPCertError = @import("./HTTPCertError.zig"); diff --git a/src/http/HTTPThread.zig b/src/http/HTTPThread.zig index 23f0514621..c65dd26ec8 100644 --- a/src/http/HTTPThread.zig +++ b/src/http/HTTPThread.zig @@ -13,8 +13,7 @@ queued_writes: std.ArrayListUnmanaged(WriteMessage) = std.ArrayListUnmanaged(Wri queued_shutdowns_lock: bun.Mutex = .{}, queued_writes_lock: bun.Mutex = .{}, - -queued_proxy_deref: std.ArrayListUnmanaged(*ProxyTunnel) = std.ArrayListUnmanaged(*ProxyTunnel){}, +queued_threadlocal_proxy_derefs: std.ArrayListUnmanaged(*ProxyTunnel) = std.ArrayListUnmanaged(*ProxyTunnel){}, has_awoken: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), timer: std.time.Timer, @@ -82,21 +81,15 @@ pub const RequestBodyBuffer = union(enum) { const threadlog = Output.scoped(.HTTPThread, .hidden); const WriteMessage = struct { async_http_id: u32, - flags: packed struct(u8) { - is_tls: bool, - type: Type, - _: u5 = 0, - }, + message_type: Type, pub const Type = enum(u2) { data = 0, end = 1, - endChunked = 2, }; }; const ShutdownMessage = struct { async_http_id: u32, - is_tls: bool, }; pub const LibdeflateState = struct { @@ -285,62 +278,96 @@ pub fn context(this: *@This(), comptime is_ssl: bool) *NewHTTPContext(is_ssl) { return if (is_ssl) &this.https_context else &this.http_context; } -fn drainEvents(this: *@This()) void { - { - this.queued_shutdowns_lock.lock(); - defer this.queued_shutdowns_lock.unlock(); - for (this.queued_shutdowns.items) |http| { +fn drainQueuedShutdowns(this: *@This()) void { + while (true) { + // socket.close() can potentially be slow + // Let's not block other threads while this runs. + var queued_shutdowns = brk: { + this.queued_shutdowns_lock.lock(); + defer this.queued_shutdowns_lock.unlock(); + const shutdowns = this.queued_shutdowns; + this.queued_shutdowns = .{}; + break :brk shutdowns; + }; + defer queued_shutdowns.deinit(bun.default_allocator); + + for (queued_shutdowns.items) |http| { if (bun.http.socket_async_http_abort_tracker.fetchSwapRemove(http.async_http_id)) |socket_ptr| { - if (http.is_tls) { - const socket = uws.SocketTLS.fromAny(socket_ptr.value); - // do a fast shutdown here since we are aborting and we dont want to wait for the close_notify from the other side - socket.close(.failure); - } else { - const socket = uws.SocketTCP.fromAny(socket_ptr.value); - socket.close(.failure); + switch (socket_ptr.value) { + inline .SocketTLS, .SocketTCP => |socket, tag| { + const is_tls = tag == .SocketTLS; + const HTTPContext = HTTPThread.NewHTTPContext(comptime is_tls); + const tagged = HTTPContext.getTaggedFromSocket(socket); + if (tagged.get(HTTPClient)) |client| { + // If we only call socket.close(), then it won't + // call `onClose` if this happens before `onOpen` is + // called. + // + client.closeAndAbort(comptime is_tls, socket); + continue; + } + socket.close(.failure); + }, } } } - this.queued_shutdowns.clearRetainingCapacity(); + if (queued_shutdowns.items.len == 0) { + break; + } + threadlog("drained {d} queued shutdowns", .{queued_shutdowns.items.len}); } - { - this.queued_writes_lock.lock(); - defer this.queued_writes_lock.unlock(); - for (this.queued_writes.items) |write| { - const flags = write.flags; - const messageType = flags.type; - const ended = messageType == .end or messageType == .endChunked; +} + +fn drainQueuedWrites(this: *@This()) void { + while (true) { + var queued_writes = brk: { + this.queued_writes_lock.lock(); + defer this.queued_writes_lock.unlock(); + const writes = this.queued_writes; + this.queued_writes = .{}; + break :brk writes; + }; + defer queued_writes.deinit(bun.default_allocator); + for (queued_writes.items) |write| { + const message = write.message_type; + const ended = message == .end; if (bun.http.socket_async_http_abort_tracker.get(write.async_http_id)) |socket_ptr| { - switch (flags.is_tls) { - inline true, false => |is_tls| { - const socket = uws.NewSocketHandler(is_tls).fromAny(socket_ptr); + switch (socket_ptr) { + inline .SocketTLS, .SocketTCP => |socket, tag| { + const is_tls = tag == .SocketTLS; if (socket.isClosed() or socket.isShutdown()) { continue; } - const tagged = NewHTTPContext(is_tls).getTaggedFromSocket(socket); + const tagged = NewHTTPContext(comptime is_tls).getTaggedFromSocket(socket); if (tagged.get(HTTPClient)) |client| { if (client.state.original_request_body == .stream) { var stream = &client.state.original_request_body.stream; stream.ended = ended; - if (messageType == .endChunked and client.flags.upgrade_state != .upgraded) { - // only send the 0-length chunk if the request body is chunked and not upgraded - client.writeToStream(is_tls, socket, bun.http.end_of_chunked_http1_1_encoding_response_body); - } else { - client.flushStream(is_tls, socket); - } + + client.flushStream(is_tls, socket); } } }, } } } - this.queued_writes.clearRetainingCapacity(); + if (queued_writes.items.len == 0) { + break; + } + threadlog("drained {d} queued writes", .{queued_writes.items.len}); } +} - while (this.queued_proxy_deref.pop()) |http| { +fn drainEvents(this: *@This()) void { + // Process any pending writes **before** aborting. + this.drainQueuedWrites(); + this.drainQueuedShutdowns(); + + for (this.queued_threadlocal_proxy_derefs.items) |http| { http.deref(); } + this.queued_threadlocal_proxy_derefs.clearRetainingCapacity(); var count: usize = 0; var active = AsyncHTTP.active_requests_count.load(.monotonic); @@ -379,6 +406,14 @@ fn processEvents(this: *@This()) noreturn { while (true) { this.drainEvents(); + if (comptime Environment.isDebug and bun.asan.enabled) { + for (bun.http.socket_async_http_abort_tracker.keys(), bun.http.socket_async_http_abort_tracker.values()) |http_id, socket| { + if (socket.socket().get()) |usocket| { + _ = http_id; + bun.asan.assertUnpoisoned(usocket); + } + } + } var start_time: i128 = 0; if (comptime Environment.isDebug) { @@ -390,6 +425,15 @@ fn processEvents(this: *@This()) noreturn { this.loop.loop.tick(); this.loop.loop.dec(); + if (comptime Environment.isDebug and bun.asan.enabled) { + for (bun.http.socket_async_http_abort_tracker.keys(), bun.http.socket_async_http_abort_tracker.values()) |http_id, socket| { + if (socket.socket().get()) |usocket| { + _ = http_id; + bun.asan.assertUnpoisoned(usocket); + } + } + } + // this.loop.run(); if (comptime Environment.isDebug) { const end = std.time.nanoTimestamp(); @@ -400,12 +444,12 @@ fn processEvents(this: *@This()) noreturn { } pub fn scheduleShutdown(this: *@This(), http: *AsyncHTTP) void { + threadlog("scheduleShutdown {d}", .{http.async_http_id}); { this.queued_shutdowns_lock.lock(); defer this.queued_shutdowns_lock.unlock(); this.queued_shutdowns.append(bun.default_allocator, .{ .async_http_id = http.async_http_id, - .is_tls = http.client.isHTTPS(), }) catch |err| bun.handleOom(err); } if (this.has_awoken.load(.monotonic)) @@ -418,10 +462,7 @@ pub fn scheduleRequestWrite(this: *@This(), http: *AsyncHTTP, messageType: Write defer this.queued_writes_lock.unlock(); this.queued_writes.append(bun.default_allocator, .{ .async_http_id = http.async_http_id, - .flags = .{ - .is_tls = http.client.isHTTPS(), - .type = messageType, - }, + .message_type = messageType, }) catch |err| bun.handleOom(err); } if (this.has_awoken.load(.monotonic)) @@ -429,10 +470,8 @@ pub fn scheduleRequestWrite(this: *@This(), http: *AsyncHTTP, messageType: Write } pub fn scheduleProxyDeref(this: *@This(), proxy: *ProxyTunnel) void { - // this is always called on the http thread - { - bun.handleOom(this.queued_proxy_deref.append(bun.default_allocator, proxy)); - } + // this is always called on the http thread, + bun.handleOom(this.queued_threadlocal_proxy_derefs.append(bun.default_allocator, proxy)); if (this.has_awoken.load(.monotonic)) this.loop.loop.wakeup(); } @@ -473,7 +512,6 @@ const Global = bun.Global; const Output = bun.Output; const jsc = bun.jsc; const strings = bun.strings; -const uws = bun.uws; const Arena = bun.allocators.MimallocArena; const Batch = bun.ThreadPool.Batch; const UnboundedQueue = bun.threading.UnboundedQueue; diff --git a/src/http/InternalState.zig b/src/http/InternalState.zig index dab63e9053..30fed9fd3a 100644 --- a/src/http/InternalState.zig +++ b/src/http/InternalState.zig @@ -221,6 +221,10 @@ const log = Output.scoped(.HTTPInternalState, .hidden); const HTTPStage = enum { pending, + + /// The `onOpen` callback has been called for the first time. + opened, + headers, body, body_chunk, diff --git a/src/js/builtins/ReadableStreamInternals.ts b/src/js/builtins/ReadableStreamInternals.ts index abb4873b22..ff1e9130c4 100644 --- a/src/js/builtins/ReadableStreamInternals.ts +++ b/src/js/builtins/ReadableStreamInternals.ts @@ -1204,7 +1204,10 @@ export function onCloseDirectStream(reason) { stream = undefined; return thisResult; }; - } else if (this._pendingRead) { + // We will close after the next $pull is called otherwise we would lost the last chunk + return; + } + if (this._pendingRead) { var read = this._pendingRead; this._pendingRead = undefined; $putByIdDirectPrivate(this, "pull", $noopDoneFunction); @@ -1796,6 +1799,66 @@ export function readableStreamFromAsyncIterator(target, fn) { throw new TypeError("Expected an async generator"); } + var runningAsyncIteratorPromise; + async function runAsyncIterator(controller) { + var closingError: Error | undefined, value, done, immediateTask; + + try { + while (!cancelled && !done) { + const promise = iter.next(controller); + + if (cancelled) { + return; + } + + if ($isPromise(promise) && $isPromiseFulfilled(promise)) { + clearImmediate(immediateTask); + ({ value, done } = $getPromiseInternalField(promise, $promiseFieldReactionsOrResult)); + $assert(!$isPromise(value), "Expected a value, not a promise"); + } else { + immediateTask = setImmediate(() => immediateTask && controller?.flush?.(true)); + ({ value, done } = await promise); + + if (cancelled) { + return; + } + } + + if (!$isUndefinedOrNull(value)) { + controller.write(value); + } + } + } catch (e) { + closingError = e; + } finally { + clearImmediate(immediateTask); + immediateTask = undefined; + // "iter" will be undefined if the stream was closed above. + + // Stream was closed before we tried writing to it. + if (closingError?.code === "ERR_INVALID_THIS") { + await iter?.return?.(); + return; + } + + if (closingError) { + try { + await iter.throw?.(closingError); + } finally { + iter = undefined; + // eslint-disable-next-line no-throw-literal + throw closingError; + } + } else { + await controller.end(); + if (iter) { + await iter.return?.(); + } + } + iter = undefined; + } + } + return new ReadableStream({ type: "direct", @@ -1826,62 +1889,23 @@ export function readableStreamFromAsyncIterator(target, fn) { }, async pull(controller) { - var closingError: Error | undefined, value, done, immediateTask; - - try { - while (!cancelled && !done) { - const promise = iter.next(controller); - - if (cancelled) { - return; - } - - if ($isPromise(promise) && $isPromiseFulfilled(promise)) { - clearImmediate(immediateTask); - ({ value, done } = $getPromiseInternalField(promise, $promiseFieldReactionsOrResult)); - $assert(!$isPromise(value), "Expected a value, not a promise"); - } else { - immediateTask = setImmediate(() => immediateTask && controller?.flush?.(true)); - ({ value, done } = await promise); - - if (cancelled) { - return; - } - } - - if (!$isUndefinedOrNull(value)) { - controller.write(value); + // pull() may be called multiple times before a single call completes. + // + // But, we only call into the stream once while a stream is in-progress. + if (!runningAsyncIteratorPromise) { + const asyncIteratorPromise = runAsyncIterator(controller); + runningAsyncIteratorPromise = asyncIteratorPromise; + try { + const result = await asyncIteratorPromise; + return result; + } finally { + if (runningAsyncIteratorPromise === asyncIteratorPromise) { + runningAsyncIteratorPromise = undefined; } } - } catch (e) { - closingError = e; - } finally { - clearImmediate(immediateTask); - immediateTask = undefined; - // "iter" will be undefined if the stream was closed above. - - // Stream was closed before we tried writing to it. - if (closingError?.code === "ERR_INVALID_THIS") { - await iter?.return?.(); - return; - } - - if (closingError) { - try { - await iter.throw?.(closingError); - } finally { - iter = undefined; - // eslint-disable-next-line no-throw-literal - throw closingError; - } - } else { - await controller.end(); - if (iter) { - await iter.return?.(); - } - } - iter = undefined; } + + return runningAsyncIteratorPromise; }, }); } diff --git a/src/s3/client.zig b/src/s3/client.zig index 1117409b8a..d08e774070 100644 --- a/src/s3/client.zig +++ b/src/s3/client.zig @@ -359,7 +359,7 @@ pub const S3UploadStreamWrapper = struct { } } - pub fn writeRequestData(this: *@This(), data: []const u8) bool { + pub fn writeRequestData(this: *@This(), data: []const u8) ResumableSinkBackpressure { log("writeRequestData {}", .{data.len}); return bun.handleOom(this.task.writeBytes(data, false)); } @@ -685,3 +685,4 @@ const std = @import("std"); const bun = @import("bun"); const jsc = bun.jsc; const picohttp = bun.picohttp; +const ResumableSinkBackpressure = jsc.WebCore.ResumableSinkBackpressure; diff --git a/src/s3/multipart.zig b/src/s3/multipart.zig index acc303469b..66b012a195 100644 --- a/src/s3/multipart.zig +++ b/src/s3/multipart.zig @@ -704,8 +704,8 @@ pub const MultiPartUpload = struct { utf16, }; - fn write(this: *@This(), chunk: []const u8, is_last: bool, comptime encoding: WriteEncoding) bun.OOM!bool { - if (this.ended) return true; // no backpressure since we are done + fn write(this: *@This(), chunk: []const u8, is_last: bool, comptime encoding: WriteEncoding) bun.OOM!ResumableSinkBackpressure { + if (this.ended) return .done; // no backpressure since we are done // we may call done inside processBuffered so we ensure that we keep a ref until we are done this.ref(); defer this.deref(); @@ -715,7 +715,7 @@ pub const MultiPartUpload = struct { if (this.buffered.size() > 0) { this.processBuffered(this.partSizeInBytes()); } - return !this.hasBackpressure(); + return if (this.hasBackpressure()) .backpressure else .want_more; } if (is_last) { this.ended = true; @@ -729,7 +729,7 @@ pub const MultiPartUpload = struct { this.processBuffered(this.partSizeInBytes()); } else { // still have more data and receive empty, nothing todo here - if (chunk.len == 0) return this.hasBackpressure(); + if (chunk.len == 0) return if (this.hasBackpressure()) .backpressure else .want_more; switch (encoding) { .bytes => try this.buffered.write(chunk), .latin1 => try this.buffered.writeLatin1(chunk, true), @@ -743,18 +743,18 @@ pub const MultiPartUpload = struct { // wait for more } - return !this.hasBackpressure(); + return if (this.hasBackpressure()) .backpressure else .want_more; } - pub fn writeLatin1(this: *@This(), chunk: []const u8, is_last: bool) bun.OOM!bool { + pub fn writeLatin1(this: *@This(), chunk: []const u8, is_last: bool) bun.OOM!ResumableSinkBackpressure { return try this.write(chunk, is_last, .latin1); } - pub fn writeUTF16(this: *@This(), chunk: []const u8, is_last: bool) bun.OOM!bool { + pub fn writeUTF16(this: *@This(), chunk: []const u8, is_last: bool) bun.OOM!ResumableSinkBackpressure { return try this.write(chunk, is_last, .utf16); } - pub fn writeBytes(this: *@This(), chunk: []const u8, is_last: bool) bun.OOM!bool { + pub fn writeBytes(this: *@This(), chunk: []const u8, is_last: bool) bun.OOM!ResumableSinkBackpressure { return try this.write(chunk, is_last, .bytes); } }; @@ -772,3 +772,4 @@ const executeSimpleS3Request = S3SimpleRequest.executeSimpleS3Request; const bun = @import("bun"); const jsc = bun.jsc; const strings = bun.strings; +const ResumableSinkBackpressure = jsc.WebCore.ResumableSinkBackpressure; diff --git a/test/js/bun/io/fetch/fetch-abort-slow-connect.test.ts b/test/js/bun/io/fetch/fetch-abort-slow-connect.test.ts new file mode 100644 index 0000000000..ef4f31e199 --- /dev/null +++ b/test/js/bun/io/fetch/fetch-abort-slow-connect.test.ts @@ -0,0 +1,59 @@ +import { expect, test } from "bun:test"; + +test.concurrent("fetch aborts when connect() returns EINPROGRESS but never completes", async () => { + // Use TEST-NET-1 (192.0.2.0/24) from RFC 5737 + // These IPs are reserved for documentation and testing. + // Connecting to them will cause connect() to return EINPROGRESS + // but the connection will never complete because there's no route. + const nonRoutableIP = "192.0.2.1"; + const port = 80; + + const start = performance.now(); + try { + await fetch(`http://${nonRoutableIP}:${port}/`, { + signal: AbortSignal.timeout(50), + }); + expect.unreachable("Fetch should have aborted"); + } catch (e: any) { + const elapsed = performance.now() - start; + expect(e.name).toBe("TimeoutError"); + expect(elapsed).toBeLessThan(1000); // But not more than 1000ms + } +}); + +test.concurrent("fetch aborts immediately during EINPROGRESS connect", async () => { + const nonRoutableIP = "192.0.2.1"; + const port = 80; + + // Start the fetch + const fetchPromise = fetch(`http://${nonRoutableIP}:${port}/`, { + signal: AbortSignal.timeout(1), + }); + + const start = performance.now(); + try { + await fetchPromise; + expect.unreachable("Fetch should have aborted"); + } catch (e: any) { + const elapsed = performance.now() - start; + expect(e.name).toBe("TimeoutError"); + expect(elapsed).toBeLessThan(1000); // Should reject very quickly after abort + } +}); + +test.concurrent("pre-aborted signal prevents connection attempt", async () => { + const nonRoutableIP = "192.0.2.1"; + const port = 80; + + const start = performance.now(); + try { + await fetch(`http://${nonRoutableIP}:${port}/`, { + signal: AbortSignal.abort(), + }); + expect.unreachable("Fetch should have aborted"); + } catch (e: any) { + const elapsed = performance.now() - start; + expect(e.name).toBe("AbortError"); + expect(elapsed).toBeLessThan(10); // Should fail immediately + } +}); diff --git a/test/js/bun/s3/s3-stream-leak-fixture.js b/test/js/bun/s3/s3-stream-leak-fixture.js index b5052d5d86..b440f8334b 100644 --- a/test/js/bun/s3/s3-stream-leak-fixture.js +++ b/test/js/bun/s3/s3-stream-leak-fixture.js @@ -33,7 +33,7 @@ async function run(inputType) { const rss = (process.memoryUsage.rss() / 1024 / 1024) | 0; if (rss > MAX_ALLOWED_MEMORY_USAGE) { await s3file.unlink(); - throw new Error("Memory usage is too high"); + throw new Error("RSS reached " + rss + "MB"); } } await run(new Buffer(1024 * 1024 * 1, "A".charCodeAt(0)).toString("utf-8")); diff --git a/test/js/bun/s3/s3.leak.test.ts b/test/js/bun/s3/s3.leak.test.ts index ac3688996c..65a655ad0e 100644 --- a/test/js/bun/s3/s3.leak.test.ts +++ b/test/js/bun/s3/s3.leak.test.ts @@ -33,13 +33,12 @@ describe.skipIf(!s3Options.accessKeyId)("s3", () => { AWS_ENDPOINT: s3Options.endpoint, AWS_BUCKET: S3Bucket, }, - stderr: "pipe", + stderr: "inherit", stdout: "inherit", stdin: "ignore", }, ); expect(exitCode).toBe(0); - expect(stderr.toString()).toBe(""); }, 30 * 1000, ); diff --git a/test/js/web/fetch/fetch.test.ts b/test/js/web/fetch/fetch.test.ts index 90e7728673..c307515cb3 100644 --- a/test/js/web/fetch/fetch.test.ts +++ b/test/js/web/fetch/fetch.test.ts @@ -295,10 +295,45 @@ describe("AbortSignal", () => { method: "POST", body: new ReadableStream({ pull(event_controller) { + console.count("pull"); event_controller.enqueue(new Uint8Array([1, 2, 3, 4])); //this will abort immediately should abort before connected controller.abort(); }, + cancel(reason) { + console.log("cancel", reason); + }, + }), + signal: controller.signal, + }); + expect.unreachable(); + } catch (ex: any) { + expect(ex?.message).toEqual("The operation was aborted."); + expect(ex?.name).toEqual("AbortError"); + expect(ex?.constructor.name).toEqual("DOMException"); + } + }); + + it("abort while uploading prevents pull() from being called", async () => { + const controller = new AbortController(); + await fetch(`http://localhost:${server.port}`, { + method: "POST", + body: new Blob(["a"]), + }); + + try { + await fetch(`http://localhost:${server.port}`, { + method: "POST", + body: new ReadableStream({ + async pull(event_controller) { + expect(controller.signal.aborted).toBeFalse(); + const chunk = Buffer.alloc(256 * 1024, "abc"); + for (let i = 0; i < 64; i++) { + event_controller.enqueue(chunk); + } + //this will abort immediately should abort before connected + controller.abort(); + }, }), signal: controller.signal, }); From 9746d03ccb0f714685ed2961c9905f1d2f27a803 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Thu, 25 Sep 2025 16:24:22 -0700 Subject: [PATCH 08/43] Delete slop test --- .../fetch/node-use-system-ca-complete.test.ts | 238 ------------------ 1 file changed, 238 deletions(-) delete mode 100644 test/js/bun/fetch/node-use-system-ca-complete.test.ts diff --git a/test/js/bun/fetch/node-use-system-ca-complete.test.ts b/test/js/bun/fetch/node-use-system-ca-complete.test.ts deleted file mode 100644 index be07b50e3c..0000000000 --- a/test/js/bun/fetch/node-use-system-ca-complete.test.ts +++ /dev/null @@ -1,238 +0,0 @@ -import { describe, expect, test } from "bun:test"; -import { promises as fs } from "fs"; -import { bunEnv, bunExe, tempDirWithFiles } from "harness"; -import { platform } from "os"; -import { join } from "path"; - -describe("NODE_USE_SYSTEM_CA Complete Implementation", () => { - test("should work with standard HTTPS sites", async () => { - const testDir = tempDirWithFiles("node-use-system-ca-basic", {}); - - const testScript = ` -async function testHttpsRequest() { - try { - const response = await fetch('https://httpbin.org/user-agent'); - console.log('SUCCESS: GitHub request completed with status', response.status); - process.exit(0); - } catch (error) { - console.log('ERROR: HTTPS request failed:', error.message); - process.exit(1); - } -} - -testHttpsRequest(); -`; - - await fs.writeFile(join(testDir, "test.js"), testScript); - - // Test with NODE_USE_SYSTEM_CA=1 - const proc1 = Bun.spawn({ - cmd: [bunExe(), "test.js"], - env: { - ...bunEnv, - NODE_USE_SYSTEM_CA: "1", - }, - cwd: testDir, - stdout: "pipe", - stderr: "pipe", - }); - - const [stdout1, stderr1, exitCode1] = await Promise.all([proc1.stdout.text(), proc1.stderr.text(), proc1.exited]); - - expect(exitCode1).toBe(0); - expect(stdout1).toContain("SUCCESS"); - - // Test without NODE_USE_SYSTEM_CA - const proc2 = Bun.spawn({ - cmd: [bunExe(), "test.js"], - env: bunEnv, - cwd: testDir, - stdout: "pipe", - stderr: "pipe", - }); - - const [stdout2, stderr2, exitCode2] = await Promise.all([proc2.stdout.text(), proc2.stderr.text(), proc2.exited]); - - expect(exitCode2).toBe(0); - expect(stdout2).toContain("SUCCESS"); - }); - - test("should properly parse NODE_USE_SYSTEM_CA environment variable", async () => { - const testDir = tempDirWithFiles("node-use-system-ca-env-parsing", {}); - - const testScript = ` -const testCases = [ - { env: '1', description: 'string "1"' }, - { env: 'true', description: 'string "true"' }, - { env: '0', description: 'string "0"' }, - { env: 'false', description: 'string "false"' }, - { env: undefined, description: 'undefined' } -]; - -console.log('Testing NODE_USE_SYSTEM_CA environment variable parsing:'); - -for (const testCase of testCases) { - if (testCase.env !== undefined) { - process.env.NODE_USE_SYSTEM_CA = testCase.env; - } else { - delete process.env.NODE_USE_SYSTEM_CA; - } - - const actual = process.env.NODE_USE_SYSTEM_CA; - console.log(\` \${testCase.description}: \${actual || 'undefined'}\`); -} - -console.log('Environment variable parsing test completed successfully'); -process.exit(0); -`; - - await fs.writeFile(join(testDir, "test-env.js"), testScript); - - const proc = Bun.spawn({ - cmd: [bunExe(), "test-env.js"], - env: bunEnv, - cwd: testDir, - stdout: "pipe", - stderr: "pipe", - }); - - const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); - - expect(exitCode).toBe(0); - expect(stdout).toContain("Environment variable parsing test completed successfully"); - }); - - test("should handle platform-specific behavior correctly", async () => { - const testDir = tempDirWithFiles("node-use-system-ca-platform", {}); - - const testScript = ` -const { platform } = require('os'); - -console.log(\`Platform: \${platform()}\`); -console.log(\`NODE_USE_SYSTEM_CA: \${process.env.NODE_USE_SYSTEM_CA}\`); - -async function testPlatformBehavior() { - try { - // Test a reliable HTTPS endpoint - const response = await fetch('https://httpbin.org/user-agent'); - const data = await response.json(); - - console.log('SUCCESS: Platform-specific certificate loading working'); - console.log('User-Agent:', data['user-agent']); - - if (platform() === 'darwin' && process.env.NODE_USE_SYSTEM_CA === '1') { - console.log('SUCCESS: macOS Security framework integration should be active'); - } else if (platform() === 'linux' && process.env.NODE_USE_SYSTEM_CA === '1') { - console.log('SUCCESS: Linux system certificate loading should be active'); - } else if (platform() === 'win32' && process.env.NODE_USE_SYSTEM_CA === '1') { - console.log('SUCCESS: Windows certificate store integration should be active'); - } else { - console.log('SUCCESS: Using bundled certificates'); - } - - process.exit(0); - } catch (error) { - console.error('FAILED: Platform test failed:', error.message); - process.exit(1); - } -} - -testPlatformBehavior(); -`; - - await fs.writeFile(join(testDir, "test-platform.js"), testScript); - - const proc = Bun.spawn({ - cmd: [bunExe(), "test-platform.js"], - env: { - ...bunEnv, - NODE_USE_SYSTEM_CA: "1", - }, - cwd: testDir, - stdout: "pipe", - stderr: "pipe", - }); - - const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); - - console.log("Platform test output:", stdout); - console.log("Platform test errors:", stderr); - - expect(exitCode).toBe(0); - expect(stdout).toContain("SUCCESS: Platform-specific certificate loading working"); - - if (platform() === "darwin") { - expect(stdout).toContain("macOS Security framework integration should be active"); - } else if (platform() === "linux") { - expect(stdout).toContain("Linux system certificate loading should be active"); - } - }); - - test("should work with TLS connections", async () => { - const testDir = tempDirWithFiles("node-use-system-ca-tls", {}); - - const testScript = ` -const tls = require('tls'); - -async function testTLSConnection() { - return new Promise((resolve, reject) => { - const options = { - host: 'www.google.com', - port: 443, - rejectUnauthorized: true, - }; - - const socket = tls.connect(options, () => { - console.log('SUCCESS: TLS connection established'); - console.log('Certificate authorized:', socket.authorized); - - socket.destroy(); - resolve(); - }); - - socket.on('error', (error) => { - console.error('FAILED: TLS connection failed:', error.message); - reject(error); - }); - - socket.setTimeout(10000, () => { - console.error('FAILED: Connection timeout'); - socket.destroy(); - reject(new Error('Timeout')); - }); - }); -} - -testTLSConnection() - .then(() => { - console.log('TLS test completed successfully'); - process.exit(0); - }) - .catch((error) => { - console.error('TLS test failed:', error.message); - process.exit(1); - }); -`; - - await fs.writeFile(join(testDir, "test-tls.js"), testScript); - - const proc = Bun.spawn({ - cmd: [bunExe(), "test-tls.js"], - env: { - ...bunEnv, - NODE_USE_SYSTEM_CA: "1", - }, - cwd: testDir, - stdout: "pipe", - stderr: "pipe", - }); - - const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); - - console.log("TLS test output:", stdout); - - expect(exitCode).toBe(0); - expect(stdout).toContain("SUCCESS: TLS connection established"); - expect(stdout).toContain("TLS test completed successfully"); - }); -}); From 749ad8a1ffd1f2999b3c62db48e0f48aa56b62ea Mon Sep 17 00:00:00 2001 From: Marko Vejnovic Date: Thu, 25 Sep 2025 16:53:21 -0700 Subject: [PATCH 09/43] fix(build): Minor Linux Build Fixes (#22972) ### What does this PR do? ### How did you verify your code works? --- CONTRIBUTING.md | 2 +- cmake/tools/SetupLLVM.cmake | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f35200c5f2..c39fd4463a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,7 +21,7 @@ $ sudo pacman -S base-devel ccache cmake git go libiconv libtool make ninja pkg- ``` ```bash#Fedora -$ sudo dnf install cargo ccache cmake git golang libtool ninja-build pkg-config rustc ruby libatomic-static libstdc++-static sed unzip which libicu-devel 'perl(Math::BigInt)' +$ sudo dnf install cargo clang19 llvm19 lld19 ccache cmake git golang libtool ninja-build pkg-config rustc ruby libatomic-static libstdc++-static sed unzip which libicu-devel 'perl(Math::BigInt)' ``` ```bash#openSUSE Tumbleweed diff --git a/cmake/tools/SetupLLVM.cmake b/cmake/tools/SetupLLVM.cmake index a06a51b23e..a250342018 100644 --- a/cmake/tools/SetupLLVM.cmake +++ b/cmake/tools/SetupLLVM.cmake @@ -131,6 +131,9 @@ else() find_llvm_command(CMAKE_RANLIB llvm-ranlib) if(LINUX) find_llvm_command(LLD_PROGRAM ld.lld) + # Ensure vendor dependencies use lld instead of ld + list(APPEND CMAKE_ARGS -DCMAKE_EXE_LINKER_FLAGS=--ld-path=${LLD_PROGRAM}) + list(APPEND CMAKE_ARGS -DCMAKE_SHARED_LINKER_FLAGS=--ld-path=${LLD_PROGRAM}) endif() if(APPLE) find_llvm_command(CMAKE_DSYMUTIL dsymutil) From 0b9a2fce2de9ca0758961bc5f5160d0829318c88 Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Thu, 25 Sep 2025 17:06:23 -0700 Subject: [PATCH 10/43] update no-validate-leaksan.txt --- test/no-validate-leaksan.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/no-validate-leaksan.txt b/test/no-validate-leaksan.txt index d6cd8d6407..bd927cf137 100644 --- a/test/no-validate-leaksan.txt +++ b/test/no-validate-leaksan.txt @@ -401,3 +401,4 @@ test/js/third_party/rollup-v4/rollup-v4.test.ts test/js/web/abort/abort.test.ts test/js/third_party/resvg/bbox.test.js test/regression/issue/10139.test.ts +test/js/bun/udp/udp_socket.test.ts From 58782ceef2ed874d331b7e1001ad81aad2119bc1 Mon Sep 17 00:00:00 2001 From: robobun Date: Thu, 25 Sep 2025 17:23:45 -0700 Subject: [PATCH 11/43] Fix bun_dependency_versions.h regenerating on every CMake run (#22985) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Fixes unnecessary regeneration of `bun_dependency_versions.h` on every CMake run - Only writes the header file when content actually changes ## Test plan Tested locally by running CMake configuration multiple times: 1. First run generates the file (shows "Updated dependency versions header") 2. Subsequent runs skip writing (shows "Dependency versions header unchanged") 3. File modification timestamp remains unchanged when content is the same 4. File is properly regenerated when deleted or when content changes 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude Bot Co-authored-by: Claude --- cmake/tools/GenerateDependencyVersions.cmake | 21 +++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/cmake/tools/GenerateDependencyVersions.cmake b/cmake/tools/GenerateDependencyVersions.cmake index 65c051e285..b7942307be 100644 --- a/cmake/tools/GenerateDependencyVersions.cmake +++ b/cmake/tools/GenerateDependencyVersions.cmake @@ -181,12 +181,23 @@ function(generate_dependency_versions_header) string(APPEND HEADER_CONTENT "}\n") string(APPEND HEADER_CONTENT "#endif\n\n") string(APPEND HEADER_CONTENT "#endif // BUN_DEPENDENCY_VERSIONS_H\n") - - # Write the header file + + # Write the header file only if content has changed set(OUTPUT_FILE "${CMAKE_BINARY_DIR}/bun_dependency_versions.h") - file(WRITE "${OUTPUT_FILE}" "${HEADER_CONTENT}") - - message(STATUS "Generated dependency versions header: ${OUTPUT_FILE}") + + # Read existing content if file exists + set(EXISTING_CONTENT "") + if(EXISTS "${OUTPUT_FILE}") + file(READ "${OUTPUT_FILE}" EXISTING_CONTENT) + endif() + + # Only write if content is different + if(NOT "${EXISTING_CONTENT}" STREQUAL "${HEADER_CONTENT}") + file(WRITE "${OUTPUT_FILE}" "${HEADER_CONTENT}") + message(STATUS "Updated dependency versions header: ${OUTPUT_FILE}") + else() + message(STATUS "Dependency versions header unchanged: ${OUTPUT_FILE}") + endif() # Also create a more detailed version for debugging set(DEBUG_OUTPUT_FILE "${CMAKE_BINARY_DIR}/bun_dependency_versions_debug.txt") From d3061de1bff1c7623e088cfc1c0e32103a9c3b9c Mon Sep 17 00:00:00 2001 From: robobun Date: Thu, 25 Sep 2025 18:03:27 -0700 Subject: [PATCH 12/43] feat(windows): implement authenticode stripping for --compile (#22960) ## Summary Implements authenticode signature stripping for Windows PE files when using `bun build --compile`, ensuring that generated executables can be properly signed with external tools after Bun embeds its data section. ## What Changed ### Core Implementation - **Authenticode stripping**: Removes digital signatures from PE files before adding the .bun section - **Safe memory access**: Replaced all `@alignCast` operations with safe unaligned access helpers to prevent crashes - **Hardened PE parsing**: Added comprehensive bounds checking and validation throughout - **PE checksum recalculation**: Properly updates checksums after modifications ### Key Features - Always strips authenticode signatures when using `--compile` for Windows (uses `.strip_always` mode) - Validates PE file structure according to PE/COFF specification - Handles overlapping memory regions safely during certificate removal - Clears `IMAGE_DLLCHARACTERISTICS_FORCE_INTEGRITY` flag when stripping signatures - Ensures no unexpected overlay data remains after stripping ### Bug Fixes - Fixed memory corruption bug using `copyBackwards` for overlapping regions - Fixed checksum calculation skipping 6 bytes instead of 4 - Added integer overflow protection in payload size calculations - Fixed double alignment bug in `size_of_image` calculation ## Technical Details The implementation follows the Windows PE/COFF specification and includes: - `StripMode` enum to control when signatures are stripped (none/strip_if_signed/strip_always) - Safe unaligned memory access helpers (`viewAtConst`, `viewAtMut`) - Proper alignment helpers with overflow protection (`alignUpU32`, `alignUpUsize`) - Comprehensive error types for all failure cases ## Testing - Passes all existing PE tests in `test/regression/issue/pe-codesigning-integrity.test.ts` - Compiles successfully with `bun run zig:check-windows` - Properly integrated with StandaloneModuleGraph for Windows compilation ## Impact This ensures Windows users can: 1. Use `bun build --compile` to create standalone executables 2. Sign the resulting executables with their own certificates 3. Distribute properly signed Windows binaries Fixes issues where previously signed executables would have invalid signatures after Bun added its embedded data. --------- Co-authored-by: Claude Bot Co-authored-by: Claude Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jarred Sumner --- src/StandaloneModuleGraph.zig | 3 +- src/pe.zig | 573 +++++++++++++++++++++++++++------- 2 files changed, 458 insertions(+), 118 deletions(-) diff --git a/src/StandaloneModuleGraph.zig b/src/StandaloneModuleGraph.zig index 40af49d457..68b3006688 100644 --- a/src/StandaloneModuleGraph.zig +++ b/src/StandaloneModuleGraph.zig @@ -724,7 +724,8 @@ pub const StandaloneModuleGraph = struct { return bun.invalid_fd; }; defer pe_file.deinit(); - pe_file.addBunSection(bytes) catch |err| { + // Always strip authenticode when adding .bun section for --compile + pe_file.addBunSection(bytes, .strip_always) catch |err| { Output.prettyErrorln("Error adding Bun section to PE file: {}", .{err}); cleanup(zname, cloned_executable_fd); return bun.invalid_fd; diff --git a/src/pe.zig b/src/pe.zig index 57148eba47..af592787c0 100644 --- a/src/pe.zig +++ b/src/pe.zig @@ -1,6 +1,36 @@ // Windows PE sections use standard file alignment (typically 512 bytes) // No special 16KB alignment needed like macOS code signing +// New error types for PE manipulation +pub const Error = error{ + OutOfBounds, + BadAlignment, + Overflow, + InvalidPEFile, + InvalidDOSSignature, + InvalidPESignature, + UnsupportedPEFormat, + InsufficientHeaderSpace, + TooManySections, + SectionExists, + InputIsSigned, + InvalidSecurityDirectory, + SecurityDirInsideImage, + UnexpectedOverlayPresent, + InvalidSectionData, + BunSectionNotFound, + InvalidBunSection, + InsufficientSpace, + SizeOfImageMismatch, +}; + +// Enums for strip modes and options +pub const StripMode = enum { none, strip_if_signed, strip_always }; +pub const StripOpts = struct { + require_overlay: bool = true, + recompute_checksum: bool = true, +}; + /// Windows PE Binary manipulation for codesigning standalone executables pub const PEFile = struct { data: std.ArrayList(u8), @@ -11,6 +41,10 @@ pub const PEFile = struct { optional_header_offset: usize, section_headers_offset: usize, num_sections: u16, + // Cached values from init + first_raw: u32, + last_file_end: u32, + last_va_end: u32, const DOSHeader = extern struct { e_magic: u16, // Magic number @@ -107,91 +141,194 @@ pub const PEFile = struct { const IMAGE_SCN_MEM_WRITE = 0x80000000; const IMAGE_SCN_MEM_EXECUTE = 0x20000000; - // Helper methods to safely access headers - fn getDosHeader(self: *const PEFile) *DOSHeader { - return @ptrCast(@alignCast(self.data.items.ptr + self.dos_header_offset)); + // Directory indices and DLL characteristics + const IMAGE_DIRECTORY_ENTRY_SECURITY: usize = 4; + const IMAGE_DLLCHARACTERISTICS_FORCE_INTEGRITY: u16 = 0x0080; + + // Section name constant for exact comparison + const BUN_SECTION_NAME = [_]u8{ '.', 'b', 'u', 'n', 0, 0, 0, 0 }; + + // Safe access helpers for unaligned views + fn viewAtConst(comptime T: type, buf: []const u8, off: usize) !*align(1) const T { + if (off + @sizeOf(T) > buf.len) return error.OutOfBounds; + return @ptrCast(buf[off .. off + @sizeOf(T)].ptr); } - fn getPEHeader(self: *const PEFile) *PEHeader { - return @ptrCast(@alignCast(self.data.items.ptr + self.pe_header_offset)); + fn viewAtMut(comptime T: type, buf: []u8, off: usize) !*align(1) T { + if (off + @sizeOf(T) > buf.len) return error.OutOfBounds; + return @ptrCast(buf[off .. off + @sizeOf(T)].ptr); } - fn getOptionalHeader(self: *const PEFile) *OptionalHeader64 { - return @ptrCast(@alignCast(self.data.items.ptr + self.optional_header_offset)); + fn isPow2(x: u32) bool { + return x != 0 and (x & (x - 1)) == 0; } - fn getSectionHeaders(self: *const PEFile) []SectionHeader { - return @as([*]SectionHeader, @ptrCast(@alignCast(self.data.items.ptr + self.section_headers_offset)))[0..self.num_sections]; + fn alignUpU32(v: u32, a: u32) !u32 { + if (a == 0) return v; + if (!isPow2(a)) return error.BadAlignment; + const add = a - 1; + if (v > std.math.maxInt(u32) - add) return error.Overflow; + return (v + add) & ~add; + } + + fn alignUpUsize(v: usize, a: usize) !usize { + if (a == 0) return v; + if ((a & (a - 1)) != 0) return error.BadAlignment; + const add = a - 1; + if (v > std.math.maxInt(usize) - add) return error.Overflow; + return (v + add) & ~add; + } + + // Helper methods to safely access headers using unaligned pointers + fn getDosHeader(self: *const PEFile) !*align(1) const DOSHeader { + return viewAtConst(DOSHeader, self.data.items, self.dos_header_offset); + } + + fn getDosHeaderMut(self: *PEFile) !*align(1) DOSHeader { + return viewAtMut(DOSHeader, self.data.items, self.dos_header_offset); + } + + fn getPEHeader(self: *const PEFile) !*align(1) const PEHeader { + return viewAtConst(PEHeader, self.data.items, self.pe_header_offset); + } + + fn getPEHeaderMut(self: *PEFile) !*align(1) PEHeader { + return viewAtMut(PEHeader, self.data.items, self.pe_header_offset); + } + + fn getOptionalHeader(self: *const PEFile) !*align(1) const OptionalHeader64 { + return viewAtConst(OptionalHeader64, self.data.items, self.optional_header_offset); + } + + fn getOptionalHeaderMut(self: *PEFile) !*align(1) OptionalHeader64 { + return viewAtMut(OptionalHeader64, self.data.items, self.optional_header_offset); + } + + fn getSectionHeaders(self: *const PEFile) ![]align(1) const SectionHeader { + const start = self.section_headers_offset; + const size = @sizeOf(SectionHeader) * self.num_sections; + if (start + size > self.data.items.len) return error.OutOfBounds; + const ptr: [*]align(1) const SectionHeader = @ptrCast(self.data.items[start..].ptr); + return ptr[0..self.num_sections]; + } + + fn getSectionHeadersMut(self: *PEFile) ![]align(1) SectionHeader { + const start = self.section_headers_offset; + const size = @sizeOf(SectionHeader) * self.num_sections; + if (start + size > self.data.items.len) return error.OutOfBounds; + const ptr: [*]align(1) SectionHeader = @ptrCast(self.data.items[start..].ptr); + return ptr[0..self.num_sections]; } pub fn init(allocator: Allocator, pe_data: []const u8) !*PEFile { - // Reserve some extra space for adding sections, but no need for 16KB alignment + // 1. Reserve capacity as before var data = try std.ArrayList(u8).initCapacity(allocator, pe_data.len + 64 * 1024); try data.appendSlice(pe_data); const self = try allocator.create(PEFile); errdefer allocator.destroy(self); - // Parse DOS header + // 2. Validate DOS header if (data.items.len < @sizeOf(DOSHeader)) { return error.InvalidPEFile; } - const dos_header: *const DOSHeader = @ptrCast(@alignCast(data.items.ptr)); + const dos_header = try viewAtConst(DOSHeader, data.items, 0); if (dos_header.e_magic != DOS_SIGNATURE) { return error.InvalidDOSSignature; } - // Validate e_lfanew offset (should be reasonable) - if (dos_header.e_lfanew < @sizeOf(DOSHeader) or dos_header.e_lfanew > 0x1000) { + // Bound e_lfanew against file size, not 0x1000 + if (dos_header.e_lfanew < @sizeOf(DOSHeader)) { + return error.InvalidPEFile; + } + if (dos_header.e_lfanew > data.items.len -| @sizeOf(PEHeader)) { return error.InvalidPEFile; } - // Calculate offsets - const pe_header_offset = dos_header.e_lfanew; - const optional_header_offset = pe_header_offset + @sizeOf(PEHeader); - - // Parse PE header - if (data.items.len < pe_header_offset + @sizeOf(PEHeader)) { - return error.InvalidPEFile; - } - - const pe_header: *const PEHeader = @ptrCast(@alignCast(data.items.ptr + pe_header_offset)); + // 3. Read PE header via viewAtMut + const pe_off = dos_header.e_lfanew; + const pe_header = try viewAtMut(PEHeader, data.items, pe_off); if (pe_header.signature != PE_SIGNATURE) { return error.InvalidPESignature; } - // Parse optional header - if (data.items.len < optional_header_offset + @sizeOf(OptionalHeader64)) { + // 4. Compute optional_header_offset + const optional_header_offset = pe_off + @sizeOf(PEHeader); + if (data.items.len < optional_header_offset + pe_header.size_of_optional_header) { + return error.InvalidPEFile; + } + if (pe_header.size_of_optional_header < @sizeOf(OptionalHeader64)) { return error.InvalidPEFile; } - const optional_header: *const OptionalHeader64 = @ptrCast(@alignCast(data.items.ptr + optional_header_offset)); + // 5. Read optional header + const optional_header = try viewAtMut(OptionalHeader64, data.items, optional_header_offset); if (optional_header.magic != OPTIONAL_HEADER_MAGIC_64) { return error.UnsupportedPEFormat; } - // Parse section headers + // Validate file_alignment and section_alignment + if (!isPow2(optional_header.file_alignment) or !isPow2(optional_header.section_alignment)) { + return error.BadAlignment; + } + // If section_alignment < 4096, then file_alignment == section_alignment + if (optional_header.section_alignment < 4096) { + if (optional_header.file_alignment != optional_header.section_alignment) { + return error.InvalidPEFile; + } + } + + // 6. Compute section_headers_offset const section_headers_offset = optional_header_offset + pe_header.size_of_optional_header; - const section_headers_size = @sizeOf(SectionHeader) * pe_header.number_of_sections; + const num_sections = pe_header.number_of_sections; + if (num_sections > 96) { // PE limit + return error.TooManySections; + } + const section_headers_size = @sizeOf(SectionHeader) * num_sections; if (data.items.len < section_headers_offset + section_headers_size) { return error.InvalidPEFile; } - // Check if we have space for at least one more section header (for future addition) - const max_sections_space = section_headers_offset + @sizeOf(SectionHeader) * 96; // PE max sections - if (data.items.len < max_sections_space) { - // Not enough space to add sections - we'll need to handle this in addBunSection + // 7. Precompute first_raw, last_file_end, last_va_end + var first_raw: u32 = @intCast(data.items.len); + var last_file_end: u32 = 0; + var last_va_end: u32 = 0; + + if (num_sections > 0) { + const sections_ptr: [*]align(1) const SectionHeader = @ptrCast(data.items[section_headers_offset..].ptr); + const sections = sections_ptr[0..num_sections]; + + for (sections) |section| { + if (section.size_of_raw_data > 0) { + if (section.pointer_to_raw_data < first_raw) { + first_raw = section.pointer_to_raw_data; + } + const file_end = section.pointer_to_raw_data + section.size_of_raw_data; + if (file_end > last_file_end) { + last_file_end = file_end; + } + } + // Use effective virtual size (max of virtual_size and size_of_raw_data) + const vs_effective = @max(section.virtual_size, section.size_of_raw_data); + const va_end = section.virtual_address + (try alignUpU32(vs_effective, optional_header.section_alignment)); + if (va_end > last_va_end) { + last_va_end = va_end; + } + } } self.* = .{ .data = data, .allocator = allocator, .dos_header_offset = 0, - .pe_header_offset = pe_header_offset, + .pe_header_offset = pe_off, .optional_header_offset = optional_header_offset, .section_headers_offset = section_headers_offset, - .num_sections = pe_header.number_of_sections, + .num_sections = num_sections, + .first_raw = first_raw, + .last_file_end = last_file_end, + .last_va_end = last_va_end, }; return self; @@ -202,41 +339,188 @@ pub const PEFile = struct { self.allocator.destroy(self); } + /// Strip Authenticode signatures from the PE file + pub fn stripAuthenticode(self: *PEFile, opts: StripOpts) !void { + const data = self.data.items; + const opt = try viewAtMut(OptionalHeader64, data, self.optional_header_offset); + + // Read Security directory (index 4) + const dd_ptr: *align(1) DataDirectory = &opt.data_directories[IMAGE_DIRECTORY_ENTRY_SECURITY]; + const sec_off_u32 = dd_ptr.virtual_address; // file offset (not RVA) + const sec_size_u32 = dd_ptr.size; + + if (sec_off_u32 == 0 or sec_size_u32 == 0) return; // nothing to strip + + // Compute last_file_end from sections (reuse cached or recompute) + var last_raw_end: u32 = 0; + const sections = try self.getSectionHeaders(); + for (sections) |s| { + const end = s.pointer_to_raw_data + s.size_of_raw_data; + if (end > last_raw_end) last_raw_end = end; + } + + const file_len = data.len; + const sec_off = @as(usize, sec_off_u32); + const sec_size = @as(usize, sec_size_u32); + + if (sec_off >= file_len or sec_size == 0) return error.InvalidSecurityDirectory; + if (opts.require_overlay and sec_off < @as(usize, last_raw_end)) + return error.SecurityDirInsideImage; + + // Remove certificate plus 8-byte padding at tail + const end_raw = try alignUpUsize(sec_off + sec_size, 8); + if (end_raw > file_len) return error.InvalidSecurityDirectory; + + if (end_raw == file_len) { + try self.data.resize(sec_off); + } else { + const tail_len = file_len - end_raw; + // Use copyBackwards for potentially overlapping memory regions + std.mem.copyBackwards(u8, self.data.items[sec_off .. sec_off + tail_len], self.data.items[end_raw..file_len]); + try self.data.resize(sec_off + tail_len); + } + + // Re-get pointers after resize + const opt_after = try self.getOptionalHeaderMut(); + const dd_after: *align(1) DataDirectory = &opt_after.data_directories[IMAGE_DIRECTORY_ENTRY_SECURITY]; + + // Zero Security directory entry + dd_after.virtual_address = 0; + dd_after.size = 0; + + // Clear FORCE_INTEGRITY bit if set + if ((opt_after.dll_characteristics & IMAGE_DLLCHARACTERISTICS_FORCE_INTEGRITY) != 0) + opt_after.dll_characteristics &= ~IMAGE_DLLCHARACTERISTICS_FORCE_INTEGRITY; + + // Recompute checksum (recommended) + if (opts.recompute_checksum) try self.recomputePEChecksum(); + + // After strip, ensure no remaining overlay beyond last section + const after_strip_len = self.data.items.len; + if (@as(usize, last_raw_end) < after_strip_len) + return error.UnexpectedOverlayPresent; + } + + /// Recompute PE checksum according to Windows spec + fn recomputePEChecksum(self: *PEFile) !void { + const data = self.data.items; + const checksum_off = self.optional_header_offset + @offsetOf(OptionalHeader64, "checksum"); + + // Zero checksum field before summing + @memset(self.data.items[checksum_off .. checksum_off + 4], 0); + + var sum: u64 = 0; + var i: usize = 0; + + // Sum 16-bit words + while (i + 1 < data.len) : (i += 2) { + const w: u16 = @as(u16, data[i]) | (@as(u16, data[i + 1]) << 8); + sum += w; + sum = (sum & 0xffff) + (sum >> 16); // fold periodically + } + // Odd trailing byte + if ((data.len & 1) != 0) { + sum += data[data.len - 1]; + } + + // Final folds + add length + sum = (sum & 0xffff) + (sum >> 16); + sum = (sum & 0xffff) + (sum >> 16); + sum += @as(u64, @intCast(data.len)); + sum = (sum & 0xffff) + (sum >> 16); + const final_sum: u32 = @intCast((sum & 0xffff) + (sum >> 16)); + + const opt = try self.getOptionalHeaderMut(); + opt.checksum = final_sum; + } + /// Add a new section to the PE file for storing Bun module data - pub fn addBunSection(self: *PEFile, data_to_embed: []const u8) !void { - const section_name = ".bun\x00\x00\x00\x00"; - const optional_header = self.getOptionalHeader(); - const aligned_size = alignSize(@intCast(data_to_embed.len + @sizeOf(u32)), optional_header.file_alignment); + pub fn addBunSection(self: *PEFile, data_to_embed: []const u8, strip: StripMode) !void { + // 1. Optional strip (before any addition) + if (strip == .strip_always) { + try self.stripAuthenticode(.{ .require_overlay = true, .recompute_checksum = true }); + } else if (strip == .strip_if_signed) { + // Read Security directory to check if signed + const opt = try self.getOptionalHeader(); + const dd = opt.data_directories[IMAGE_DIRECTORY_ENTRY_SECURITY]; + if (dd.virtual_address != 0 or dd.size != 0) { + try self.stripAuthenticode(.{ .require_overlay = true, .recompute_checksum = true }); + } + } + + // 2. Re-read PE/Optional (pointers may have moved due to resize in strip) + const opt = try self.getOptionalHeaderMut(); + + // 3. Duplicate .bun guard - compare all 8 bytes exactly + const section_headers = try self.getSectionHeaders(); + for (section_headers) |section| { + if (std.mem.eql(u8, section.name[0..8], &BUN_SECTION_NAME)) { + return error.SectionExists; + } + } // Check if we can add another section - if (self.num_sections >= 95) { // PE limit is 96 sections + if (self.num_sections >= 96) { // PE limit return error.TooManySections; } - // Find the last section to determine where to place the new one - var last_section_end: u32 = 0; - var last_virtual_end: u32 = 0; + // 4. Compute header slack requirement + const new_headers_end = self.section_headers_offset + @sizeOf(SectionHeader) * (self.num_sections + 1); + const new_size_of_headers = try alignUpU32(@intCast(new_headers_end), opt.file_alignment); - const section_headers = self.getSectionHeaders(); + // Determine first_raw (min PointerToRawData among sections with raw data, else data.len) + var first_raw: u32 = @intCast(self.data.items.len); for (section_headers) |section| { - const section_file_end = section.pointer_to_raw_data + section.size_of_raw_data; - const section_virtual_end = section.virtual_address + alignSize(section.virtual_size, optional_header.section_alignment); - - if (section_file_end > last_section_end) { - last_section_end = section_file_end; - } - if (section_virtual_end > last_virtual_end) { - last_virtual_end = section_virtual_end; + if (section.size_of_raw_data > 0) { + if (section.pointer_to_raw_data < first_raw) { + first_raw = section.pointer_to_raw_data; + } } } - // Create new section header - const new_section = SectionHeader{ - .name = section_name.*, - .virtual_size = @intCast(data_to_embed.len + @sizeOf(u32)), - .virtual_address = alignSize(last_virtual_end, optional_header.section_alignment), - .size_of_raw_data = aligned_size, - .pointer_to_raw_data = alignSize(last_section_end, optional_header.file_alignment), + // Require new_size_of_headers <= first_raw + if (new_size_of_headers > first_raw) { + return error.InsufficientHeaderSpace; + } + + // 5. Placement calculations + // Recompute last_file_end and last_va_end after strip + var last_file_end: u32 = 0; + var last_va_end: u32 = 0; + for (section_headers) |section| { + const file_end = section.pointer_to_raw_data + section.size_of_raw_data; + if (file_end > last_file_end) { + last_file_end = file_end; + } + // Use effective virtual size (max of virtual_size and size_of_raw_data) + const vs_effective = @max(section.virtual_size, section.size_of_raw_data); + const va_end = section.virtual_address + (try alignUpU32(vs_effective, opt.section_alignment)); + if (va_end > last_va_end) { + last_va_end = va_end; + } + } + + // Check for overflow before adding 4 + if (data_to_embed.len > std.math.maxInt(u32) - 4) { + return error.Overflow; + } + const payload_len = @as(u32, @intCast(data_to_embed.len + 4)); // 4 for LE length prefix + const raw_size = try alignUpU32(payload_len, opt.file_alignment); + const new_va = try alignUpU32(last_va_end, opt.section_alignment); + const new_raw = try alignUpU32(last_file_end, opt.file_alignment); + + // 6. Resize & zero only the new section area + const new_file_size = @as(usize, new_raw) + @as(usize, raw_size); + try self.data.resize(new_file_size); + @memset(self.data.items[@intCast(new_raw)..new_file_size], 0); + + // 7. Write the new SectionHeader by byte copy + const sh = SectionHeader{ + .name = [_]u8{ '.', 'b', 'u', 'n', 0, 0, 0, 0 }, + .virtual_size = payload_len, + .virtual_address = new_va, + .size_of_raw_data = raw_size, + .pointer_to_raw_data = new_raw, .pointer_to_relocations = 0, .pointer_to_line_numbers = 0, .number_of_relocations = 0, @@ -244,45 +528,51 @@ pub const PEFile = struct { .characteristics = IMAGE_SCN_CNT_INITIALIZED_DATA | IMAGE_SCN_MEM_READ, }; - // Resize data to accommodate new section - const new_data_size = new_section.pointer_to_raw_data + new_section.size_of_raw_data; - try self.data.resize(new_data_size); - - // Zero out the new section data - @memset(self.data.items[last_section_end..new_data_size], 0); - - // Write the section header - use our stored offset - const new_section_offset = self.section_headers_offset + @sizeOf(SectionHeader) * self.num_sections; - - // Check bounds before writing - if (new_section_offset + @sizeOf(SectionHeader) > self.data.items.len) { - return error.InsufficientSpace; + const new_sh_off = self.section_headers_offset + @sizeOf(SectionHeader) * self.num_sections; + // Bounds check against first_raw (not file length) + if (new_sh_off + @sizeOf(SectionHeader) > first_raw) { + return error.InsufficientHeaderSpace; } + std.mem.copyForwards(u8, self.data.items[new_sh_off .. new_sh_off + @sizeOf(SectionHeader)], std.mem.asBytes(&sh)); - const new_section_ptr: *SectionHeader = @ptrCast(@alignCast(self.data.items.ptr + new_section_offset)); - new_section_ptr.* = new_section; + // 8. Write payload + // At data[new_raw ..]: write LE length prefix then data + std.mem.writeInt(u32, self.data.items[new_raw..][0..4], @intCast(data_to_embed.len), .little); + @memcpy(self.data.items[new_raw + 4 ..][0..data_to_embed.len], data_to_embed); - // Write the data with size header - const data_offset = new_section.pointer_to_raw_data; - std.mem.writeInt(u32, self.data.items[data_offset..][0..4], @intCast(data_to_embed.len), .little); - @memcpy(self.data.items[data_offset + 4 ..][0..data_to_embed.len], data_to_embed); - - // Update PE header - get fresh pointer after resize - const pe_header = self.getPEHeader(); - pe_header.number_of_sections += 1; + // 9. Update headers + // Get fresh pointers after resize + const pe_after = try self.getPEHeaderMut(); + pe_after.number_of_sections += 1; self.num_sections += 1; - // Update optional header - get fresh pointer after resize - const updated_optional_header = self.getOptionalHeader(); - updated_optional_header.size_of_image = alignSize(new_section.virtual_address + new_section.virtual_size, updated_optional_header.section_alignment); - updated_optional_header.size_of_initialized_data += new_section.size_of_raw_data; + const opt_after = try self.getOptionalHeaderMut(); + // If opt.size_of_headers < new_size_of_headers + if (opt_after.size_of_headers < new_size_of_headers) { + opt_after.size_of_headers = new_size_of_headers; + } + // Calculate size_of_image: aligned end of last section + const section_va_end = new_va + sh.virtual_size; + opt_after.size_of_image = try alignUpU32(section_va_end, opt_after.section_alignment); + + // Security directory must be zero (signature invalidated by change) + const dd_ptr: *align(1) DataDirectory = &opt_after.data_directories[IMAGE_DIRECTORY_ENTRY_SECURITY]; + if (dd_ptr.virtual_address != 0 or dd_ptr.size != 0) { + dd_ptr.virtual_address = 0; + dd_ptr.size = 0; + } + + // Do not touch size_of_initialized_data (leave as is) + + // 10. Recompute checksum (recommended) + try self.recomputePEChecksum(); } /// Find the .bun section and return its data pub fn getBunSectionData(self: *const PEFile) ![]const u8 { - const section_headers = self.getSectionHeaders(); + const section_headers = try self.getSectionHeaders(); for (section_headers) |section| { - if (strings.eqlComptime(section.name[0..4], ".bun")) { + if (std.mem.eql(u8, section.name[0..8], &BUN_SECTION_NAME)) { if (section.size_of_raw_data < @sizeOf(u32)) { return error.InvalidBunSection; } @@ -309,9 +599,9 @@ pub const PEFile = struct { /// Get the length of the Bun section data pub fn getBunSectionLength(self: *const PEFile) !u32 { - const section_headers = self.getSectionHeaders(); + const section_headers = try self.getSectionHeaders(); for (section_headers) |section| { - if (strings.eqlComptime(section.name[0..4], ".bun")) { + if (std.mem.eql(u8, section.name[0..8], &BUN_SECTION_NAME)) { if (section.size_of_raw_data < @sizeOf(u32)) { return error.InvalidBunSection; } @@ -337,54 +627,105 @@ pub const PEFile = struct { /// Validate the PE file structure pub fn validate(self: *const PEFile) !void { - // Check DOS header - const dos_header = self.getDosHeader(); + // Check DOS & PE signatures + const dos_header = try self.getDosHeader(); if (dos_header.e_magic != DOS_SIGNATURE) { return error.InvalidDOSSignature; } - // Check PE header - const pe_header = self.getPEHeader(); + const pe_header = try self.getPEHeader(); if (pe_header.signature != PE_SIGNATURE) { return error.InvalidPESignature; } - // Check optional header - const optional_header = self.getOptionalHeader(); + // Check optional header magic is 0x20B (64-bit) + const optional_header = try self.getOptionalHeader(); if (optional_header.magic != OPTIONAL_HEADER_MAGIC_64) { return error.UnsupportedPEFormat; } - // Validate section headers - const section_headers = self.getSectionHeaders(); - for (section_headers) |section| { - if (section.pointer_to_raw_data + section.size_of_raw_data > self.data.items.len) { - return error.InvalidSectionData; + // Validate file_alignment, section_alignment sanity + if (!isPow2(optional_header.file_alignment) or !isPow2(optional_header.section_alignment)) { + return error.BadAlignment; + } + // Relational rule + if (optional_header.section_alignment < 4096) { + if (optional_header.file_alignment != optional_header.section_alignment) { + return error.InvalidPEFile; } } + + // Section headers region fits within size_of_headers and file + const section_headers_end = self.section_headers_offset + @sizeOf(SectionHeader) * self.num_sections; + if (section_headers_end > optional_header.size_of_headers or + section_headers_end > self.data.items.len) + { + return error.InvalidPEFile; + } + + // Validate each section + const section_headers = try self.getSectionHeaders(); + var max_va_end: u32 = 0; + + for (section_headers, 0..) |section, i| { + // If size_of_raw_data > 0, validate raw data bounds + if (section.size_of_raw_data > 0) { + if (section.pointer_to_raw_data < optional_header.size_of_headers or + section.pointer_to_raw_data + section.size_of_raw_data > self.data.items.len) + { + return error.InvalidSectionData; + } + + // Check for overlaps with other sections using correct interval test + for (section_headers[i + 1 ..]) |other| { + if (other.size_of_raw_data > 0) { + const section_start = section.pointer_to_raw_data; + const section_end = section_start + section.size_of_raw_data; + const other_start = other.pointer_to_raw_data; + const other_end = other_start + other.size_of_raw_data; + // Standard overlap test: max(start) < min(end) + if (@max(section_start, other_start) < @min(section_end, other_end)) { + return error.InvalidPEFile; // Section raw ranges overlap + } + } + } + } + + // Track max virtual address end using effective virtual size + const vs_effective = @max(section.virtual_size, section.size_of_raw_data); + const va_end = section.virtual_address + (try alignUpU32(vs_effective, optional_header.section_alignment)); + if (va_end > max_va_end) { + max_va_end = va_end; + } + } + + // Verify size_of_image equals alignUp(max(VA + alignUp(VS, SA)), SA) + const expected_size_of_image = try alignUpU32(max_va_end, optional_header.section_alignment); + if (optional_header.size_of_image != expected_size_of_image) { + return error.SizeOfImageMismatch; + } + + // Security directory should be 0,0 post-change (if we modified it) + // (This is optional validation, not critical) + + // If checksum recomputed, field should be non-zero + // (Unless we intentionally write zero, which is allowed) } }; -/// Align size to the nearest multiple of alignment -fn alignSize(size: u32, alignment: u32) u32 { - if (alignment == 0) return size; - // Check for overflow - if (size > std.math.maxInt(u32) - alignment + 1) return std.math.maxInt(u32); - return (size + alignment - 1) & ~(alignment - 1); -} - /// Utilities for PE file detection and validation pub const utils = struct { pub fn isPE(data: []const u8) bool { if (data.len < @sizeOf(PEFile.DOSHeader)) return false; - const dos_header: *const PEFile.DOSHeader = @ptrCast(@alignCast(data.ptr)); - if (dos_header.e_magic != PEFile.DOS_SIGNATURE) return false; + const dos: *align(1) const PEFile.DOSHeader = @ptrCast(data.ptr); + if (dos.e_magic != PEFile.DOS_SIGNATURE) return false; - if (data.len < dos_header.e_lfanew + @sizeOf(PEFile.PEHeader)) return false; + const off = dos.e_lfanew; + if (off < @sizeOf(PEFile.DOSHeader) or off > data.len -| @sizeOf(PEFile.PEHeader)) return false; - const pe_header: *const PEFile.PEHeader = @ptrCast(@alignCast(data.ptr + dos_header.e_lfanew)); - return pe_header.signature == PEFile.PE_SIGNATURE; + const pe: *align(1) const PEFile.PEHeader = @ptrCast(data[off..].ptr); + return pe.signature == PEFile.PE_SIGNATURE; } }; @@ -398,10 +739,8 @@ pub const BUN_COMPILED_SECTION_NAME = ".bun"; extern "C" fn Bun__getStandaloneModuleGraphPELength() u32; extern "C" fn Bun__getStandaloneModuleGraphPEData() ?[*]u8; -const std = @import("std"); - const bun = @import("bun"); -const strings = bun.strings; +const std = @import("std"); const mem = std.mem; const Allocator = mem.Allocator; From 14b62e6904a5b49fa60544f56cc3256eeb62e480 Mon Sep 17 00:00:00 2001 From: Marko Vejnovic Date: Thu, 25 Sep 2025 18:06:21 -0700 Subject: [PATCH 13/43] chore(build): Add build:debug:noasan and remove build:debug:asan (#22982) ### What does this PR do? Adds a `bun run build:debug:noasan` run script and deletes the `bun run build:debug:asan` rule. ### How did you verify your code works? Ran the change locally. --- package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package.json b/package.json index b723c8dbf0..fff1134c97 100644 --- a/package.json +++ b/package.json @@ -33,7 +33,7 @@ "bd:v": "(bun run --silent build:debug &> /tmp/bun.debug.build.log || (cat /tmp/bun.debug.build.log && rm -rf /tmp/bun.debug.build.log && exit 1)) && rm -f /tmp/bun.debug.build.log && ./build/debug/bun-debug", "bd": "BUN_DEBUG_QUIET_LOGS=1 bun --silent bd:v", "build:debug": "export COMSPEC=\"C:\\Windows\\System32\\cmd.exe\" && bun ./scripts/build.mjs -GNinja -DCMAKE_BUILD_TYPE=Debug -B build/debug --log-level=NOTICE", - "build:debug:asan": "bun ./scripts/build.mjs -GNinja -DCMAKE_BUILD_TYPE=Debug -DENABLE_ASAN=ON -B build/debug-asan --log-level=NOTICE", + "build:debug:noasan": "export COMSPEC=\"C:\\Windows\\System32\\cmd.exe\" && bun ./scripts/build.mjs -GNinja -DCMAKE_BUILD_TYPE=Debug -DENABLE_ASAN=OFF -B build/debug --log-level=NOTICE", "build:release": "bun ./scripts/build.mjs -GNinja -DCMAKE_BUILD_TYPE=Release -B build/release", "build:ci": "bun ./scripts/build.mjs -GNinja -DCMAKE_BUILD_TYPE=Release -DCMAKE_VERBOSE_MAKEFILE=ON -DCI=true -B build/release-ci --verbose --fresh", "build:assert": "bun ./scripts/build.mjs -GNinja -DCMAKE_BUILD_TYPE=RelWithDebInfo -DENABLE_ASSERTIONS=ON -DENABLE_LOGS=ON -B build/release-assert", From 51ce3bc269495afc6145c33bb67230cdd5342f6e Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Thu, 25 Sep 2025 18:03:22 -0800 Subject: [PATCH 14/43] [publish images] ci: ensure tests that require docker have it available (#22781) --- cmake/targets/BuildBun.cmake | 1 + scripts/bootstrap.sh | 13 +++++++-- scripts/utils.mjs | 6 ++++ src/bun.js/api/YAMLObject.zig | 2 +- src/bun.js/api/bun/udp_socket.zig | 2 +- src/bun.js/api/server.zig | 5 +--- src/bun.js/api/server/ServerWebSocket.zig | 22 ++++---------- src/bun.js/bindings/CatchScope.zig | 2 +- src/bun.js/bindings/JSValue.zig | 23 +++------------ src/bun.js/bindings/SQLClient.cpp | 26 +++++++++-------- src/bun.js/bindings/bindings.cpp | 9 ++---- src/bun.js/node/util/parse_args.zig | 6 ++-- src/bun.js/test/expect/toThrow.zig | 2 +- src/bun.js/webcore/Sink.zig | 5 +--- src/codegen/cppbind.ts | 34 +++++++++++++--------- src/codegen/shared-types.ts | 1 + src/dns.zig | 8 ++--- src/sql/mysql/js/JSMySQLConnection.zig | 4 +-- src/sql/mysql/protocol/ResultSet.zig | 4 +-- src/sql/postgres/DataCell.zig | 4 +-- src/sql/postgres/PostgresSQLConnection.zig | 2 +- src/sql/shared/SQLDataCell.zig | 26 +++++++++++++++++ src/valkey/js_valkey_functions.zig | 7 +---- src/valkey/valkey_protocol.zig | 2 +- test/cli/install/bun-install-proxy.test.ts | 4 +-- test/harness.ts | 6 ++++ test/internal/ban-limits.json | 4 +-- test/js/bun/s3/s3.test.ts | 6 ++-- test/js/bun/symbols.test.ts | 2 +- test/js/sql/local-sql.test.ts | 12 ++++---- test/js/valkey/test-utils.ts | 4 +-- 31 files changed, 135 insertions(+), 119 deletions(-) diff --git a/cmake/targets/BuildBun.cmake b/cmake/targets/BuildBun.cmake index ac6104c398..5e8795c1bd 100644 --- a/cmake/targets/BuildBun.cmake +++ b/cmake/targets/BuildBun.cmake @@ -1011,6 +1011,7 @@ if(LINUX) -Wl,--wrap=exp2 -Wl,--wrap=expf -Wl,--wrap=fcntl64 + -Wl,--wrap=gettid -Wl,--wrap=log -Wl,--wrap=log2 -Wl,--wrap=log2f diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index a72f614a28..3537285e05 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -1,5 +1,5 @@ #!/bin/sh -# Version: 18 +# Version: 19 # A script that installs the dependencies needed to build and test Bun. # This should work on macOS and Linux with a POSIX shell. @@ -685,6 +685,8 @@ install_common_software() { apt-transport-https \ software-properties-common fi + install_packages \ + libc6-dbg ;; dnf) install_packages \ @@ -1193,7 +1195,7 @@ install_docker() { execute_sudo amazon-linux-extras install docker ;; amzn-* | alpine-*) - install_packages docker + install_packages docker docker-cli-compose ;; *) sh="$(require sh)" @@ -1208,10 +1210,17 @@ install_docker() { if [ -f "$systemctl" ]; then execute_sudo "$systemctl" enable docker fi + if [ "$os" = "linux" ] && [ "$distro" = "alpine" ]; then + execute doas rc-update add docker default + execute doas rc-service docker start + fi getent="$(which getent)" if [ -n "$("$getent" group docker)" ]; then usermod="$(which usermod)" + if [ -z "$usermod" ]; then + usermod="$(sudo which usermod)" + fi if [ -f "$usermod" ]; then execute_sudo "$usermod" -aG docker "$user" fi diff --git a/scripts/utils.mjs b/scripts/utils.mjs index cc7b205150..604227f9cd 100755 --- a/scripts/utils.mjs +++ b/scripts/utils.mjs @@ -2866,6 +2866,12 @@ export function printEnvironment() { spawnSync([shell, "-c", "free -m -w"], { stdio: "inherit" }); } }); + startGroup("Docker", () => { + const shell = which(["sh", "bash"]); + if (shell) { + spawnSync([shell, "-c", "docker ps"], { stdio: "inherit" }); + } + }); } if (isWindows) { startGroup("Disk (win)", () => { diff --git a/src/bun.js/api/YAMLObject.zig b/src/bun.js/api/YAMLObject.zig index ad5e5d8da9..3d18dfa4ec 100644 --- a/src/bun.js/api/YAMLObject.zig +++ b/src/bun.js/api/YAMLObject.zig @@ -1030,7 +1030,7 @@ const ParserCtx = struct { const key_str = try key.toBunString(ctx.global); defer key_str.deref(); - obj.putMayBeIndex(ctx.global, &key_str, value); + try obj.putMayBeIndex(ctx.global, &key_str, value); } return obj; diff --git a/src/bun.js/api/bun/udp_socket.zig b/src/bun.js/api/bun/udp_socket.zig index e34eadfa88..1d0a78a14a 100644 --- a/src/bun.js/api/bun/udp_socket.zig +++ b/src/bun.js/api/bun/udp_socket.zig @@ -625,7 +625,7 @@ pub const UDPSocket = struct { if (val.asArrayBuffer(globalThis)) |arrayBuffer| { break :brk arrayBuffer.slice(); } else if (val.isString()) { - break :brk val.toString(globalThis).toSlice(globalThis, alloc).slice(); + break :brk (try val.toJSString(globalThis)).toSlice(globalThis, alloc).slice(); } else { return globalThis.throwInvalidArguments("Expected ArrayBufferView or string as payload", .{}); } diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 0ea9ff488b..aa17c485e1 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -719,10 +719,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d } { - var js_string = message_value.toString(globalThis); - if (globalThis.hasException()) { - return .zero; - } + var js_string = try message_value.toJSString(globalThis); const view = js_string.view(globalThis); const slice = view.toSlice(bun.default_allocator); defer slice.deinit(); diff --git a/src/bun.js/api/server/ServerWebSocket.zig b/src/bun.js/api/server/ServerWebSocket.zig index ffd5bec449..a244b2dc29 100644 --- a/src/bun.js/api/server/ServerWebSocket.zig +++ b/src/bun.js/api/server/ServerWebSocket.zig @@ -438,10 +438,7 @@ pub fn publish( } { - var js_string = message_value.toString(globalThis); - if (globalThis.hasException()) { - return .zero; - } + var js_string = try message_value.toJSString(globalThis); const view = js_string.view(globalThis); const slice = view.toSlice(bun.default_allocator); defer slice.deinit(); @@ -505,10 +502,7 @@ pub fn publishText( return globalThis.throw("publishText requires a non-empty message", .{}); } - var js_string = message_value.toString(globalThis); - if (globalThis.hasException()) { - return .zero; - } + var js_string = try message_value.toJSString(globalThis); const view = js_string.view(globalThis); const slice = view.toSlice(bun.default_allocator); defer slice.deinit(); @@ -756,10 +750,7 @@ pub fn send( } { - var js_string = message_value.toString(globalThis); - if (globalThis.hasException()) { - return .zero; - } + var js_string = try message_value.toJSString(globalThis); const view = js_string.view(globalThis); const slice = view.toSlice(bun.default_allocator); defer slice.deinit(); @@ -814,10 +805,7 @@ pub fn sendText( return globalThis.throw("sendText expects a string", .{}); } - var js_string = message_value.toString(globalThis); - if (globalThis.hasException()) { - return .zero; - } + var js_string = try message_value.toJSString(globalThis); const view = js_string.view(globalThis); const slice = view.toSlice(bun.default_allocator); defer slice.deinit(); @@ -997,7 +985,7 @@ inline fn sendPing( }, } } else if (value.isString()) { - var string_value = value.toString(globalThis).toSlice(globalThis, bun.default_allocator); + var string_value = (try value.toJSString(globalThis)).toSlice(globalThis, bun.default_allocator); defer string_value.deinit(); const buffer = string_value.slice(); diff --git a/src/bun.js/bindings/CatchScope.zig b/src/bun.js/bindings/CatchScope.zig index 22c6014159..d20f42cc28 100644 --- a/src/bun.js/bindings/CatchScope.zig +++ b/src/bun.js/bindings/CatchScope.zig @@ -77,7 +77,7 @@ pub const CatchScope = struct { /// Intended for use with `try`. Returns if there is already a pending exception or if traps cause /// an exception to be thrown (this is the same as how RETURN_IF_EXCEPTION behaves in C++) - pub fn returnIfException(self: *CatchScope) bun.JSError!void { + pub fn returnIfException(self: *CatchScope) !void { if (self.exceptionIncludingTraps() != null) return error.JSError; } diff --git a/src/bun.js/bindings/JSValue.zig b/src/bun.js/bindings/JSValue.zig index a6cc5029e8..f8a54edcd2 100644 --- a/src/bun.js/bindings/JSValue.zig +++ b/src/bun.js/bindings/JSValue.zig @@ -347,12 +347,11 @@ pub const JSValue = enum(i64) { @compileError("Unsupported key type in put(). Expected ZigString or bun.String, got " ++ @typeName(Key)); } } - extern fn JSC__JSValue__putMayBeIndex(target: JSValue, globalObject: *JSGlobalObject, key: *const String, value: jsc.JSValue) void; /// Note: key can't be numeric (if so, use putMayBeIndex instead) /// Same as `.put` but accepts both non-numeric and numeric keys. /// Prefer to use `.put` if the key is guaranteed to be non-numeric (e.g. known at comptime) - pub inline fn putMayBeIndex(this: JSValue, globalObject: *JSGlobalObject, key: *const String, value: JSValue) void { - JSC__JSValue__putMayBeIndex(this, globalObject, key, value); + pub fn putMayBeIndex(this: JSValue, globalObject: *JSGlobalObject, key: *const String, value: JSValue) bun.JSError!void { + return bun.cpp.JSC__JSValue__putMayBeIndex(this, globalObject, key, value); } extern fn JSC__JSValue__putToPropertyKey(target: JSValue, globalObject: *JSGlobalObject, key: jsc.JSValue, value: jsc.JSValue) void; @@ -1191,10 +1190,8 @@ pub const JSValue = enum(i64) { return getZigString(this, global).toSliceZ(allocator); } - extern fn JSC__JSValue__toString(this: JSValue, globalThis: *JSGlobalObject) *JSString; - /// On exception, this returns the empty string. - pub fn toString(this: JSValue, globalThis: *JSGlobalObject) *JSString { - return JSC__JSValue__toString(this, globalThis); + pub fn toJSString(this: JSValue, globalThis: *JSGlobalObject) bun.JSError!*JSString { + return bun.cpp.JSC__JSValue__toStringOrNull(this, globalThis); } extern fn JSC__JSValue__jsonStringify(this: JSValue, globalThis: *JSGlobalObject, indent: u32, out: *bun.String) void; @@ -1202,17 +1199,6 @@ pub const JSValue = enum(i64) { return bun.jsc.fromJSHostCallGeneric(globalThis, @src(), JSC__JSValue__jsonStringify, .{ this, globalThis, indent, out }); } - extern fn JSC__JSValue__toStringOrNull(this: JSValue, globalThis: *JSGlobalObject) ?*JSString; - // Calls JSValue::toStringOrNull. Returns error on exception. - pub fn toJSString(this: JSValue, globalThis: *JSGlobalObject) bun.JSError!*JSString { - var scope: ExceptionValidationScope = undefined; - scope.init(globalThis, @src()); - defer scope.deinit(); - const maybe_string = JSC__JSValue__toStringOrNull(this, globalThis); - scope.assertExceptionPresenceMatches(maybe_string == null); - return maybe_string orelse error.JSError; - } - /// Call `toString()` on the JSValue and clone the result. pub fn toSliceOrNull(this: JSValue, globalThis: *JSGlobalObject) bun.JSError!ZigString.Slice { const str = try bun.String.fromJS(this, globalThis); @@ -2424,7 +2410,6 @@ const ArrayBuffer = jsc.ArrayBuffer; const C_API = bun.jsc.C; const CatchScope = jsc.CatchScope; const DOMURL = jsc.DOMURL; -const ExceptionValidationScope = jsc.ExceptionValidationScope; const JSArrayIterator = jsc.JSArrayIterator; const JSCell = jsc.JSCell; const JSGlobalObject = jsc.JSGlobalObject; diff --git a/src/bun.js/bindings/SQLClient.cpp b/src/bun.js/bindings/SQLClient.cpp index c9c8f41313..00c08ab91d 100644 --- a/src/bun.js/bindings/SQLClient.cpp +++ b/src/bun.js/bindings/SQLClient.cpp @@ -161,10 +161,10 @@ static JSC::JSValue toJS(JSC::VM& vm, JSC::JSGlobalObject* globalObject, DataCel return jsNumber(cell.value.unsigned_integer); break; case DataCellTag::Bigint: - return JSC::JSBigInt::createFrom(globalObject, cell.value.bigint); + RELEASE_AND_RETURN(scope, JSC::JSBigInt::createFrom(globalObject, cell.value.bigint)); break; case DataCellTag::UnsignedBigint: - return JSC::JSBigInt::createFrom(globalObject, cell.value.unsigned_bigint); + RELEASE_AND_RETURN(scope, JSC::JSBigInt::createFrom(globalObject, cell.value.unsigned_bigint)); break; case DataCellTag::Boolean: return jsBoolean(cell.value.boolean); @@ -189,6 +189,7 @@ static JSC::JSValue toJS(JSC::VM& vm, JSC::JSGlobalObject* globalObject, DataCel if (cell.value.json) { auto str = WTF::String(cell.value.json); JSC::JSValue json = JSC::JSONParse(globalObject, str); + RETURN_IF_EXCEPTION(scope, {}); return json; } return jsNull(); @@ -198,14 +199,10 @@ static JSC::JSValue toJS(JSC::VM& vm, JSC::JSGlobalObject* globalObject, DataCel uint32_t length = cell.value.array.length; for (uint32_t i = 0; i < length; i++) { JSValue result = toJS(vm, globalObject, cell.value.array.cells[i]); - if (result.isEmpty()) [[unlikely]] { - return {}; - } - + RETURN_IF_EXCEPTION(scope, {}); args.append(result); } - - return JSC::constructArray(globalObject, static_cast(nullptr), args); + RELEASE_AND_RETURN(scope, JSC::constructArray(globalObject, static_cast(nullptr), args)); } case DataCellTag::TypedArray: { JSC::JSType type = static_cast(cell.value.typed_array.type); @@ -343,6 +340,7 @@ static JSC::JSValue toJS(JSC::Structure* structure, DataCell* cells, uint32_t co // -> { "8": 1, "2": 2, "3": 3 } // 8 > count object->putDirectIndex(globalObject, cell.index, value); + RETURN_IF_EXCEPTION(scope, {}); } } else { uint32_t structureOffsetIndex = 0; @@ -356,6 +354,7 @@ static JSC::JSValue toJS(JSC::Structure* structure, DataCell* cells, uint32_t co ASSERT(!cell.isNamedColumn()); ASSERT(!cell.isDuplicateColumn()); object->putDirectIndex(globalObject, cell.index, value); + RETURN_IF_EXCEPTION(scope, {}); } else if (cell.isNamedColumn()) { JSValue value = toJS(vm, globalObject, cell); RETURN_IF_EXCEPTION(scope, {}); @@ -387,6 +386,7 @@ static JSC::JSValue toJS(JSC::Structure* structure, DataCell* cells, uint32_t co JSValue value = toJS(vm, globalObject, cell); RETURN_IF_EXCEPTION(scope, {}); array->putDirectIndex(globalObject, i, value); + RETURN_IF_EXCEPTION(scope, {}); } return array; } @@ -399,20 +399,22 @@ static JSC::JSValue toJS(JSC::Structure* structure, DataCell* cells, uint32_t co } static JSC::JSValue toJS(JSC::JSArray* array, JSC::Structure* structure, DataCell* cells, uint32_t count, JSC::JSGlobalObject* globalObject, Bun::BunStructureFlags flags, BunResultMode result_mode, ExternColumnIdentifier* namesPtr, uint32_t namesCount) { + auto& vm = JSC::getVM(globalObject); + auto scope = DECLARE_THROW_SCOPE(vm); JSValue value = toJS(structure, cells, count, globalObject, flags, result_mode, namesPtr, namesCount); - if (value.isEmpty()) - return {}; + RETURN_IF_EXCEPTION(scope, {}); if (array) { array->push(globalObject, value); + RETURN_IF_EXCEPTION(scope, {}); return array; } auto* newArray = JSC::constructEmptyArray(globalObject, static_cast(nullptr), 1); - if (!newArray) - return {}; + RETURN_IF_EXCEPTION(scope, {}); newArray->putDirectIndex(globalObject, 0, value); + RETURN_IF_EXCEPTION(scope, {}); return newArray; } diff --git a/src/bun.js/bindings/bindings.cpp b/src/bun.js/bindings/bindings.cpp index fd10d44f99..496c5e14f8 100644 --- a/src/bun.js/bindings/bindings.cpp +++ b/src/bun.js/bindings/bindings.cpp @@ -3713,7 +3713,7 @@ void JSC__JSValue__putToPropertyKey(JSC::EncodedJSValue JSValue0, JSC::JSGlobalO object->putDirectMayBeIndex(arg1, pkey, value); } -extern "C" void JSC__JSValue__putMayBeIndex(JSC::EncodedJSValue target, JSC::JSGlobalObject* globalObject, const BunString* key, JSC::EncodedJSValue value) +extern "C" [[ZIG_EXPORT(check_slow)]] void JSC__JSValue__putMayBeIndex(JSC::EncodedJSValue target, JSC::JSGlobalObject* globalObject, const BunString* key, JSC::EncodedJSValue value) { auto& vm = JSC::getVM(globalObject); ThrowScope scope = DECLARE_THROW_SCOPE(vm); @@ -4346,12 +4346,7 @@ JSC::JSObject* JSC__JSValue__toObject(JSC::EncodedJSValue JSValue0, JSC::JSGloba return value.toObject(arg1); } -JSC::JSString* JSC__JSValue__toString(JSC::EncodedJSValue JSValue0, JSC::JSGlobalObject* arg1) -{ - JSC::JSValue value = JSC::JSValue::decode(JSValue0); - return value.toString(arg1); -}; -JSC::JSString* JSC__JSValue__toStringOrNull(JSC::EncodedJSValue JSValue0, JSC::JSGlobalObject* arg1) +[[ZIG_EXPORT(null_is_throw)]] JSC::JSString* JSC__JSValue__toStringOrNull(JSC::EncodedJSValue JSValue0, JSC::JSGlobalObject* arg1) { JSC::JSValue value = JSC::JSValue::decode(JSValue0); return value.toStringOrNull(arg1); diff --git a/src/bun.js/node/util/parse_args.zig b/src/bun.js/node/util/parse_args.zig index 37513e185e..5904dff0aa 100644 --- a/src/bun.js/node/util/parse_args.zig +++ b/src/bun.js/node/util/parse_args.zig @@ -272,10 +272,10 @@ fn storeOption(globalThis: *JSGlobalObject, option_name: ValueRef, option_value: } else { var value_list = try JSValue.createEmptyArray(globalThis, 1); try value_list.putIndex(globalThis, 0, new_value); - values.putMayBeIndex(globalThis, &key, value_list); + try values.putMayBeIndex(globalThis, &key, value_list); } } else { - values.putMayBeIndex(globalThis, &key, new_value); + try values.putMayBeIndex(globalThis, &key, new_value); } } @@ -723,7 +723,7 @@ pub fn parseArgs(globalThis: *JSGlobalObject, callframe: *jsc.CallFrame) bun.JSE if (!option.long_name.eqlComptime("__proto__")) { if (try state.values.getOwn(globalThis, option.long_name) == null) { log(" Setting \"{}\" to default value", .{option.long_name}); - state.values.putMayBeIndex(globalThis, &option.long_name, default_value); + try state.values.putMayBeIndex(globalThis, &option.long_name, default_value); } } } diff --git a/src/bun.js/test/expect/toThrow.zig b/src/bun.js/test/expect/toThrow.zig index fda78e44b7..0fd4e1f28c 100644 --- a/src/bun.js/test/expect/toThrow.zig +++ b/src/bun.js/test/expect/toThrow.zig @@ -23,7 +23,7 @@ pub fn toThrow(this: *Expect, globalThis: *JSGlobalObject, callFrame: *CallFrame } } else if (value.isString()) { // `.toThrow("") behaves the same as `.toThrow()` - const s = value.toString(globalThis); + const s = try value.toJSString(globalThis); if (s.length() == 0) break :brk .zero; } break :brk value; diff --git a/src/bun.js/webcore/Sink.zig b/src/bun.js/webcore/Sink.zig index f26a613feb..6ad9e07c9a 100644 --- a/src/bun.js/webcore/Sink.zig +++ b/src/bun.js/webcore/Sink.zig @@ -403,10 +403,7 @@ pub fn JSSink(comptime SinkType: type, comptime abi_name: []const u8) type { return globalThis.throwValue(globalThis.toTypeError(.INVALID_ARG_TYPE, "write() expects a string, ArrayBufferView, or ArrayBuffer", .{})); } - const str = arg.toString(globalThis); - if (globalThis.hasException()) { - return .zero; - } + const str = try arg.toJSString(globalThis); const view = str.view(globalThis); diff --git a/src/codegen/cppbind.ts b/src/codegen/cppbind.ts index ba01d3bcad..c269256cbe 100644 --- a/src/codegen/cppbind.ts +++ b/src/codegen/cppbind.ts @@ -45,8 +45,8 @@ To run manually: - **[[ZIG_NONNULL]]** - Mark pointer parameters as non-nullable: ```cpp - [[ZIG_EXPORT(nothrow)]] void process([[ZIG_NONNULL]] JSGlobalObject* globalThis, - [[ZIG_NONNULL]] JSValue* values, + [[ZIG_EXPORT(nothrow)]] void process([[ZIG_NONNULL]] JSGlobalObject* globalThis, + [[ZIG_NONNULL]] JSValue* values, size_t count) { ... } ``` Generates: `pub extern fn process(globalThis: *jsc.JSGlobalObject, values: [*]const jsc.JSValue) void;` @@ -397,7 +397,7 @@ function processFunction(ctx: ParseContext, node: SyntaxNode, tag: ExportTag): C }; } -type ExportTag = "check_slow" | "zero_is_throw" | "false_is_throw" | "nothrow"; +type ExportTag = "check_slow" | "zero_is_throw" | "false_is_throw" | "null_is_throw" | "nothrow"; const sharedTypesText = await Bun.file("src/codegen/shared-types.ts").text(); const sharedTypesLines = sharedTypesText.split("\n"); @@ -570,7 +570,8 @@ async function processFile(parser: CppParser, file: string, allFunctions: CppFn[ tagStr === "nothrow" || tagStr === "zero_is_throw" || tagStr === "check_slow" || - tagStr === "false_is_throw" + tagStr === "false_is_throw" || + tagStr === "null_is_throw" ) { tag = tagStr; } else if (tagStr === "print") { @@ -580,7 +581,7 @@ async function processFile(parser: CppParser, file: string, allFunctions: CppFn[ } else { appendError( nodePosition(tagIdentifier, ctx), - "tag must be nothrow, zero_is_throw, check_slow, or false_is_throw: " + tagStr, + "tag must be nothrow, zero_is_throw, check_slow, false_is_throw, or null_is_throw: " + tagStr, ); tag = "nothrow"; } @@ -639,7 +640,7 @@ function generateZigFn( resultSourceLinks: string[], cfg: Cfg, ): void { - const returnType = generateZigType(fn.returnType, null); + let returnType = generateZigType(fn.returnType, null); if (resultBindings.length) resultBindings.push(""); resultBindings.push(generateZigSourceComment(cfg, resultSourceLinks, fn)); if (fn.tag === "nothrow") { @@ -667,7 +668,7 @@ function generateZigFn( ); } resultBindings.push( - `pub inline fn ${formatZigName(fn.name)}(${generateZigParameterList(fn.parameters, globalThisArg)}) bun.JSError!${returnType} {`, + `pub fn ${formatZigName(fn.name)}(${generateZigParameterList(fn.parameters, globalThisArg)}) error{JSError}!${returnType} {`, ` if (comptime Environment.ci_assert) {`, ` var scope: jsc.CatchScope = undefined;`, ` scope.init(${formatZigName(globalThisArg.name)}, @src());`, @@ -697,9 +698,16 @@ function generateZigFn( if (returnType !== "bool") { appendError(fn.position, "ZIG_EXPORT(false_is_throw) is only allowed for functions that return bool"); } + returnType = "void"; + } else if (fn.tag === "null_is_throw") { + equalsValue = "null"; + if (!returnType.startsWith("?*")) { + appendError(fn.position, "ZIG_EXPORT(null_is_throw) is only allowed for functions that return optional pointer"); + } + returnType = returnType.slice(1); } else assertNever(fn.tag); resultBindings.push( - `pub inline fn ${formatZigName(fn.name)}(${generateZigParameterList(fn.parameters, globalThisArg)}) bun.JSError!${fn.tag === "false_is_throw" ? "void" : returnType} {`, + `pub fn ${formatZigName(fn.name)}(${generateZigParameterList(fn.parameters, globalThisArg)}) error{JSError}!${returnType} {`, ` if (comptime Environment.ci_assert) {`, ` var scope: jsc.ExceptionValidationScope = undefined;`, ` scope.init(${formatZigName(globalThisArg.name)}, @src());`, @@ -707,11 +715,11 @@ function generateZigFn( ``, ` const value = raw.${formatZigName(fn.name)}(${fn.parameters.map(p => formatZigName(p.name)).join(", ")});`, ` scope.assertExceptionPresenceMatches(value == ${equalsValue});`, - ` return if (value == ${equalsValue}) error.JSError ${fn.tag === "false_is_throw" ? "" : "else value"};`, + ` return if (value == ${equalsValue}) error.JSError ${fn.tag === "false_is_throw" ? "" : "else value"}${fn.tag === "null_is_throw" ? ".?" : ""};`, ` } else {`, ` const value = raw.${formatZigName(fn.name)}(${fn.parameters.map(p => formatZigName(p.name)).join(", ")});`, ` if (value == ${equalsValue}) return error.JSError;`, - ...(fn.tag === "false_is_throw" ? [] : [` return value;`]), + ...(fn.tag === "false_is_throw" ? [] : [` return value${fn.tag === "null_is_throw" ? ".?" : ""};`]), ` }`, `}`, ); @@ -733,14 +741,14 @@ async function main() { if (!dstDir) { console.error( String.raw` - _ _ _ + _ _ _ | | (_) | | ___ _ __ _ __ | |__ _ _ __ __| | / __| '_ \| '_ \| '_ \| | '_ \ / _' | | (__| |_) | |_) | |_) | | | | | (_| | \___| .__/| .__/|_.__/|_|_| |_|\__,_| - | | | | - |_| |_| + | | | | + |_| |_| `.slice(1), ); console.error("Usage: bun src/codegen/cppbind src build/debug/codegen"); diff --git a/src/codegen/shared-types.ts b/src/codegen/shared-types.ts index ec2b2b6e24..9611a3983d 100644 --- a/src/codegen/shared-types.ts +++ b/src/codegen/shared-types.ts @@ -56,6 +56,7 @@ export const sharedTypes: Record = { "JSC::SourceProvider": "bun.jsc.SourceProvider", "JSC::CallFrame": "bun.jsc.CallFrame", "JSC::JSObject": "bun.jsc.JSObject", + "JSC::JSString": "bun.jsc.JSString", }; export const bannedTypes: Record = { diff --git a/src/dns.zig b/src/dns.zig index e3bc939830..0875024c4c 100644 --- a/src/dns.zig +++ b/src/dns.zig @@ -150,7 +150,7 @@ pub const GetAddrInfo = struct { if (value.isString()) { return try map.fromJS(globalObject, value) orelse { - if (value.toString(globalObject).length() == 0) { + if ((try value.toJSString(globalObject)).length() == 0) { return .unspecified; } @@ -211,7 +211,7 @@ pub const GetAddrInfo = struct { if (value.isString()) { return try map.fromJS(globalObject, value) orelse { - if (value.toString(globalObject).length() == 0) + if ((try value.toJSString(globalObject)).length() == 0) return .unspecified; return error.InvalidSocketType; @@ -251,7 +251,7 @@ pub const GetAddrInfo = struct { if (value.isString()) { return try map.fromJS(globalObject, value) orelse { - const str = value.toString(globalObject); + const str = try value.toJSString(globalObject); if (str.length() == 0) return .unspecified; @@ -301,7 +301,7 @@ pub const GetAddrInfo = struct { if (value.isString()) { return try label.fromJS(globalObject, value) orelse { - if (value.toString(globalObject).length() == 0) { + if ((try value.toJSString(globalObject)).length() == 0) { return default; } diff --git a/src/sql/mysql/js/JSMySQLConnection.zig b/src/sql/mysql/js/JSMySQLConnection.zig index b1b52c0608..92b31c1b5d 100644 --- a/src/sql/mysql/js/JSMySQLConnection.zig +++ b/src/sql/mysql/js/JSMySQLConnection.zig @@ -667,7 +667,7 @@ pub fn onConnectionEstabilished(this: *@This()) void { pub fn onQueryResult(this: *@This(), request: *JSMySQLQuery, result: MySQLQueryResult) void { request.resolve(this.getQueriesArray(), result); } -pub fn onResultRow(this: *@This(), request: *JSMySQLQuery, statement: *MySQLStatement, Context: type, reader: NewReader(Context)) error{ShortRead}!void { +pub fn onResultRow(this: *@This(), request: *JSMySQLQuery, statement: *MySQLStatement, Context: type, reader: NewReader(Context)) (error{ ShortRead, JSError })!void { const result_mode = request.getResultMode(); var stack_fallback = std.heap.stackFallback(4096, bun.default_allocator); const allocator = stack_fallback.get(); @@ -700,7 +700,7 @@ pub fn onResultRow(this: *@This(), request: *JSMySQLQuery, statement: *MySQLStat }; const pending_value = request.getPendingValue() orelse .js_undefined; // Process row data - const row_value = row.toJS( + const row_value = try row.toJS( this.#globalObject, pending_value, structure, diff --git a/src/sql/mysql/protocol/ResultSet.zig b/src/sql/mysql/protocol/ResultSet.zig index d5a06d117f..1d49650869 100644 --- a/src/sql/mysql/protocol/ResultSet.zig +++ b/src/sql/mysql/protocol/ResultSet.zig @@ -8,7 +8,7 @@ pub const Row = struct { bigint: bool = false, globalObject: *jsc.JSGlobalObject, - pub fn toJS(this: *Row, globalObject: *jsc.JSGlobalObject, array: JSValue, structure: JSValue, flags: SQLDataCell.Flags, result_mode: SQLQueryResultMode, cached_structure: ?CachedStructure) JSValue { + pub fn toJS(this: *Row, globalObject: *jsc.JSGlobalObject, array: JSValue, structure: JSValue, flags: SQLDataCell.Flags, result_mode: SQLQueryResultMode, cached_structure: ?CachedStructure) !JSValue { var names: ?[*]jsc.JSObject.ExternColumnIdentifier = null; var names_count: u32 = 0; if (cached_structure) |c| { @@ -18,7 +18,7 @@ pub const Row = struct { } } - return SQLDataCell.JSC__constructObjectFromDataCell( + return SQLDataCell.constructObjectFromDataCell( globalObject, array, structure, diff --git a/src/sql/postgres/DataCell.zig b/src/sql/postgres/DataCell.zig index 98a721a735..74f32034f3 100644 --- a/src/sql/postgres/DataCell.zig +++ b/src/sql/postgres/DataCell.zig @@ -906,7 +906,7 @@ pub const Putter = struct { count: usize = 0, globalObject: *jsc.JSGlobalObject, - pub fn toJS(this: *Putter, globalObject: *jsc.JSGlobalObject, array: JSValue, structure: JSValue, flags: SQLDataCell.Flags, result_mode: PostgresSQLQueryResultMode, cached_structure: ?PostgresCachedStructure) JSValue { + pub fn toJS(this: *Putter, globalObject: *jsc.JSGlobalObject, array: JSValue, structure: JSValue, flags: SQLDataCell.Flags, result_mode: PostgresSQLQueryResultMode, cached_structure: ?PostgresCachedStructure) !JSValue { var names: ?[*]jsc.JSObject.ExternColumnIdentifier = null; var names_count: u32 = 0; if (cached_structure) |c| { @@ -916,7 +916,7 @@ pub const Putter = struct { } } - return SQLDataCell.JSC__constructObjectFromDataCell( + return SQLDataCell.constructObjectFromDataCell( globalObject, array, structure, diff --git a/src/sql/postgres/PostgresSQLConnection.zig b/src/sql/postgres/PostgresSQLConnection.zig index 0d2445a5ff..e176281713 100644 --- a/src/sql/postgres/PostgresSQLConnection.zig +++ b/src/sql/postgres/PostgresSQLConnection.zig @@ -1429,7 +1429,7 @@ pub fn on(this: *PostgresSQLConnection, comptime MessageType: @Type(.enum_litera }; const pending_value = PostgresSQLQuery.js.pendingValueGetCached(thisValue) orelse .zero; pending_value.ensureStillAlive(); - const result = putter.toJS( + const result = try putter.toJS( this.globalObject, pending_value, structure, diff --git a/src/sql/shared/SQLDataCell.zig b/src/sql/shared/SQLDataCell.zig index 1cf73d6edb..059ba6e71d 100644 --- a/src/sql/shared/SQLDataCell.zig +++ b/src/sql/shared/SQLDataCell.zig @@ -141,6 +141,32 @@ pub const SQLDataCell = extern struct { _: u29 = 0, }; + // TODO: cppbind isn't yet able to detect slice parameters when the next is uint32_t + pub fn constructObjectFromDataCell( + globalObject: *jsc.JSGlobalObject, + encodedArrayValue: jsc.JSValue, + encodedStructureValue: jsc.JSValue, + cells: [*]SQLDataCell, + count: u32, + flags: SQLDataCell.Flags, + result_mode: u8, + namesPtr: ?[*]bun.jsc.JSObject.ExternColumnIdentifier, + namesCount: u32, + ) !jsc.JSValue { + if (comptime bun.Environment.ci_assert) { + var scope: jsc.ExceptionValidationScope = undefined; + scope.init(globalObject, @src()); + defer scope.deinit(); + const value = JSC__constructObjectFromDataCell(globalObject, encodedArrayValue, encodedStructureValue, cells, count, flags, result_mode, namesPtr, namesCount); + scope.assertExceptionPresenceMatches(value == .zero); + return if (value == .zero) error.JSError else value; + } else { + const value = JSC__constructObjectFromDataCell(globalObject, encodedArrayValue, encodedStructureValue, cells, count, flags, result_mode, namesPtr, namesCount); + if (value == .zero) return error.JSError; + return value; + } + } + pub extern fn JSC__constructObjectFromDataCell( *jsc.JSGlobalObject, JSValue, diff --git a/src/valkey/js_valkey_functions.zig b/src/valkey/js_valkey_functions.zig index 79cb990e92..4b6d2daad6 100644 --- a/src/valkey/js_valkey_functions.zig +++ b/src/valkey/js_valkey_functions.zig @@ -833,12 +833,7 @@ fn fromJS(globalObject: *jsc.JSGlobalObject, value: JSValue) !?JSArgument { if (value.isNumber()) { // Allow numbers to be passed as strings. - const str = value.toString(globalObject); - if (globalObject.hasException()) { - @branchHint(.unlikely); - return error.JSError; - } - + const str = try value.toJSString(globalObject); return try JSArgument.fromJSMaybeFile(globalObject, bun.default_allocator, str.toJS(), true); } diff --git a/src/valkey/valkey_protocol.zig b/src/valkey/valkey_protocol.zig index 39e27dd2dc..a02514f7c6 100644 --- a/src/valkey/valkey_protocol.zig +++ b/src/valkey/valkey_protocol.zig @@ -288,7 +288,7 @@ pub const RESPValue = union(RESPType) { defer key_str.deref(); const js_value = try entry.value.toJSWithOptions(globalObject, options); - js_obj.putMayBeIndex(globalObject, &key_str, js_value); + try js_obj.putMayBeIndex(globalObject, &key_str, js_value); } return js_obj; }, diff --git a/test/cli/install/bun-install-proxy.test.ts b/test/cli/install/bun-install-proxy.test.ts index 85acc8711c..40672217c0 100644 --- a/test/cli/install/bun-install-proxy.test.ts +++ b/test/cli/install/bun-install-proxy.test.ts @@ -1,11 +1,11 @@ import { beforeAll, it } from "bun:test"; import { exec } from "child_process"; import { rm } from "fs/promises"; -import { bunEnv, bunExe, isDockerEnabled, tempDirWithFiles } from "harness"; +import { bunEnv, bunExe, dockerExe, isDockerEnabled, tempDirWithFiles } from "harness"; import { join } from "path"; import { promisify } from "util"; const execAsync = promisify(exec); -const dockerCLI = Bun.which("docker") as string; +const dockerCLI = dockerExe() as string; const SQUID_URL = "http://127.0.0.1:3128"; if (isDockerEnabled()) { beforeAll(async () => { diff --git a/test/harness.ts b/test/harness.ts index c0d2ac7b8f..9b46505c57 100644 --- a/test/harness.ts +++ b/test/harness.ts @@ -860,6 +860,9 @@ export function dockerExe(): string | null { export function isDockerEnabled(): boolean { const dockerCLI = dockerExe(); if (!dockerCLI) { + if (isCI && isLinux) { + throw new Error("A functional `docker` is required in CI for some tests."); + } return false; } @@ -872,6 +875,9 @@ export function isDockerEnabled(): boolean { const info = execSync(`"${dockerCLI}" info`, { stdio: ["ignore", "pipe", "inherit"] }); return info.toString().indexOf("Server Version:") !== -1; } catch { + if (isCI && isLinux) { + throw new Error("A functional `docker` is required in CI for some tests."); + } return false; } } diff --git a/test/internal/ban-limits.json b/test/internal/ban-limits.json index 6980821328..bee82aa703 100644 --- a/test/internal/ban-limits.json +++ b/test/internal/ban-limits.json @@ -22,8 +22,8 @@ "allocator.ptr !=": 1, "allocator.ptr ==": 0, "global.hasException": 28, - "globalObject.hasException": 48, - "globalThis.hasException": 133, + "globalObject.hasException": 47, + "globalThis.hasException": 127, "std.StringArrayHashMap(": 1, "std.StringArrayHashMapUnmanaged(": 11, "std.StringHashMap(": 0, diff --git a/test/js/bun/s3/s3.test.ts b/test/js/bun/s3/s3.test.ts index 6de01baacc..8c8034785a 100644 --- a/test/js/bun/s3/s3.test.ts +++ b/test/js/bun/s3/s3.test.ts @@ -1,9 +1,9 @@ import type { S3Options } from "bun"; -import { S3Client, s3 as defaultS3, file, randomUUIDv7, which } from "bun"; +import { S3Client, s3 as defaultS3, file, randomUUIDv7 } from "bun"; import { describe, expect, it } from "bun:test"; import child_process from "child_process"; import { randomUUID } from "crypto"; -import { bunEnv, bunExe, getSecret, isCI, isDockerEnabled, tempDirWithFiles } from "harness"; +import { bunEnv, bunExe, dockerExe, getSecret, isCI, isDockerEnabled, tempDirWithFiles } from "harness"; import path from "path"; const s3 = (...args) => defaultS3.file(...args); const S3 = (...args) => new S3Client(...args); @@ -11,7 +11,7 @@ const S3 = (...args) => new S3Client(...args); // Import docker-compose helper import * as dockerCompose from "../../../docker/index.ts"; -const dockerCLI = which("docker") as string; +const dockerCLI = dockerExe() as string; type S3Credentials = S3Options & { service: string; }; diff --git a/test/js/bun/symbols.test.ts b/test/js/bun/symbols.test.ts index 621871b6d6..ecc4ed441d 100644 --- a/test/js/bun/symbols.test.ts +++ b/test/js/bun/symbols.test.ts @@ -33,7 +33,7 @@ if (process.platform === "linux") { throw new Error(`Found glibc symbols > 2.26. This breaks Amazon Linux 2 and Vercel. ${Bun.inspect.table(errors, { colors: true })} -To fix this, add it to -Wl,-wrap=symbol in the linker flags and update workaround-missing-symbols.cpp.`); +To fix this, add it to -Wl,--wrap=symbol in the linker flags and update workaround-missing-symbols.cpp.`); } }); diff --git a/test/js/sql/local-sql.test.ts b/test/js/sql/local-sql.test.ts index 527bb57e8b..b9af87ae07 100644 --- a/test/js/sql/local-sql.test.ts +++ b/test/js/sql/local-sql.test.ts @@ -1,6 +1,6 @@ import { SQL } from "bun"; import { afterAll, expect, test } from "bun:test"; -import { bunEnv, bunExe, isDockerEnabled, tempDirWithFiles } from "harness"; +import { bunEnv, bunExe, dockerExe, isDockerEnabled, tempDirWithFiles } from "harness"; import path from "path"; const postgres = (...args) => new SQL(...args); @@ -9,7 +9,7 @@ import net from "net"; import { promisify } from "util"; const execAsync = promisify(exec); -const dockerCLI = Bun.which("docker") as string; +const dockerCLI = dockerExe() as string; async function findRandomPort() { return new Promise((resolve, reject) => { @@ -199,7 +199,7 @@ if (isDockerEnabled()) { const searchs = await db\` WITH cte AS ( - SELECT + SELECT post.id, post."content", post.created_at AS "createdAt", @@ -214,10 +214,10 @@ if (isDockerEnabled()) { \${fragment} ORDER BY post.created_at DESC ) - SELECT - * + SELECT + * FROM cte - -- LIMIT 5 + -- LIMIT 5 \`; return Response.json(searchs); } catch { diff --git a/test/js/valkey/test-utils.ts b/test/js/valkey/test-utils.ts index 18d8e3c728..34ce3347cc 100644 --- a/test/js/valkey/test-utils.ts +++ b/test/js/valkey/test-utils.ts @@ -1,6 +1,6 @@ import { RedisClient, type SpawnOptions } from "bun"; import { afterAll, beforeAll, expect } from "bun:test"; -import { bunEnv, isCI, randomPort, tempDirWithFiles } from "harness"; +import { bunEnv, dockerExe, isCI, randomPort, tempDirWithFiles } from "harness"; import path from "path"; import * as dockerCompose from "../../docker/index.ts"; @@ -8,7 +8,7 @@ import { UnixDomainSocketProxy } from "../../unix-domain-socket-proxy.ts"; import * as fs from "node:fs"; import * as os from "node:os"; -const dockerCLI = Bun.which("docker") as string; +const dockerCLI = dockerExe() as string; export const isEnabled = !!dockerCLI && (() => { From 5a709a2dbf3894dd54689dc2c841e06e12152abb Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Thu, 25 Sep 2025 18:58:44 -0800 Subject: [PATCH 15/43] node:tty: use terminal VT mode on Windows (#21161) mirrors: https://github.com/nodejs/node/pull/58358 --- src/deps/libuv.zig | 1 + src/io/source.zig | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/deps/libuv.zig b/src/deps/libuv.zig index e3b1837f1b..5a59f777ce 100644 --- a/src/deps/libuv.zig +++ b/src/deps/libuv.zig @@ -1471,6 +1471,7 @@ pub const struct_uv_tty_s = extern struct { normal = 0, raw = 1, io = 2, + vt = 3, }; pub fn setMode(this: *uv_tty_t, mode: Mode) ReturnCode { diff --git a/src/io/source.zig b/src/io/source.zig index 647c5a8165..aeff0ca5bb 100644 --- a/src/io/source.zig +++ b/src/io/source.zig @@ -224,7 +224,11 @@ export fn Source__setRawModeStdin(raw: bool) c_int { .result => |tty| tty, .err => |e| return e.errno, }; - if (tty.setMode(if (raw) .raw else .normal).toError(.uv_tty_set_mode)) |err| { + // UV_TTY_MODE_RAW_VT is a variant of UV_TTY_MODE_RAW that enables control sequence processing on the TTY implementer side, + // rather than having libuv translate keypress events into control sequences, aligning behavior more closely with + // POSIX platforms. This is also required to support some control sequences at all on Windows, such as bracketed paste mode. + // The Node.js readline implementation handles differences between these modes. + if (tty.setMode(if (raw) .vt else .normal).toError(.uv_tty_set_mode)) |err| { return err.errno; } return 0; From 2039ab182d8fecc57db3bc75b1ca7307292ee8c4 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Thu, 25 Sep 2025 22:34:49 -0700 Subject: [PATCH 16/43] Remove stale path assertion on Windows (#22988) ### What does this PR do? This assertion is occasionally incorrect, and was originally added as a workaround for lack of proper error handling in zig's std library. We've seen fixed that so this assertion is no longer needed. ### How did you verify your code works? --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/string/immutable.zig | 1 - src/string/immutable/paths.zig | 19 ------------------- src/sys.zig | 7 ++----- src/sys_uv.zig | 18 ------------------ 4 files changed, 2 insertions(+), 43 deletions(-) diff --git a/src/string/immutable.zig b/src/string/immutable.zig index b478beb491..2f577f69b5 100644 --- a/src/string/immutable.zig +++ b/src/string/immutable.zig @@ -2313,7 +2313,6 @@ pub const addNTPathPrefix = paths_.addNTPathPrefix; pub const addNTPathPrefixIfNeeded = paths_.addNTPathPrefixIfNeeded; pub const addLongPathPrefix = paths_.addLongPathPrefix; pub const addLongPathPrefixIfNeeded = paths_.addLongPathPrefixIfNeeded; -pub const assertIsValidWindowsPath = paths_.assertIsValidWindowsPath; pub const charIsAnySlash = paths_.charIsAnySlash; pub const cloneNormalizingSeparators = paths_.cloneNormalizingSeparators; pub const fromWPath = paths_.fromWPath; diff --git a/src/string/immutable/paths.zig b/src/string/immutable/paths.zig index f7cf4045d1..079d6f5cf4 100644 --- a/src/string/immutable/paths.zig +++ b/src/string/immutable/paths.zig @@ -288,24 +288,6 @@ fn isUNCPath(comptime T: type, path: []const T) bool { !bun.path.Platform.windows.isSeparatorT(T, path[2]) and path[2] != '.'; } -pub fn assertIsValidWindowsPath(comptime T: type, path: []const T) void { - if (Environment.allow_assert and Environment.isWindows) { - if (bun.path.Platform.windows.isAbsoluteT(T, path) and - isWindowsAbsolutePathMissingDriveLetter(T, path) and - // is it a null device path? that's not an error. it's just a weird file path. - !eqlComptimeT(T, path, "\\\\.\\NUL") and !eqlComptimeT(T, path, "\\\\.\\nul") and !eqlComptimeT(T, path, "\\nul") and !eqlComptimeT(T, path, "\\NUL") and !isUNCPath(T, path)) - { - std.debug.panic("Internal Error: Do not pass posix paths to Windows APIs, was given '{s}'" ++ if (Environment.isDebug) " (missing a root like 'C:\\', see PosixToWinNormalizer for why this is an assertion)" else ". Please open an issue on GitHub with a reproduction.", .{ - if (T == u8) path else bun.fmt.utf16(path), - }); - } - if (hasPrefixComptimeType(T, path, ":/") and Environment.isDebug) { - std.debug.panic("Path passed to windows API '{s}' is almost certainly invalid. Where did the drive letter go?", .{ - if (T == u8) path else bun.fmt.utf16(path), - }); - } - } -} pub fn toWPathMaybeDir(wbuf: []u16, utf8: []const u8, comptime add_trailing_lash: bool) [:0]u16 { bun.unsafeAssert(wbuf.len > 0); @@ -518,7 +500,6 @@ const assert = bun.assert; const strings = bun.strings; const copyUTF16IntoUTF8 = strings.copyUTF16IntoUTF8; -const eqlComptimeT = strings.eqlComptimeT; const hasPrefixComptime = strings.hasPrefixComptime; const hasPrefixComptimeType = strings.hasPrefixComptimeType; const hasPrefixComptimeUTF16 = strings.hasPrefixComptimeUTF16; diff --git a/src/sys.zig b/src/sys.zig index a800b68d42..4bf87f89f5 100644 --- a/src/sys.zig +++ b/src/sys.zig @@ -660,7 +660,7 @@ pub fn mkdirA(file_path: []const u8, flags: mode_t) Maybe(void) { const wbuf = bun.w_path_buffer_pool.get(); defer bun.w_path_buffer_pool.put(wbuf); const wpath = bun.strings.toKernel32Path(wbuf, file_path); - assertIsValidWindowsPath(u16, wpath); + return Maybe(void).errnoSysP( kernel32.CreateDirectoryW(wpath.ptr, null), .mkdir, @@ -811,7 +811,7 @@ fn openDirAtWindowsNtPath( const no_follow = options.no_follow; const can_rename_or_delete = options.can_rename_or_delete; const read_only = options.read_only; - assertIsValidWindowsPath(u16, path); + const base_flags = w.STANDARD_RIGHTS_READ | w.FILE_READ_ATTRIBUTES | w.FILE_READ_EA | w.SYNCHRONIZE | w.FILE_TRAVERSE; const iterable_flag: u32 = if (iterable) w.FILE_LIST_DIRECTORY else 0; @@ -1010,7 +1010,6 @@ pub fn openFileAtWindowsNtPath( // this path is probably already backslash normalized so we're only going to check for '.\' // const path = if (bun.strings.hasPrefixComptimeUTF16(path_maybe_leading_dot, ".\\")) path_maybe_leading_dot[2..] else path_maybe_leading_dot; // bun.assert(!bun.strings.hasPrefixComptimeUTF16(path_maybe_leading_dot, "./")); - assertIsValidWindowsPath(u16, path); var result: windows.HANDLE = undefined; @@ -2624,7 +2623,6 @@ pub fn mmap( } pub fn mmapFile(path: [:0]const u8, flags: std.c.MAP, wanted_size: ?usize, offset: usize) Maybe([]align(page_size_min) u8) { - assertIsValidWindowsPath(u8, path); const fd = switch (open(path, bun.O.RDWR, 0)) { .result => |fd| fd, .err => |err| return .{ .err = err }, @@ -4156,7 +4154,6 @@ const MAX_PATH_BYTES = bun.MAX_PATH_BYTES; const c = bun.c; // translated c headers const jsc = bun.jsc; const libc_stat = bun.Stat; -const assertIsValidWindowsPath = bun.strings.assertIsValidWindowsPath; const darwin_nocancel = bun.darwin.nocancel; const windows = bun.windows; diff --git a/src/sys_uv.zig b/src/sys_uv.zig index 7ac44f005a..4e01d9bb5f 100644 --- a/src/sys_uv.zig +++ b/src/sys_uv.zig @@ -19,8 +19,6 @@ pub const access = bun.sys.access; // Note: `req = undefined; req.deinit()` has a safety-check in a debug build pub fn open(file_path: [:0]const u8, c_flags: i32, _perm: bun.Mode) Maybe(bun.FileDescriptor) { - assertIsValidWindowsPath(u8, file_path); - var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); @@ -41,7 +39,6 @@ pub fn open(file_path: [:0]const u8, c_flags: i32, _perm: bun.Mode) Maybe(bun.Fi } pub fn mkdir(file_path: [:0]const u8, flags: bun.Mode) Maybe(void) { - assertIsValidWindowsPath(u8, file_path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_mkdir(uv.Loop.get(), &req, file_path.ptr, flags, null); @@ -54,7 +51,6 @@ pub fn mkdir(file_path: [:0]const u8, flags: bun.Mode) Maybe(void) { } pub fn chmod(file_path: [:0]const u8, flags: bun.Mode) Maybe(void) { - assertIsValidWindowsPath(u8, file_path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); @@ -81,7 +77,6 @@ pub fn fchmod(fd: FileDescriptor, flags: bun.Mode) Maybe(void) { } pub fn statfs(file_path: [:0]const u8) Maybe(bun.StatFS) { - assertIsValidWindowsPath(u8, file_path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_statfs(uv.Loop.get(), &req, file_path.ptr, null); @@ -94,7 +89,6 @@ pub fn statfs(file_path: [:0]const u8) Maybe(bun.StatFS) { } pub fn chown(file_path: [:0]const u8, uid: uv.uv_uid_t, gid: uv.uv_uid_t) Maybe(void) { - assertIsValidWindowsPath(u8, file_path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_chown(uv.Loop.get(), &req, file_path.ptr, uid, gid, null); @@ -121,7 +115,6 @@ pub fn fchown(fd: FileDescriptor, uid: uv.uv_uid_t, gid: uv.uv_uid_t) Maybe(void } pub fn rmdir(file_path: [:0]const u8) Maybe(void) { - assertIsValidWindowsPath(u8, file_path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_rmdir(uv.Loop.get(), &req, file_path.ptr, null); @@ -134,7 +127,6 @@ pub fn rmdir(file_path: [:0]const u8) Maybe(void) { } pub fn unlink(file_path: [:0]const u8) Maybe(void) { - assertIsValidWindowsPath(u8, file_path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_unlink(uv.Loop.get(), &req, file_path.ptr, null); @@ -147,7 +139,6 @@ pub fn unlink(file_path: [:0]const u8) Maybe(void) { } pub fn readlink(file_path: [:0]const u8, buf: []u8) Maybe([:0]u8) { - assertIsValidWindowsPath(u8, file_path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); // Edge cases: http://docs.libuv.org/en/v1.x/fs.html#c.uv_fs_realpath @@ -172,8 +163,6 @@ pub fn readlink(file_path: [:0]const u8, buf: []u8) Maybe([:0]u8) { } pub fn rename(from: [:0]const u8, to: [:0]const u8) Maybe(void) { - assertIsValidWindowsPath(u8, from); - assertIsValidWindowsPath(u8, to); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_rename(uv.Loop.get(), &req, from.ptr, to.ptr, null); @@ -187,8 +176,6 @@ pub fn rename(from: [:0]const u8, to: [:0]const u8) Maybe(void) { } pub fn link(from: [:0]const u8, to: [:0]const u8) Maybe(void) { - assertIsValidWindowsPath(u8, from); - assertIsValidWindowsPath(u8, to); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_link(uv.Loop.get(), &req, from.ptr, to.ptr, null); @@ -201,8 +188,6 @@ pub fn link(from: [:0]const u8, to: [:0]const u8) Maybe(void) { } pub fn symlinkUV(target: [:0]const u8, new_path: [:0]const u8, flags: c_int) Maybe(void) { - assertIsValidWindowsPath(u8, target); - assertIsValidWindowsPath(u8, new_path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_symlink(uv.Loop.get(), &req, target.ptr, new_path.ptr, flags, null); @@ -267,7 +252,6 @@ pub fn fsync(fd: FileDescriptor) Maybe(void) { } pub fn stat(path: [:0]const u8) Maybe(bun.Stat) { - assertIsValidWindowsPath(u8, path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_stat(uv.Loop.get(), &req, path.ptr, null); @@ -280,7 +264,6 @@ pub fn stat(path: [:0]const u8) Maybe(bun.Stat) { } pub fn lstat(path: [:0]const u8) Maybe(bun.Stat) { - assertIsValidWindowsPath(u8, path); var req: uv.fs_t = uv.fs_t.uninitialized; defer req.deinit(); const rc = uv.uv_fs_lstat(uv.Loop.get(), &req, path.ptr, null); @@ -402,5 +385,4 @@ const bun = @import("bun"); const Environment = bun.Environment; const FileDescriptor = bun.FileDescriptor; const Maybe = bun.sys.Maybe; -const assertIsValidWindowsPath = bun.strings.assertIsValidWindowsPath; const uv = bun.windows.libuv; From 656747bcf118c4da54dfe8e9a716005e7318b896 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Thu, 25 Sep 2025 22:41:02 -0700 Subject: [PATCH 17/43] Fix vm destruction assertion failure in udp socket, reduce usage of protect() (#22986) ### What does this PR do? ### How did you verify your code works? --------- Co-authored-by: Claude Bot Co-authored-by: Claude Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/bun.js/api/bun/udp_socket.zig | 267 +++++++++++++++--------------- src/bun.js/api/sockets.classes.ts | 2 +- src/bun.js/jsc/array_buffer.zig | 4 +- test/internal/ban-limits.json | 2 +- 4 files changed, 138 insertions(+), 137 deletions(-) diff --git a/src/bun.js/api/bun/udp_socket.zig b/src/bun.js/api/bun/udp_socket.zig index 1d0a78a14a..0be967a24a 100644 --- a/src/bun.js/api/bun/udp_socket.zig +++ b/src/bun.js/api/bun/udp_socket.zig @@ -14,21 +14,23 @@ fn onClose(socket: *uws.udp.Socket) callconv(.C) void { const this: *UDPSocket = bun.cast(*UDPSocket, socket.user().?); this.closed = true; this.poll_ref.disable(); - _ = this.js_refcount.fetchSub(1, .monotonic); + this.this_value.downgrade(); + this.socket = null; } fn onDrain(socket: *uws.udp.Socket) callconv(.C) void { jsc.markBinding(@src()); const this: *UDPSocket = bun.cast(*UDPSocket, socket.user().?); - const callback = this.config.on_drain; - if (callback == .zero) return; + const thisValue = this.this_value.tryGet() orelse return; + const callback = UDPSocket.js.gc.on_drain.get(thisValue) orelse return; + if (callback.isEmptyOrUndefinedOrNull()) return; const vm = jsc.VirtualMachine.get(); const event_loop = vm.eventLoop(); event_loop.enter(); defer event_loop.exit(); - _ = callback.call(this.globalThis, this.thisValue, &.{this.thisValue}) catch |err| { + _ = callback.call(this.globalThis, thisValue, &.{thisValue}) catch |err| { this.callErrorHandler(.zero, this.globalThis.takeException(err)); }; } @@ -37,10 +39,12 @@ fn onData(socket: *uws.udp.Socket, buf: *uws.udp.PacketBuffer, packets: c_int) c jsc.markBinding(@src()); const udpSocket: *UDPSocket = bun.cast(*UDPSocket, socket.user().?); - const callback = udpSocket.config.on_data; - if (callback == .zero) return; + const thisValue = udpSocket.this_value.tryGet() orelse return; + const callback = UDPSocket.js.gc.on_data.get(thisValue) orelse return; + if (callback.isEmptyOrUndefinedOrNull()) return; const globalThis = udpSocket.globalThis; + defer thisValue.ensureStillAlive(); var i: c_int = 0; while (i < packets) : (i += 1) { @@ -73,12 +77,6 @@ fn onData(socket: *uws.udp.Socket, buf: *uws.udp.PacketBuffer, packets: c_int) c const slice = buf.getPayload(i); - const loop = udpSocket.vm.eventLoop(); - loop.enter(); - defer loop.exit(); - _ = udpSocket.js_refcount.fetchAdd(1, .monotonic); - defer _ = udpSocket.js_refcount.fetchSub(1, .monotonic); - const span = std.mem.span(hostname.?); var hostname_string = if (scope_id) |id| blk: { if (comptime !bun.Environment.isWindows) { @@ -91,13 +89,18 @@ fn onData(socket: *uws.udp.Socket, buf: *uws.udp.PacketBuffer, packets: c_int) c break :blk bun.handleOom(bun.String.createFormat("{s}%{d}", .{ span, id })); } else bun.String.init(span); - _ = callback.call(globalThis, udpSocket.thisValue, &.{ - udpSocket.thisValue, + const loop = udpSocket.vm.eventLoop(); + loop.enter(); + defer loop.exit(); + defer thisValue.ensureStillAlive(); + + _ = callback.call(globalThis, thisValue, &.{ + thisValue, udpSocket.config.binary_type.toJS(slice, globalThis) catch return, // TODO: properly propagate exception upwards .jsNumber(port), hostname_string.transferToJS(globalThis), }) catch |err| { - udpSocket.callErrorHandler(.zero, udpSocket.globalThis.takeException(err)); + udpSocket.callErrorHandler(.zero, globalThis.takeException(err)); }; } } @@ -112,37 +115,20 @@ pub const UDPSocketConfig = struct { const ConnectConfig = struct { port: u16, - address: [:0]u8, + address: bun.String, }; - hostname: [:0]u8, + hostname: bun.String = .empty, connect: ?ConnectConfig = null, - port: u16, - flags: i32, + port: u16 = 0, + flags: i32 = 0, binary_type: jsc.ArrayBuffer.BinaryType = .Buffer, - on_data: JSValue = .zero, - on_drain: JSValue = .zero, - on_error: JSValue = .zero, - pub fn fromJS(globalThis: *JSGlobalObject, options: JSValue) bun.JSError!This { + pub fn fromJS(globalThis: *JSGlobalObject, options: JSValue, thisValue: JSValue) bun.JSError!This { if (options.isEmptyOrUndefinedOrNull() or !options.isObject()) { return globalThis.throwInvalidArguments("Expected an object", .{}); } - const hostname = brk: { - if (try options.getTruthy(globalThis, "hostname")) |value| { - if (!value.isString()) { - return globalThis.throwInvalidArguments("Expected \"hostname\" to be a string", .{}); - } - const str = value.toBunString(globalThis) catch @panic("unreachable"); - defer str.deref(); - break :brk bun.handleOom(str.toOwnedSliceZ(default_allocator)); - } else { - break :brk bun.handleOom(default_allocator.dupeZ(u8, "0.0.0.0")); - } - }; - defer if (globalThis.hasException()) default_allocator.free(hostname); - const port: u16 = brk: { if (try options.getTruthy(globalThis, "port")) |value| { const number = try value.coerceToInt32(globalThis); @@ -160,12 +146,25 @@ pub const UDPSocketConfig = struct { else 0; + const hostname = brk: { + if (try options.getTruthy(globalThis, "hostname")) |value| { + if (!value.isString()) { + return globalThis.throwInvalidArguments("Expected \"hostname\" to be a string", .{}); + } + break :brk try value.toBunString(globalThis); + } else { + break :brk bun.String.static("0.0.0.0"); + } + }; + var config = This{ .hostname = hostname, .port = port, .flags = flags, }; + errdefer config.deinit(); + if (try options.getTruthy(globalThis, "socket")) |socket| { if (!socket.isObject()) { return globalThis.throwInvalidArguments("Expected \"socket\" to be an object", .{}); @@ -186,15 +185,8 @@ pub const UDPSocketConfig = struct { if (!value.isCell() or !value.isCallable()) { return globalThis.throwInvalidArguments("Expected \"socket.{s}\" to be a function", .{handler.@"0"}); } - @field(config, handler.@"1") = value.withAsyncContextIfNeeded(globalThis); - } - } - } - - defer { - if (globalThis.hasException()) { - if (config.connect) |connect| { - default_allocator.free(connect.address); + const callback = value.withAsyncContextIfNeeded(globalThis); + UDPSocket.js.gc.set(@field(UDPSocket.js.gc, handler.@"1"), thisValue, globalThis, callback); } } } @@ -217,9 +209,7 @@ pub const UDPSocketConfig = struct { }; const connect_port = try connect_port_js.coerceToInt32(globalThis); - const str = try connect_host_js.toBunString(globalThis); - defer str.deref(); - const connect_host = bun.handleOom(str.toOwnedSliceZ(default_allocator)); + const connect_host = try connect_host_js.toBunString(globalThis); config.connect = .{ .port = if (connect_port < 1 or connect_port > 0xffff) 0 else @as(u16, @intCast(connect_port)), @@ -227,28 +217,13 @@ pub const UDPSocketConfig = struct { }; } - config.protect(); - return config; } - pub fn protect(this: This) void { - inline for (handlers) |handler| { - @field(this, handler.@"1").protect(); - } - } - - pub fn unprotect(this: This) void { - inline for (handlers) |handler| { - @field(this, handler.@"1").unprotect(); - } - } - - pub fn deinit(this: This) void { - this.unprotect(); - default_allocator.free(this.hostname); - if (this.connect) |val| { - default_allocator.free(val.address); + pub fn deinit(this: *This) void { + this.hostname.deref(); + if (this.connect) |*val| { + val.address.deref(); } } }; @@ -258,11 +233,11 @@ pub const UDPSocket = struct { config: UDPSocketConfig, - socket: *uws.udp.Socket, + socket: ?*uws.udp.Socket = null, loop: *uws.Loop, globalThis: *JSGlobalObject, - thisValue: JSValue = .zero, + this_value: JSRef = JSRef.empty(), jsc_ref: jsc.Ref = jsc.Ref.init(), poll_ref: Async.KeepAlive = Async.KeepAlive.init(), @@ -270,7 +245,6 @@ pub const UDPSocket = struct { closed: bool = false, connect_info: ?ConnectInfo = null, vm: *jsc.VirtualMachine, - js_refcount: std.atomic.Value(usize) = std.atomic.Value(usize).init(1), const ConnectInfo = struct { port: u16, @@ -281,101 +255,113 @@ pub const UDPSocket = struct { pub const fromJS = js.fromJS; pub const fromJSDirect = js.fromJSDirect; - pub fn hasPendingActivity(this: *This) callconv(.C) bool { - return this.js_refcount.load(.monotonic) > 0; - } - pub const new = bun.TrivialNew(@This()); pub fn udpSocket(globalThis: *JSGlobalObject, options: JSValue) bun.JSError!JSValue { log("udpSocket", .{}); - const config = try UDPSocketConfig.fromJS(globalThis, options); - const vm = globalThis.bunVM(); var this = This.new(.{ - .socket = undefined, - .config = config, + .socket = null, + .config = .{}, .globalThis = globalThis, .loop = uws.Loop.get(), .vm = vm, }); + errdefer { + this.closed = true; + if (this.socket) |socket| { + this.socket = null; + socket.close(); + } + + // Do not deinit, rely on GC to free it. + } + const thisValue = this.toJS(globalThis); + thisValue.ensureStillAlive(); + this.this_value.setStrong(thisValue, globalThis); + + this.config = try UDPSocketConfig.fromJS(globalThis, options, thisValue); var err: i32 = 0; - if (uws.udp.Socket.create( + const hostname_slice = this.config.hostname.toUTF8(bun.default_allocator); + defer hostname_slice.deinit(); + const hostname_z = bun.handleOom(bun.default_allocator.dupeZ(u8, hostname_slice.slice())); + defer bun.default_allocator.free(hostname_z); + + this.socket = uws.udp.Socket.create( this.loop, onData, onDrain, onClose, - config.hostname, - config.port, - config.flags, + hostname_z, + this.config.port, + this.config.flags, &err, this, - )) |socket| { - this.socket = socket; - } else { + ) orelse { this.closed = true; - defer this.deinit(); if (err != 0) { const code = @tagName(bun.sys.SystemErrno.init(@as(c_int, @intCast(err))).?); const sys_err = jsc.SystemError{ .errno = err, .code = bun.String.static(code), - .message = bun.handleOom(bun.String.createFormat("bind {s} {s}", .{ code, config.hostname })), + .message = bun.handleOom(bun.String.createFormat("bind {s} {}", .{ code, this.config.hostname })), }; const error_value = sys_err.toErrorInstance(globalThis); - error_value.put(globalThis, "address", try bun.String.createUTF8ForJS(globalThis, config.hostname)); + error_value.put(globalThis, "address", this.config.hostname.toJS(globalThis)); + return globalThis.throwValue(error_value); } + return globalThis.throw("Failed to bind socket", .{}); - } + }; - errdefer { - this.socket.close(); - this.deinit(); - } - - if (config.connect) |connect| { - const ret = this.socket.connect(connect.address, connect.port); + if (this.config.connect) |*connect| { + const address_slice = connect.address.toUTF8(bun.default_allocator); + defer address_slice.deinit(); + const address_z = bun.handleOom(bun.default_allocator.dupeZ(u8, address_slice.slice())); + defer bun.default_allocator.free(address_z); + const ret = this.socket.?.connect(address_z, connect.port); if (ret != 0) { - if (bun.sys.Maybe(void).errnoSys(ret, .connect)) |sys_err| { - return globalThis.throwValue(try sys_err.toJS(globalThis)); + if (bun.sys.Maybe(void).errnoSys(ret, .connect)) |*sys_err| { + return globalThis.throwValue(sys_err.err.toJS(globalThis)); } if (bun.c_ares.Error.initEAI(ret)) |eai_err| { - return globalThis.throwValue(eai_err.toJSWithSyscallAndHostname(globalThis, "connect", connect.address)); + return globalThis.throwValue(eai_err.toJSWithSyscallAndHostname(globalThis, "connect", address_slice.slice())); } } this.connect_info = .{ .port = connect.port }; } this.poll_ref.ref(vm); - const thisValue = this.toJS(globalThis); - thisValue.ensureStillAlive(); - this.thisValue = thisValue; return jsc.JSPromise.resolvedPromiseValue(globalThis, thisValue); } pub fn callErrorHandler( this: *This, - thisValue: JSValue, + thisValue_: JSValue, err: JSValue, ) void { - const callback = this.config.on_error; + const thisValue = if (thisValue_ == .zero) this.this_value.tryGet() orelse return else thisValue_; + const callback = This.js.gc.on_error.get(thisValue) orelse .zero; const globalThis = this.globalThis; const vm = globalThis.bunVM(); if (err.isTerminationException()) { return; } - if (callback == .zero) { + if (callback.isEmptyOrUndefinedOrNull()) { _ = vm.uncaughtException(globalThis, err, false); return; } - _ = callback.call(globalThis, thisValue, &.{err}) catch |e| globalThis.reportActiveExceptionAsUnhandled(e); + const event_loop = vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + _ = callback.call(globalThis, thisValue, &.{err.toError() orelse err}) catch |e| globalThis.reportActiveExceptionAsUnhandled(e); } pub fn setBroadcast(this: *This, globalThis: *JSGlobalObject, callframe: *CallFrame) bun.JSError!JSValue { @@ -389,7 +375,7 @@ pub const UDPSocket = struct { } const enabled = arguments[0].toBoolean(); - const res = this.socket.setBroadcast(enabled); + const res = this.socket.?.setBroadcast(enabled); if (getUSError(res, .setsockopt, true)) |err| { return globalThis.throwValue(try err.toJS(globalThis)); @@ -409,7 +395,7 @@ pub const UDPSocket = struct { } const enabled = arguments[0].toBoolean(); - const res = this.socket.setMulticastLoopback(enabled); + const res = this.socket.?.setMulticastLoopback(enabled); if (getUSError(res, .setsockopt, true)) |err| { return globalThis.throwValue(try err.toJS(globalThis)); @@ -435,12 +421,14 @@ pub const UDPSocket = struct { var interface = std.mem.zeroes(std.posix.sockaddr.storage); + const socket = this.socket orelse return globalThis.throw("Socket is closed", .{}); + const res = if (arguments.len > 1 and try parseAddr(this, globalThis, .jsNumber(0), arguments[1], &interface)) blk: { if (addr.family != interface.family) { return globalThis.throwInvalidArguments("Family mismatch between address and interface", .{}); } - break :blk this.socket.setMembership(&addr, &interface, drop); - } else this.socket.setMembership(&addr, null, drop); + break :blk socket.setMembership(&addr, &interface, drop); + } else socket.setMembership(&addr, null, drop); if (getUSError(res, .setsockopt, true)) |err| { return globalThis.throwValue(try err.toJS(globalThis)); @@ -483,12 +471,14 @@ pub const UDPSocket = struct { var interface: std.posix.sockaddr.storage = undefined; + const socket = this.socket orelse return globalThis.throw("Socket is closed", .{}); + const res = if (arguments.len > 2 and try parseAddr(this, globalThis, .jsNumber(0), arguments[2], &interface)) blk: { if (source_addr.family != interface.family) { return globalThis.throwInvalidArguments("Family mismatch among source, group and interface addresses", .{}); } - break :blk this.socket.setSourceSpecificMembership(&source_addr, &group_addr, &interface, drop); - } else this.socket.setSourceSpecificMembership(&source_addr, &group_addr, null, drop); + break :blk socket.setSourceSpecificMembership(&source_addr, &group_addr, &interface, drop); + } else socket.setSourceSpecificMembership(&source_addr, &group_addr, null, drop); if (getUSError(res, .setsockopt, true)) |err| { return globalThis.throwValue(try err.toJS(globalThis)); @@ -521,7 +511,9 @@ pub const UDPSocket = struct { return .false; } - const res = this.socket.setMulticastInterface(&addr); + const socket = this.socket orelse return globalThis.throw("Socket is closed", .{}); + + const res = socket.setMulticastInterface(&addr); if (getUSError(res, .setsockopt, true)) |err| { return globalThis.throwValue(try err.toJS(globalThis)); @@ -572,7 +564,7 @@ pub const UDPSocket = struct { } const ttl = try arguments[0].coerceToInt32(globalThis); - const res = function(this.socket, ttl); + const res = function(this.socket.?, ttl); if (getUSError(res, .setsockopt, true)) |err| { return globalThis.throwValue(try err.toJS(globalThis)); @@ -651,7 +643,8 @@ pub const UDPSocket = struct { if (i != array_len) { return globalThis.throwInvalidArguments("Mismatch between array length property and number of items", .{}); } - const res = this.socket.send(payloads, lens, addr_ptrs); + const socket = this.socket orelse return globalThis.throw("Socket is closed", .{}); + const res = socket.send(payloads, lens, addr_ptrs); if (getUSError(res, .send, true)) |err| { return globalThis.throwValue(try err.toJS(globalThis)); } @@ -709,7 +702,8 @@ pub const UDPSocket = struct { } }; - const res = this.socket.send(&.{payload.ptr}, &.{payload.len}, &.{addr_ptr}); + const socket = this.socket orelse return globalThis.throw("Socket is closed", .{}); + const res = socket.send(&.{payload.ptr}, &.{payload.len}, &.{addr_ptr}); if (getUSError(res, .send, true)) |err| { return globalThis.throwValue(try err.toJS(globalThis)); } @@ -796,7 +790,12 @@ pub const UDPSocket = struct { _: *JSGlobalObject, _: *CallFrame, ) bun.JSError!JSValue { - if (!this.closed) this.socket.close(); + if (!this.closed) { + const socket = this.socket orelse return .js_undefined; + this.socket = null; + socket.close(); + this.this_value.downgrade(); + } return .js_undefined; } @@ -809,12 +808,12 @@ pub const UDPSocket = struct { } const options = args.ptr[0]; - const config = try UDPSocketConfig.fromJS(globalThis, options); + const thisValue = this.this_value.tryGet() orelse return .js_undefined; + const config = try UDPSocketConfig.fromJS(globalThis, options, thisValue); - config.protect(); var previous_config = this.config; - previous_config.unprotect(); this.config = config; + previous_config.deinit(); return .js_undefined; } @@ -824,13 +823,12 @@ pub const UDPSocket = struct { } pub fn getHostname(this: *This, _: *JSGlobalObject) JSValue { - const hostname = jsc.ZigString.init(this.config.hostname); - return hostname.toJS(this.globalThis); + return this.config.hostname.toJS(this.globalThis); } pub fn getPort(this: *This, _: *JSGlobalObject) JSValue { if (this.closed) return .js_undefined; - return JSValue.jsNumber(this.socket.boundPort()); + return JSValue.jsNumber(this.socket.?.boundPort()); } fn createSockAddr(globalThis: *JSGlobalObject, address_bytes: []const u8, port: u16) JSValue { @@ -842,10 +840,10 @@ pub const UDPSocket = struct { if (this.closed) return .js_undefined; var buf: [64]u8 = [_]u8{0} ** 64; var length: i32 = 64; - this.socket.boundIp(&buf, &length); + this.socket.?.boundIp(&buf, &length); const address_bytes = buf[0..@as(usize, @intCast(length))]; - const port = this.socket.boundPort(); + const port = this.socket.?.boundPort(); return createSockAddr(globalThis, address_bytes, @intCast(port)); } @@ -854,7 +852,7 @@ pub const UDPSocket = struct { const connect_info = this.connect_info orelse return .js_undefined; var buf: [64]u8 = [_]u8{0} ** 64; var length: i32 = 64; - this.socket.remoteIp(&buf, &length); + this.socket.?.remoteIp(&buf, &length); const address_bytes = buf[0..@as(usize, @intCast(length))]; return createSockAddr(globalThis, address_bytes, connect_info.port); @@ -874,15 +872,15 @@ pub const UDPSocket = struct { pub fn finalize(this: *This) void { log("Finalize {*}", .{this}); + this.this_value.finalize(); this.deinit(); } pub fn deinit(this: *This) void { - // finalize is only called when js_refcount reaches 0 - // js_refcount can only reach 0 when the socket is closed - bun.assert(this.closed); + bun.assert(this.closed or this.vm.isShuttingDown()); this.poll_ref.disable(); this.config.deinit(); + this.this_value.deinit(); bun.destroy(this); } @@ -919,15 +917,15 @@ pub const UDPSocket = struct { const connect_port = connect_port_js.asInt32(); const port: u16 = if (connect_port < 1 or connect_port > 0xffff) 0 else @as(u16, @intCast(connect_port)); - if (this.socket.connect(connect_host, port) == -1) { + if (this.socket.?.connect(connect_host, port) == -1) { return globalThis.throw("Failed to connect socket", .{}); } this.connect_info = .{ .port = port, }; - js.addressSetCached(callFrame.this(), globalThis, .zero); - js.remoteAddressSetCached(callFrame.this(), globalThis, .zero); + This.js.addressSetCached(callFrame.this(), globalThis, .zero); + This.js.remoteAddressSetCached(callFrame.this(), globalThis, .zero); return .js_undefined; } @@ -945,7 +943,7 @@ pub const UDPSocket = struct { return globalObject.throw("Socket is closed", .{}); } - if (this.socket.disconnect() == -1) { + if (this.socket.?.disconnect() == -1) { return globalObject.throw("Failed to disconnect socket", .{}); } this.connect_info = null; @@ -965,5 +963,6 @@ const default_allocator = bun.default_allocator; const jsc = bun.jsc; const CallFrame = jsc.CallFrame; const JSGlobalObject = jsc.JSGlobalObject; +const JSRef = jsc.JSRef; const JSValue = jsc.JSValue; const SocketAddress = jsc.API.SocketAddress; diff --git a/src/bun.js/api/sockets.classes.ts b/src/bun.js/api/sockets.classes.ts index 833651452e..13d89453ff 100644 --- a/src/bun.js/api/sockets.classes.ts +++ b/src/bun.js/api/sockets.classes.ts @@ -302,7 +302,7 @@ export default [ JSType: "0b11101110", finalize: true, construct: true, - hasPendingActivity: true, + values: ["on_data", "on_drain", "on_error"], proto: { send: { fn: "send", diff --git a/src/bun.js/jsc/array_buffer.zig b/src/bun.js/jsc/array_buffer.zig index 87bd1f59ec..9751b476b6 100644 --- a/src/bun.js/jsc/array_buffer.zig +++ b/src/bun.js/jsc/array_buffer.zig @@ -435,7 +435,9 @@ pub const ArrayBuffer = extern struct { pub fn fromJSValue(globalThis: *jsc.JSGlobalObject, input: jsc.JSValue) bun.JSError!?BinaryType { if (input.isString()) { - return Map.getWithEql(try input.toBunString(globalThis), bun.String.eqlComptime); + const str = try input.toBunString(globalThis); + defer str.deref(); + return Map.getWithEql(str, bun.String.eqlComptime); } return null; diff --git a/test/internal/ban-limits.json b/test/internal/ban-limits.json index bee82aa703..39f2c58676 100644 --- a/test/internal/ban-limits.json +++ b/test/internal/ban-limits.json @@ -23,7 +23,7 @@ "allocator.ptr ==": 0, "global.hasException": 28, "globalObject.hasException": 47, - "globalThis.hasException": 127, + "globalThis.hasException": 125, "std.StringArrayHashMap(": 1, "std.StringArrayHashMapUnmanaged(": 11, "std.StringHashMap(": 0, From c4519c75521220ade3d125ac328f1fb9dbdfe9b2 Mon Sep 17 00:00:00 2001 From: pfg Date: Thu, 25 Sep 2025 23:47:46 -0700 Subject: [PATCH 18/43] Add --randomize --seed flag (#22987) Outputs the seed when randomizing. Adds --seed flag to reproduce a random order. Seeds might not produce the same order across operating systems / bun versions. Fixes #11847 --------- Co-authored-by: Claude Bot Co-authored-by: Claude Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- docs/cli/test.md | 30 +++ src/bun.js/test/Order.zig | 8 +- src/bun.js/test/bun_test.zig | 7 +- src/bun.js/test/jest.zig | 2 +- src/cli.zig | 1 + src/cli/Arguments.zig | 9 + src/cli/test_command.zig | 17 +- test/cli/test/test-randomize.fixture.ts | 3 + test/cli/test/test-randomize.test.ts | 323 ++++++------------------ 9 files changed, 149 insertions(+), 251 deletions(-) create mode 100644 test/cli/test/test-randomize.fixture.ts diff --git a/docs/cli/test.md b/docs/cli/test.md index bd636bcc39..e3729b8b62 100644 --- a/docs/cli/test.md +++ b/docs/cli/test.md @@ -117,6 +117,36 @@ Use the `--rerun-each` flag to run each test multiple times. This is useful for $ bun test --rerun-each 100 ``` +## Randomize test execution order + +Use the `--randomize` flag to run tests in a random order. This helps detect tests that depend on shared state or execution order. + +```sh +$ bun test --randomize +``` + +When using `--randomize`, the seed used for randomization will be displayed in the test summary: + +```sh +$ bun test --randomize +# ... test output ... + --seed=12345 + 2 pass + 8 fail +Ran 10 tests across 2 files. [50.00ms] +``` + +### Reproducible random order with `--seed` + +Use the `--seed` flag to specify a seed for the randomization. This allows you to reproduce the same test order when debugging order-dependent failures. + +```sh +# Reproduce a previous randomized run +$ bun test --seed 123456 +``` + +The `--seed` flag implies `--randomize`, so you don't need to specify both. Using the same seed value will always produce the same test execution order, making it easier to debug intermittent failures caused by test interdependencies. + ## Bail out with `--bail` Use the `--bail` flag to abort the test run early after a pre-determined number of test failures. By default Bun will run all tests and report all failures, but sometimes in CI environments it's preferable to terminate earlier to reduce CPU usage. diff --git a/src/bun.js/test/Order.zig b/src/bun.js/test/Order.zig index 3adae8f64b..9f220fafc8 100644 --- a/src/bun.js/test/Order.zig +++ b/src/bun.js/test/Order.zig @@ -37,8 +37,8 @@ pub const AllOrderResult = struct { } }; pub const Config = struct { - always_use_hooks: bool = false, - randomize: bool = false, + always_use_hooks: bool, + randomize: ?std.Random, }; pub fn generateAllOrder(this: *Order, entries: []const *ExecutionEntry, _: Config) bun.JSError!AllOrderResult { const start = this.groups.items.len; @@ -63,9 +63,7 @@ pub fn generateOrderDescribe(this: *Order, current: *DescribeScope, cfg: Config) const beforeall_order: AllOrderResult = if (use_hooks) try generateAllOrder(this, current.beforeAll.items, cfg) else .empty; // shuffle entries if randomize flag is set - if (cfg.randomize) { - var prng = std.Random.DefaultPrng.init(bun.fastRandom()); - const random = prng.random(); + if (cfg.randomize) |random| { random.shuffle(TestScheduleEntry, current.entries.items); } diff --git a/src/bun.js/test/bun_test.zig b/src/bun.js/test/bun_test.zig index 1150c8c916..f25c65fb06 100644 --- a/src/bun.js/test/bun_test.zig +++ b/src/bun.js/test/bun_test.zig @@ -514,8 +514,11 @@ pub const BunTest = struct { defer order.deinit(); const has_filter = if (this.reporter) |reporter| if (reporter.jest.filter_regex) |_| true else false else false; - const should_randomize = if (this.reporter) |reporter| reporter.jest.randomize else false; - const cfg: Order.Config = .{ .always_use_hooks = this.collection.root_scope.base.only == .no and !has_filter, .randomize = should_randomize }; + const should_randomize: ?std.Random = if (this.reporter) |reporter| reporter.jest.randomize else null; + const cfg: Order.Config = .{ + .always_use_hooks = this.collection.root_scope.base.only == .no and !has_filter, + .randomize = should_randomize, + }; const beforeall_order: Order.AllOrderResult = if (cfg.always_use_hooks or this.collection.root_scope.base.has_callback) try order.generateAllOrder(this.buntest.hook_scope.beforeAll.items, cfg) else .empty; try order.generateOrderDescribe(this.collection.root_scope, cfg); beforeall_order.setFailureSkipTo(&order); diff --git a/src/bun.js/test/jest.zig b/src/bun.js/test/jest.zig index 013bbcb5f9..8ec2dbdfd4 100644 --- a/src/bun.js/test/jest.zig +++ b/src/bun.js/test/jest.zig @@ -55,7 +55,7 @@ pub const TestRunner = struct { only: bool = false, run_todo: bool = false, concurrent: bool = false, - randomize: bool = false, + randomize: ?std.Random = null, concurrent_test_glob: ?[]const []const u8 = null, last_file: u64 = 0, bail: u32 = 0, diff --git a/src/cli.zig b/src/cli.zig index f192591414..4461e8f92c 100644 --- a/src/cli.zig +++ b/src/cli.zig @@ -340,6 +340,7 @@ pub const Command = struct { only: bool = false, concurrent: bool = false, randomize: bool = false, + seed: ?u32 = null, concurrent_test_glob: ?[]const []const u8 = null, bail: u32 = 0, coverage: TestCommand.CodeCoverageOptions = .{}, diff --git a/src/cli/Arguments.zig b/src/cli/Arguments.zig index 1828060979..324c8165e2 100644 --- a/src/cli/Arguments.zig +++ b/src/cli/Arguments.zig @@ -198,6 +198,7 @@ pub const test_only_params = [_]ParamType{ clap.parseParam("--todo Include tests that are marked with \"test.todo()\"") catch unreachable, clap.parseParam("--concurrent Treat all tests as `test.concurrent()` tests") catch unreachable, clap.parseParam("--randomize Run tests in random order") catch unreachable, + clap.parseParam("--seed Set the random seed for test randomization") catch unreachable, clap.parseParam("--coverage Generate a coverage profile") catch unreachable, clap.parseParam("--coverage-reporter ... Report coverage in 'text' and/or 'lcov'. Defaults to 'text'.") catch unreachable, clap.parseParam("--coverage-dir Directory for coverage files. Defaults to 'coverage'.") catch unreachable, @@ -497,6 +498,14 @@ pub fn parse(allocator: std.mem.Allocator, ctx: Command.Context, comptime cmd: C ctx.test_options.run_todo = args.flag("--todo"); ctx.test_options.concurrent = args.flag("--concurrent"); ctx.test_options.randomize = args.flag("--randomize"); + + if (args.option("--seed")) |seed_str| { + ctx.test_options.randomize = true; + ctx.test_options.seed = std.fmt.parseInt(u32, seed_str, 10) catch { + Output.prettyErrorln("error: Invalid seed value: {s}", .{seed_str}); + std.process.exit(1); + }; + } } ctx.args.absolute_working_dir = cwd; diff --git a/src/cli/test_command.zig b/src/cli/test_command.zig index 924c415445..e155123ab8 100644 --- a/src/cli/test_command.zig +++ b/src/cli/test_command.zig @@ -1280,6 +1280,11 @@ pub const TestCommand = struct { bun.jsc.initialize(false); HTTPThread.init(&.{}); + const enable_random = ctx.test_options.randomize; + const seed: u32 = if (enable_random) ctx.test_options.seed orelse @truncate(bun.fastRandom()) else 0; // seed is limited to u32 so storing it in js doesn't lose precision + var random_instance: ?std.Random.DefaultPrng = if (enable_random) std.Random.DefaultPrng.init(seed) else null; + const random = if (random_instance) |*instance| instance.random() else null; + var snapshot_file_buf = std.ArrayList(u8).init(ctx.allocator); var snapshot_values = Snapshots.ValuesHashMap.init(ctx.allocator); var snapshot_counts = bun.StringHashMap(usize).init(ctx.allocator); @@ -1301,7 +1306,7 @@ pub const TestCommand = struct { .allocator = ctx.allocator, .default_timeout_ms = ctx.test_options.default_timeout_ms, .concurrent = ctx.test_options.concurrent, - .randomize = ctx.test_options.randomize, + .randomize = random, .concurrent_test_glob = ctx.test_options.concurrent_test_glob, .run_todo = ctx.test_options.run_todo, .only = ctx.test_options.only, @@ -1475,6 +1480,11 @@ pub const TestCommand = struct { const search_count = scanner.search_count; if (test_files.len > 0) { + // Randomize the order of test files if --randomize flag is set + if (random) |rand| { + rand.shuffle(PathString, test_files); + } + vm.hot_reload = ctx.debug.hot_reload; switch (vm.hot_reload) { @@ -1607,6 +1617,11 @@ pub const TestCommand = struct { const did_label_filter_out_all_tests = summary.didLabelFilterOutAllTests() and reporter.jest.unhandled_errors_between_tests == 0; if (!did_label_filter_out_all_tests) { + // Display the random seed if tests were randomized + if (random != null) { + Output.prettyError(" --seed={d}\n", .{seed}); + } + if (summary.pass > 0) { Output.prettyError("", .{}); } diff --git a/test/cli/test/test-randomize.fixture.ts b/test/cli/test/test-randomize.fixture.ts new file mode 100644 index 0000000000..a7d9f18359 --- /dev/null +++ b/test/cli/test/test-randomize.fixture.ts @@ -0,0 +1,3 @@ +test.each(Array.from({ length: 100 }, (_, i) => i + 1))("many %d", item => { + console.log(item); +}); diff --git a/test/cli/test/test-randomize.test.ts b/test/cli/test/test-randomize.test.ts index b53dfc2804..d5754ba0d7 100644 --- a/test/cli/test/test-randomize.test.ts +++ b/test/cli/test/test-randomize.test.ts @@ -1,256 +1,95 @@ import { expect, test } from "bun:test"; -import { bunEnv, bunExe, tempDir } from "harness"; -import { join } from "path"; +import { bunEnv, bunExe, tempDirWithFiles } from "harness"; -test("--randomize flag randomizes test execution order", async () => { - // Create a test file with multiple tests that output their names - using dir = tempDir("test-randomize", {}); - const testFile = join(String(dir), "order.test.js"); +// test: +// --randomize randomizes +// output produces a seed which produces the same result +// --seed produces the same result twice - await Bun.write( - testFile, - ` - import { test } from "bun:test"; - - test("test-01", () => { - console.log("test-01"); - }); - - test("test-02", () => { - console.log("test-02"); - }); - - test("test-03", () => { - console.log("test-03"); - }); - - test("test-04", () => { - console.log("test-04"); - }); - - test("test-05", () => { - console.log("test-05"); - }); - - test("test-06", () => { - console.log("test-06"); - }); - - test("test-07", () => { - console.log("test-07"); - }); - - test("test-08", () => { - console.log("test-08"); - }); - `, - ); - - // Run without --randomize to get the default order - await using defaultProc = Bun.spawn({ - cmd: [bunExe(), "test", testFile], +const unsortedOrder = Array.from({ length: 100 }, (_, i) => i + 1); +async function runFixture(flags: string[]): Promise<{ order: number[]; seed: number | null }> { + const proc = await Bun.spawn([bunExe(), "test", ...flags], { env: bunEnv, - stdout: "pipe", - stderr: "pipe", - cwd: String(dir), + stdio: ["pipe", "pipe", "pipe"], }); + const exitCode = await proc.exited; + const stdout = await proc.stdout.text(); + const stderr = await proc.stderr.text(); + expect(exitCode).toBe(0); + const stdoutOrder = stdout + .split("\n") + .map(l => l.trim()) + .filter(l => l && !isNaN(+l)) + .map(l => +l); + const seed = stderr.includes("--seed") ? +(stderr.match(/--seed=(-?\d+)/)?.[1] + "") : null; + return { order: stdoutOrder, seed: seed }; +} - const [defaultOut, defaultErr, defaultExit] = await Promise.all([ - defaultProc.stdout.text(), - defaultProc.stderr.text(), - defaultProc.exited, +const sortNumbers = (a: number, b: number) => a - b; +test("--randomize and --seed work", async () => { + const fixture = import.meta.dir + "/test-randomize.fixture.ts"; + + // with --randomize + const { order: randomizedOrder, seed: randomizedSeed } = await runFixture([fixture, "--randomize"]); + expect(randomizedSeed).toBeFinite(); + expect(randomizedOrder.toSorted(sortNumbers)).toEqual(unsortedOrder); + expect(randomizedOrder).not.toEqual(unsortedOrder); + + // different randomized run is different + const { order: differentRandomizedOrder, seed: differentRandomizedSeed } = await runFixture([fixture, "--randomize"]); + expect(differentRandomizedOrder.toSorted(sortNumbers)).toEqual(unsortedOrder); + expect(differentRandomizedOrder).not.toEqual(unsortedOrder); + expect(differentRandomizedOrder).not.toEqual(randomizedOrder); + expect(differentRandomizedSeed).not.toEqual(randomizedSeed); + + // with same seed as first run + const { order: seededOrder, seed: seededSeed } = await runFixture([fixture, "--seed", "" + randomizedSeed]); + expect(seededOrder).toEqual(randomizedOrder); + expect(seededSeed).toEqual(randomizedSeed); + + // with both randomize and seed parameter + const { order: randomizedAndSeededOrder, seed: randomizedAndSeededSeed } = await runFixture([ + fixture, + "--randomize", + "--seed", + "" + randomizedSeed, ]); + expect(randomizedAndSeededOrder).toEqual(randomizedOrder); + expect(randomizedAndSeededSeed).toEqual(randomizedSeed); - expect(defaultExit).toBe(0); + // without seed + const { order: unseededOrder, seed: unseededSeed } = await runFixture([fixture]); + expect(unseededOrder).toEqual(unsortedOrder); + expect(unseededSeed).toBeNull(); +}); - // Extract test execution order from output - const defaultTests = defaultOut.match(/test-\d+/g) || []; - expect(defaultTests.length).toBe(8); - - // Run multiple times WITH --randomize to find a different order - let foundDifferentOrder = false; - const maxAttempts = 20; // Increase attempts since randomization might occasionally match - - for (let i = 0; i < maxAttempts; i++) { - await using randomProc = Bun.spawn({ - cmd: [bunExe(), "test", testFile, "--randomize"], - env: bunEnv, - stdout: "pipe", - stderr: "pipe", - cwd: String(dir), - }); - - const [randomOut, randomErr, randomExit] = await Promise.all([ - randomProc.stdout.text(), - randomProc.stderr.text(), - randomProc.exited, - ]); - - expect(randomExit).toBe(0); - - const randomTests = randomOut.match(/test-\d+/g) || []; - expect(randomTests.length).toBe(8); - - // Check if all tests ran (just different order) - const sortedRandom = [...randomTests].sort(); - const sortedDefault = [...defaultTests].sort(); - expect(sortedRandom).toEqual(sortedDefault); - - // Check if order is different - const orderIsDifferent = randomTests.some((test, index) => test !== defaultTests[index]); - if (orderIsDifferent) { - foundDifferentOrder = true; - break; - } - } - - // With 8 tests and 20 attempts, the probability of not finding a different order - // by pure chance is (1/8!)^20 which is astronomically small - expect(foundDifferentOrder).toBe(true); -}, 30000); // 30 second timeout for this test - -test("--randomize flag works with describe blocks", async () => { - using dir = tempDir("test-randomize-describe", {}); - const testFile = join(String(dir), "describe.test.js"); - - await Bun.write( - testFile, - ` - import { test, describe } from "bun:test"; - - describe("Suite-A", () => { - test("A1", () => { - console.log("A1"); - }); - - test("A2", () => { - console.log("A2"); - }); - - test("A3", () => { - console.log("A3"); - }); - }); - - describe("Suite-B", () => { - test("B1", () => { - console.log("B1"); - }); - - test("B2", () => { - console.log("B2"); - }); - }); - - describe("Suite-C", () => { - test("C1", () => { - console.log("C1"); - }); - - test("C2", () => { - console.log("C2"); - }); - }); - `, +test("randomizes order of files", async () => { + const dir = tempDirWithFiles( + "randomize-order-of-files", + Object.fromEntries( + Array.from({ length: 20 }, (_, i) => [ + `test${i + 1}.test.ts`, + `test("test ${i + 1}", () => { console.log(${i + 1}); });`, + ]), + ), ); - // Run without --randomize - await using defaultProc = Bun.spawn({ - cmd: [bunExe(), "test", testFile], - env: bunEnv, - stdout: "pipe", - stderr: "pipe", - cwd: String(dir), - }); + const { order: unrandomizedOrder, seed: unrandomizedSeed } = await runFixture([dir]); + const { order: anotherUnrandomizedOrder, seed: anotherUnrandomizedSeed } = await runFixture([dir]); + expect(unrandomizedSeed).toBeNull(); + expect(anotherUnrandomizedSeed).toBeNull(); + expect(anotherUnrandomizedOrder).toEqual(unrandomizedOrder); - const [defaultOut, defaultErr, defaultExit] = await Promise.all([ - defaultProc.stdout.text(), - defaultProc.stderr.text(), - defaultProc.exited, - ]); + const { order: randomizedOrder, seed: randomizedSeed } = await runFixture([dir, "--randomize"]); + expect(randomizedSeed).toBeFinite(); + expect(unrandomizedOrder).not.toEqual(randomizedOrder); - expect(defaultExit).toBe(0); + const { order: anotherRandomizedOrder, seed: anotherRandomizedSeed } = await runFixture([dir, "--randomize"]); + expect(anotherRandomizedOrder).not.toEqual(randomizedOrder); + expect(anotherRandomizedSeed).not.toEqual(randomizedSeed); - const defaultTests = defaultOut.match(/[ABC]\d/g) || []; - expect(defaultTests.length).toBe(7); - - // Run with --randomize multiple times - let foundDifferentOrder = false; - - for (let i = 0; i < 20; i++) { - await using randomProc = Bun.spawn({ - cmd: [bunExe(), "test", testFile, "--randomize"], - env: bunEnv, - stdout: "pipe", - stderr: "pipe", - cwd: String(dir), - }); - - const [randomOut, randomErr, randomExit] = await Promise.all([ - randomProc.stdout.text(), - randomProc.stderr.text(), - randomProc.exited, - ]); - - expect(randomExit).toBe(0); - - const randomTests = randomOut.match(/[ABC]\d/g) || []; - expect(randomTests.length).toBe(7); - - // Verify all tests ran - expect([...randomTests].sort()).toEqual([...defaultTests].sort()); - - // Check if order is different - const orderIsDifferent = randomTests.some((test, index) => test !== defaultTests[index]); - if (orderIsDifferent) { - foundDifferentOrder = true; - break; - } - } - - expect(foundDifferentOrder).toBe(true); -}, 30000); - -test("without --randomize flag tests run in consistent order", async () => { - using dir = tempDir("test-consistent", {}); - const testFile = join(String(dir), "consistent.test.js"); - - await Bun.write( - testFile, - ` - import { test } from "bun:test"; - - test("test-1", () => { console.log("1"); }); - test("test-2", () => { console.log("2"); }); - test("test-3", () => { console.log("3"); }); - test("test-4", () => { console.log("4"); }); - test("test-5", () => { console.log("5"); }); - `, - ); - - const runs = []; - - // Run 5 times without --randomize - for (let i = 0; i < 5; i++) { - await using proc = Bun.spawn({ - cmd: [bunExe(), "test", testFile], - env: bunEnv, - stdout: "pipe", - stderr: "pipe", - cwd: String(dir), - }); - - const [out, err, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); - - expect(exitCode).toBe(0); - - const order = out.match(/\d/g) || []; - runs.push(order.join("")); - } - - // All runs should have the same order - const firstRun = runs[0]; - for (const run of runs) { - expect(run).toBe(firstRun); - } -}, 20000); + // test with --seed + const { order: seededOrder, seed: seededSeed } = await runFixture([dir, "--seed", "" + randomizedSeed]); + expect(seededOrder).toEqual(randomizedOrder); + expect(seededSeed).toEqual(randomizedSeed); +}); From 5457d76bcb9c59dd63633aea09a70689927dcab3 Mon Sep 17 00:00:00 2001 From: robobun Date: Thu, 25 Sep 2025 23:52:56 -0700 Subject: [PATCH 19/43] Fix double-free in createArgv function (#22978) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Fixed a double-free bug in the `createArgv` function in `node_process.zig` ## Details The `createArgv` function had two `defer allocator.free(args)` statements: - One on line 164 - Another on line 192 (now removed) This would cause the same memory to be freed twice when the function returned, leading to undefined behavior. Fixes #22975 ## Test plan The existing process.argv tests should continue to pass with this fix. 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude Bot Co-authored-by: Claude Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Dylan Conway Co-authored-by: Jarred Sumner --- src/bun.js/node/node_process.zig | 2 - .../22978-createargv-double-free.test.ts | 100 ++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 test/regression/issue/22978-createargv-double-free.test.ts diff --git a/src/bun.js/node/node_process.zig b/src/bun.js/node/node_process.zig index 07966abedd..b78d0d04da 100644 --- a/src/bun.js/node/node_process.zig +++ b/src/bun.js/node/node_process.zig @@ -189,8 +189,6 @@ fn createArgv(globalObject: *jsc.JSGlobalObject) callconv(.C) jsc.JSValue { } } - defer allocator.free(args); - if (vm.worker) |worker| { for (worker.argv) |arg| { args_list.appendAssumeCapacity(bun.String.init(arg)); diff --git a/test/regression/issue/22978-createargv-double-free.test.ts b/test/regression/issue/22978-createargv-double-free.test.ts new file mode 100644 index 0000000000..579874c154 --- /dev/null +++ b/test/regression/issue/22978-createargv-double-free.test.ts @@ -0,0 +1,100 @@ +import { expect, test } from "bun:test"; +import { bunEnv, bunExe, tempDir } from "harness"; + +test("process.argv with many arguments doesn't double-free", async () => { + // The stack fallback buffer in createArgv is 32 * @sizeOf(jsc.ZigString) + // We need more than 32 arguments to trigger heap allocation + // Adding 40 arguments to ensure we exceed the stack buffer + const manyArgs = Array.from({ length: 129 }, (_, i) => `arg${i}`); + + using dir = tempDir("argv-test", { + "check-argv.js": ` + // Just access process.argv to trigger the createArgv function + const argv = process.argv; + console.log(JSON.stringify({ + length: argv.length, + // Check that all arguments are present and valid + hasAllArgs: argv.slice(2).every((arg, i) => arg === \`arg\${i}\`), + // The first two should be the executable and script path + hasExe: argv[0].includes("bun"), + hasScript: argv[1].endsWith("check-argv.js") + })); + `, + }); + + await using proc = Bun.spawn({ + cmd: [bunExe(), "check-argv.js", ...manyArgs], + env: bunEnv, + cwd: String(dir), + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + // If there was a double-free, ASAN would catch it and the process would crash + expect(exitCode).toBe(0); + expect(stderr).toBe(""); + + const result = JSON.parse(stdout.trim()); + expect(result.length).toBe(131); // exe + script + 129 args + expect(result.hasAllArgs).toBe(true); + expect(result.hasExe).toBe(true); + expect(result.hasScript).toBe(true); +}); + +test.todo("process.argv with many arguments in worker", async () => { + // Test the worker code path as well + const manyArgs = Array.from({ length: 129 }, (_, i) => `worker-arg${i}`); + + using dir = tempDir("argv-worker-test", { + "worker.js": ` + const { parentPort, workerData } = require("worker_threads"); + const argv = process.argv; + parentPort.postMessage({ + length: argv.length, + hasAllArgs: workerData.every((arg, i) => argv[i + 2] === arg), + hasExe: argv[0].includes("bun"), + hasScript: argv[1] === "[worker eval]" || argv[1].endsWith("worker.js") + }); + `, + "main.js": ` + const { Worker } = require("worker_threads"); + const args = ${JSON.stringify(manyArgs)}; + + const worker = new Worker("./worker.js", { + workerData: args, + argv: args + }); + + worker.on("message", (msg) => { + console.log(JSON.stringify(msg)); + process.exit(0); + }); + + worker.on("error", (err) => { + console.error("Worker error:", err); + process.exit(1); + }); + `, + }); + + await using proc = Bun.spawn({ + cmd: [bunExe(), "main.js"], + env: bunEnv, + cwd: String(dir), + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([proc.stdout.text(), proc.stderr.text(), proc.exited]); + + expect(exitCode).toBe(0); + expect(stderr).toBe(""); + + const result = JSON.parse(stdout.trim()); + expect(result.length).toBe(131); // exe + script + 129 args + expect(result.hasAllArgs).toBe(true); + expect(result.hasExe).toBe(true); + expect(result.hasScript).toBe(true); +}); From c63fa996d12e999125f16d045d953fa67f49c36f Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Fri, 26 Sep 2025 00:17:58 -0700 Subject: [PATCH 20/43] package.json: add amazonlinux machine script --- package.json | 1 + 1 file changed, 1 insertion(+) diff --git a/package.json b/package.json index fff1134c97..77db0e218c 100644 --- a/package.json +++ b/package.json @@ -87,6 +87,7 @@ "machine:linux:ubuntu": "./scripts/machine.mjs ssh --cloud=aws --arch=x64 --instance-type c7i.2xlarge --os=linux --distro=ubuntu --release=25.04", "machine:linux:debian": "./scripts/machine.mjs ssh --cloud=aws --arch=x64 --instance-type c7i.2xlarge --os=linux --distro=debian --release=12", "machine:linux:alpine": "./scripts/machine.mjs ssh --cloud=aws --arch=x64 --instance-type c7i.2xlarge --os=linux --distro=alpine --release=3.21", + "machine:linux:amazonlinux": "./scripts/machine.mjs ssh --cloud=aws --arch=x64 --instance-type c7i.2xlarge --os=linux --distro=amazonlinux --release=2023", "machine:windows:2019": "./scripts/machine.mjs ssh --cloud=aws --arch=x64 --instance-type c7i.2xlarge --os=windows --release=2019", "sync-webkit-source": "bun ./scripts/sync-webkit-source.ts" } From 90c7a4e886d24ae8d1087b472d8ea4f8d6080b0d Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Fri, 26 Sep 2025 00:24:02 -0700 Subject: [PATCH 21/43] update no-validate-leaksan.txt --- test/no-validate-leaksan.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/no-validate-leaksan.txt b/test/no-validate-leaksan.txt index bd927cf137..33330a1ee0 100644 --- a/test/no-validate-leaksan.txt +++ b/test/no-validate-leaksan.txt @@ -402,3 +402,4 @@ test/js/web/abort/abort.test.ts test/js/third_party/resvg/bbox.test.js test/regression/issue/10139.test.ts test/js/bun/udp/udp_socket.test.ts +test/cli/init/init.test.ts From 064ecc37fd5ea23db03a1e425482d9c26faf5083 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Fri, 26 Sep 2025 00:33:30 -0700 Subject: [PATCH 22/43] Move Bun__JSRequest__calculateEstimatedByteSize earlier (#22993) ### What does this PR do? ### How did you verify your code works? --- src/bun.js/bindings/JSBunRequest.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/bun.js/bindings/JSBunRequest.cpp b/src/bun.js/bindings/JSBunRequest.cpp index 1e5fecfb54..34452e2b15 100644 --- a/src/bun.js/bindings/JSBunRequest.cpp +++ b/src/bun.js/bindings/JSBunRequest.cpp @@ -21,6 +21,8 @@ static JSC_DECLARE_CUSTOM_GETTER(jsJSBunRequestGetCookies); static JSC_DECLARE_HOST_FUNCTION(jsJSBunRequestClone); +extern "C" void Bun__JSRequest__calculateEstimatedByteSize(void* requestPtr); + static const HashTableValue JSBunRequestPrototypeValues[] = { { "params"_s, static_cast(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, jsJSBunRequestGetParams, nullptr } }, { "cookies"_s, static_cast(JSC::PropertyAttribute::CustomAccessor | JSC::PropertyAttribute::ReadOnly | JSC::PropertyAttribute::DontDelete), NoIntrinsic, { HashTableValue::GetterSetterType, jsJSBunRequestGetCookies, nullptr } }, @@ -29,6 +31,10 @@ static const HashTableValue JSBunRequestPrototypeValues[] = { JSBunRequest* JSBunRequest::create(JSC::VM& vm, JSC::Structure* structure, void* sinkPtr, JSObject* params) { + // Do this **extremely** early, before we create the JSValue. + // We do not want to risk the GC running before this function is called. + Bun__JSRequest__calculateEstimatedByteSize(sinkPtr); + JSBunRequest* ptr = new (NotNull, JSC::allocateCell(vm)) JSBunRequest(vm, structure, sinkPtr); ptr->finishCreation(vm, params); return ptr; @@ -124,13 +130,12 @@ JSBunRequest::JSBunRequest(JSC::VM& vm, JSC::Structure* structure, void* sinkPtr { } extern "C" size_t Request__estimatedSize(void* requestPtr); -extern "C" void Bun__JSRequest__calculateEstimatedByteSize(void* requestPtr); + void JSBunRequest::finishCreation(JSC::VM& vm, JSObject* params) { Base::finishCreation(vm); m_params.setMayBeNull(vm, this, params); m_cookies.clear(); - Bun__JSRequest__calculateEstimatedByteSize(this->wrapped()); auto size = Request__estimatedSize(this->wrapped()); vm.heap.reportExtraMemoryAllocated(this, size); From ea735c341fb7c84a66af4a14fa69aa5ec7253d62 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Fri, 26 Sep 2025 01:46:26 -0700 Subject: [PATCH 23/43] Bump WebKit (#22957) ### What does this PR do? ### How did you verify your code works? --------- Co-authored-by: Claude Bot Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- cmake/tools/SetupWebKit.cmake | 2 +- .../bindings/node/crypto/CryptoGenDhKeyPair.h | 2 +- .../bindings/webcore/JSMessageEventCustom.cpp | 18 +++-- src/bun.js/bindings/webcore/MessageEvent.cpp | 14 ++-- .../webcore/PerformanceUserTiming.cpp | 38 +++++----- .../webcore/SerializedScriptValue.cpp | 8 +- .../bindings/webcrypto/SubtleCrypto.cpp | 73 ++++++++++--------- 7 files changed, 83 insertions(+), 72 deletions(-) diff --git a/cmake/tools/SetupWebKit.cmake b/cmake/tools/SetupWebKit.cmake index 1804cd0848..5aee190f85 100644 --- a/cmake/tools/SetupWebKit.cmake +++ b/cmake/tools/SetupWebKit.cmake @@ -2,7 +2,7 @@ option(WEBKIT_VERSION "The version of WebKit to use") option(WEBKIT_LOCAL "If a local version of WebKit should be used instead of downloading") if(NOT WEBKIT_VERSION) - set(WEBKIT_VERSION 495c25e24927ba03277ae225cd42811588d03ff8) + set(WEBKIT_VERSION 69fa2714ab5f917c2d15501ff8cfdccfaea78882) endif() string(SUBSTRING ${WEBKIT_VERSION} 0 16 WEBKIT_VERSION_PREFIX) diff --git a/src/bun.js/bindings/node/crypto/CryptoGenDhKeyPair.h b/src/bun.js/bindings/node/crypto/CryptoGenDhKeyPair.h index 5ed7e1a3ef..845cb4b90c 100644 --- a/src/bun.js/bindings/node/crypto/CryptoGenDhKeyPair.h +++ b/src/bun.js/bindings/node/crypto/CryptoGenDhKeyPair.h @@ -33,7 +33,7 @@ public: ncrypto::EVPKeyCtxPointer setup(); static std::optional fromJS(JSC::JSGlobalObject* globalObject, JSC::ThrowScope& scope, const JSC::GCOwnedDataScope& typeView, JSC::JSValue optionsValue, const KeyEncodingConfig& config); - std::variant m_prime; + WTF::Variant m_prime; uint32_t m_generator; }; diff --git a/src/bun.js/bindings/webcore/JSMessageEventCustom.cpp b/src/bun.js/bindings/webcore/JSMessageEventCustom.cpp index e67b3b4b56..eed83eafc4 100644 --- a/src/bun.js/bindings/webcore/JSMessageEventCustom.cpp +++ b/src/bun.js/bindings/webcore/JSMessageEventCustom.cpp @@ -57,14 +57,16 @@ JSC::JSValue JSMessageEvent::data(JSC::JSGlobalObject& lexicalGlobalObject) cons { auto throwScope = DECLARE_THROW_SCOPE(lexicalGlobalObject.vm()); return cachedPropertyValue(throwScope, lexicalGlobalObject, *this, wrapped().cachedData(), [this, &lexicalGlobalObject](JSC::ThrowScope&) { - return WTF::switchOn( - wrapped().data(), [this](MessageEvent::JSValueTag) -> JSC::JSValue { return wrapped().jsData().getValue(JSC::jsNull()); }, - [this, &lexicalGlobalObject](const Ref& data) { - // FIXME: Is it best to handle errors by returning null rather than throwing an exception? - return data->deserialize(lexicalGlobalObject, globalObject(), wrapped().ports(), SerializationErrorMode::NonThrowing); }, - [&lexicalGlobalObject](const String& data) { return toJS(lexicalGlobalObject, data); }, - [this, &lexicalGlobalObject](const Ref& data) { return toJS>(lexicalGlobalObject, *globalObject(), data); }, - [this, &lexicalGlobalObject](const Ref& data) { return toJS>(lexicalGlobalObject, *globalObject(), data); }); + return std::visit( + WTF::makeVisitor( + [this](MessageEvent::JSValueTag) -> JSC::JSValue { return wrapped().jsData().getValue(JSC::jsNull()); }, + [this, &lexicalGlobalObject](const Ref& data) -> JSC::JSValue { + // FIXME: Is it best to handle errors by returning null rather than throwing an exception? + return data->deserialize(lexicalGlobalObject, globalObject(), wrapped().ports(), SerializationErrorMode::NonThrowing); }, + [&lexicalGlobalObject](const String& data) -> JSC::JSValue { return toJS(lexicalGlobalObject, data); }, + [this, &lexicalGlobalObject](const Ref& data) -> JSC::JSValue { return toJS>(lexicalGlobalObject, *globalObject(), data); }, + [this, &lexicalGlobalObject](const Ref& data) -> JSC::JSValue { return toJS>(lexicalGlobalObject, *globalObject(), data); }), + wrapped().data()); }); } diff --git a/src/bun.js/bindings/webcore/MessageEvent.cpp b/src/bun.js/bindings/webcore/MessageEvent.cpp index c8a00d99ed..a050adfe02 100644 --- a/src/bun.js/bindings/webcore/MessageEvent.cpp +++ b/src/bun.js/bindings/webcore/MessageEvent.cpp @@ -146,12 +146,14 @@ EventInterface MessageEvent::eventInterface() const size_t MessageEvent::memoryCost() const { Locker { m_concurrentDataAccessLock }; - return WTF::switchOn( - m_data, [](JSValueTag) -> size_t { return 0; }, - [](const Ref& data) -> size_t { return data->memoryCost(); }, - [](const String& string) -> size_t { return string.sizeInBytes(); }, - [](const Ref& blob) -> size_t { return blob->memoryCost(); }, - [](const Ref& buffer) -> size_t { return buffer->byteLength(); }); + return std::visit( + WTF::makeVisitor( + [](JSValueTag) -> size_t { return 0; }, + [](const Ref& data) -> size_t { return data->memoryCost(); }, + [](const String& string) -> size_t { return string.sizeInBytes(); }, + [](const Ref& blob) -> size_t { return blob->memoryCost(); }, + [](const Ref& buffer) -> size_t { return buffer->byteLength(); }), + m_data); } } // namespace WebCore diff --git a/src/bun.js/bindings/webcore/PerformanceUserTiming.cpp b/src/bun.js/bindings/webcore/PerformanceUserTiming.cpp index 359dafe02a..d9d67ebb21 100644 --- a/src/bun.js/bindings/webcore/PerformanceUserTiming.cpp +++ b/src/bun.js/bindings/webcore/PerformanceUserTiming.cpp @@ -142,9 +142,10 @@ void PerformanceUserTiming::clearMarks(const String& markName) ExceptionOr PerformanceUserTiming::convertMarkToTimestamp(const std::variant& mark) const { - return WTF::switchOn(mark, [&](auto& value) { + return std::visit([&](auto& value) { return convertMarkToTimestamp(value); - }); + }, + mark); } ExceptionOr PerformanceUserTiming::convertMarkToTimestamp(const String& mark) const @@ -283,23 +284,24 @@ static bool isNonEmptyDictionary(const PerformanceMeasureOptions& measureOptions ExceptionOr> PerformanceUserTiming::measure(JSC::JSGlobalObject& globalObject, const String& measureName, std::optional&& startOrMeasureOptions, const String& endMark) { if (startOrMeasureOptions) { - return WTF::switchOn( - *startOrMeasureOptions, - [&](const PerformanceMeasureOptions& measureOptions) -> ExceptionOr> { - if (isNonEmptyDictionary(measureOptions)) { - if (!endMark.isNull()) - return Exception { TypeError }; - if (!measureOptions.start && !measureOptions.end) - return Exception { TypeError }; - if (measureOptions.start && measureOptions.duration && measureOptions.end) - return Exception { TypeError }; - } + return std::visit( + WTF::makeVisitor( + [&](const PerformanceMeasureOptions& measureOptions) -> ExceptionOr> { + if (isNonEmptyDictionary(measureOptions)) { + if (!endMark.isNull()) + return Exception { TypeError }; + if (!measureOptions.start && !measureOptions.end) + return Exception { TypeError }; + if (measureOptions.start && measureOptions.duration && measureOptions.end) + return Exception { TypeError }; + } - return measure(globalObject, measureName, measureOptions); - }, - [&](const String& startMark) { - return measure(measureName, startMark, endMark); - }); + return measure(globalObject, measureName, measureOptions); + }, + [&](const String& startMark) -> ExceptionOr> { + return measure(measureName, startMark, endMark); + }), + *startOrMeasureOptions); } return measure(measureName, {}, endMark); diff --git a/src/bun.js/bindings/webcore/SerializedScriptValue.cpp b/src/bun.js/bindings/webcore/SerializedScriptValue.cpp index 1b16e1bee7..7cf7e2bd13 100644 --- a/src/bun.js/bindings/webcore/SerializedScriptValue.cpp +++ b/src/bun.js/bindings/webcore/SerializedScriptValue.cpp @@ -6275,9 +6275,11 @@ JSValue SerializedScriptValue::deserialize(JSGlobalObject& lexicalGlobalObject, for (const auto& property : m_simpleInMemoryPropertyTable) { // We **must** clone this so that the atomic flag doesn't get set to true. JSC::Identifier identifier = JSC::Identifier::fromString(vm, property.propertyName.isolatedCopy()); - JSValue value = WTF::switchOn( - property.value, [](JSValue value) -> JSValue { return value; }, - [&](const String& string) -> JSValue { return jsString(vm, string); }); + JSValue value = std::visit( + WTF::makeVisitor( + [](JSValue value) -> JSValue { return value; }, + [&](const String& string) -> JSValue { return jsString(vm, string); }), + property.value); object->putDirect(vm, identifier, value); } diff --git a/src/bun.js/bindings/webcrypto/SubtleCrypto.cpp b/src/bun.js/bindings/webcrypto/SubtleCrypto.cpp index eddda7a9a8..4a233ea064 100644 --- a/src/bun.js/bindings/webcrypto/SubtleCrypto.cpp +++ b/src/bun.js/bindings/webcrypto/SubtleCrypto.cpp @@ -512,26 +512,28 @@ static std::optional toKeyData(SubtleCrypto::KeyFormat format, SubtleCr case SubtleCrypto::KeyFormat::Spki: case SubtleCrypto::KeyFormat::Pkcs8: case SubtleCrypto::KeyFormat::Raw: - return WTF::switchOn( - keyDataVariant, - [&promise](JsonWebKey&) -> std::optional { - promise->reject(Exception { TypeError }); - return std::nullopt; - }, - [](auto& bufferSource) -> std::optional { - return KeyData { Vector(std::span { static_cast(bufferSource->data()), bufferSource->byteLength() }) }; - }); + return std::visit( + WTF::makeVisitor( + [&promise](JsonWebKey&) -> std::optional { + promise->reject(Exception { TypeError }); + return std::nullopt; + }, + [](auto& bufferSource) -> std::optional { + return KeyData { Vector(std::span { static_cast(bufferSource->data()), bufferSource->byteLength() }) }; + }), + keyDataVariant); case SubtleCrypto::KeyFormat::Jwk: - return WTF::switchOn( - keyDataVariant, - [](JsonWebKey& webKey) -> std::optional { - normalizeJsonWebKey(webKey); - return KeyData { webKey }; - }, - [&promise](auto&) -> std::optional { - promise->reject(Exception { TypeError }); - return std::nullopt; - }); + return std::visit( + WTF::makeVisitor( + [](JsonWebKey& webKey) -> std::optional { + normalizeJsonWebKey(webKey); + return KeyData { webKey }; + }, + [&promise](auto&) -> std::optional { + promise->reject(Exception { TypeError }); + return std::nullopt; + }), + keyDataVariant); } RELEASE_ASSERT_NOT_REACHED(); @@ -815,22 +817,23 @@ void SubtleCrypto::generateKey(JSC::JSGlobalObject& state, AlgorithmIdentifier&& WeakPtr weakThis { *this }; auto callback = [index, weakThis](KeyOrKeyPair&& keyOrKeyPair) mutable { if (auto promise = getPromise(index, weakThis)) { - WTF::switchOn( - keyOrKeyPair, - [&promise](RefPtr& key) { - if ((key->type() == CryptoKeyType::Private || key->type() == CryptoKeyType::Secret) && !key->usagesBitmap()) { - rejectWithException(promise.releaseNonNull(), SyntaxError, ""_s); - return; - } - promise->resolve>(*key); - }, - [&promise](CryptoKeyPair& keyPair) { - if (!keyPair.privateKey->usagesBitmap()) { - rejectWithException(promise.releaseNonNull(), SyntaxError, ""_s); - return; - } - promise->resolve>(keyPair); - }); + std::visit( + WTF::makeVisitor( + [&promise](RefPtr& key) { + if ((key->type() == CryptoKeyType::Private || key->type() == CryptoKeyType::Secret) && !key->usagesBitmap()) { + rejectWithException(promise.releaseNonNull(), SyntaxError, ""_s); + return; + } + promise->resolve>(*key); + }, + [&promise](CryptoKeyPair& keyPair) { + if (!keyPair.privateKey->usagesBitmap()) { + rejectWithException(promise.releaseNonNull(), SyntaxError, ""_s); + return; + } + promise->resolve>(keyPair); + }), + keyOrKeyPair); } }; auto exceptionCallback = [index, weakThis](ExceptionCode ec, const String& msg) mutable { From 17b503b389d13ce28355f852b45ae638ce90a166 Mon Sep 17 00:00:00 2001 From: Marko Vejnovic Date: Fri, 26 Sep 2025 03:06:18 -0700 Subject: [PATCH 24/43] Redis PUB/SUB 2.0 (#22568) ### What does this PR do? **This PR is created because [the previous PR I opened](https://github.com/oven-sh/bun/pull/21728) had some concerning issues.** Thanks @Jarred-Sumner for the help. The goal of this PR is to introduce PUB/SUB functionality to the built-in Redis client. Based on the fact that the current Redis API does not appear to have compatibility with `io-redis` or `redis-node`, I've decided to do away with existing APIs and API compatibility with these existing libraries. I have decided to base my implementation on the [`redis-node` pub/sub API](https://github.com/redis/node-redis/blob/master/docs/pub-sub.md). #### Random Things That Happened - [x] Refactored the build scripts so that `valgrind` can be disabled. - [x] Added a `numeric` namespace in `harness.ts` with useful mathematical libraries. - [x] Added a mechanism in `cppbind.ts` to disable static assertions (specifically to allow `check_slow` even when returning a `JSValue`). Implemented via `// NOLINT[NEXTLINE]?\(.*\)` macros. - [x] Fixed inconsistencies in error handling of `JSMap`. ### How did you verify your code works? I've written a set of unit tests to hopefully catch the major use-cases of this feature. They all appear to pass. #### Future Improvements I would have a lot more confidence in our Redis implementation if we tested it with a test suite running over a network which emulates a high network failure rate. There are large amounts of edge cases that are worthwhile to grab, but I think we can roll that out in a future PR. ### Future Tasks - [ ] Tests over flaky network - [ ] Use the custom private members over `_`. --------- Co-authored-by: Jarred Sumner --- build.zig | 12 +- cmake/Options.cmake | 6 +- cmake/targets/BuildBun.cmake | 31 +- cmake/tools/SetupZig.cmake | 1 + docs/api/redis.md | 99 +++- docs/nav.ts | 2 +- packages/bun-types/redis.d.ts | 262 +++++++-- src/bun.js/api/BunObject.zig | 8 +- src/bun.js/api/valkey.classes.ts | 4 +- src/bun.js/bindings/ErrorCode.ts | 31 +- src/bun.js/bindings/JSGlobalObject.zig | 7 + src/bun.js/bindings/JSMap.zig | 43 +- src/bun.js/bindings/JSPromise.zig | 11 + src/bun.js/bindings/JSRef.zig | 9 + src/bun.js/bindings/bindings.cpp | 53 +- src/bun.js/bindings/headers.h | 7 +- src/bun.js/event_loop.zig | 13 + src/deps/uws/us_socket_t.zig | 18 +- src/memory.zig | 30 + src/string/StringBuilder.zig | 9 + src/valkey/ValkeyCommand.zig | 3 +- src/valkey/js_valkey.zig | 611 ++++++++++++++++++-- src/valkey/js_valkey_functions.zig | 474 +++++++++++++-- src/valkey/valkey.zig | 184 +++++- src/valkey/valkey_protocol.zig | 12 + test/_util/numeric.ts | 192 ++++++ test/harness.ts | 13 +- test/integration/bun-types/fixture/redis.ts | 31 + test/js/valkey/test-utils.ts | 108 +++- test/js/valkey/valkey.failing-subscriber.ts | 48 ++ test/js/valkey/valkey.test.ts | 577 +++++++++++++++++- 31 files changed, 2643 insertions(+), 266 deletions(-) create mode 100644 test/_util/numeric.ts create mode 100644 test/integration/bun-types/fixture/redis.ts create mode 100644 test/js/valkey/valkey.failing-subscriber.ts diff --git a/build.zig b/build.zig index e368cac8d9..239df85c53 100644 --- a/build.zig +++ b/build.zig @@ -48,6 +48,7 @@ const BunBuildOptions = struct { /// enable debug logs in release builds enable_logs: bool = false, enable_asan: bool, + enable_valgrind: bool, tracy_callstack_depth: u16, reported_nodejs_version: Version, /// To make iterating on some '@embedFile's faster, we load them at runtime @@ -94,6 +95,7 @@ const BunBuildOptions = struct { opts.addOption(bool, "baseline", this.isBaseline()); opts.addOption(bool, "enable_logs", this.enable_logs); opts.addOption(bool, "enable_asan", this.enable_asan); + opts.addOption(bool, "enable_valgrind", this.enable_valgrind); opts.addOption([]const u8, "reported_nodejs_version", b.fmt("{}", .{this.reported_nodejs_version})); opts.addOption(bool, "zig_self_hosted_backend", this.no_llvm); opts.addOption(bool, "override_no_export_cpp_apis", this.override_no_export_cpp_apis); @@ -213,26 +215,21 @@ pub fn build(b: *Build) !void { var build_options = BunBuildOptions{ .target = target, .optimize = optimize, - .os = os, .arch = arch, - .codegen_path = codegen_path, .codegen_embed = codegen_embed, .no_llvm = no_llvm, .override_no_export_cpp_apis = override_no_export_cpp_apis, - .version = try Version.parse(bun_version), .canary_revision = canary: { const rev = b.option(u32, "canary", "Treat this as a canary build") orelse 0; break :canary if (rev == 0) null else rev; }, - .reported_nodejs_version = try Version.parse( b.option([]const u8, "reported_nodejs_version", "Reported Node.js version") orelse "0.0.0-unset", ), - .sha = sha: { const sha_buildoption = b.option([]const u8, "sha", "Force the git sha"); const sha_github = b.graph.env_map.get("GITHUB_SHA"); @@ -268,10 +265,10 @@ pub fn build(b: *Build) !void { break :sha sha; }, - .tracy_callstack_depth = b.option(u16, "tracy_callstack_depth", "") orelse 10, .enable_logs = b.option(bool, "enable_logs", "Enable logs in release") orelse false, .enable_asan = b.option(bool, "enable_asan", "Enable asan") orelse false, + .enable_valgrind = b.option(bool, "enable_valgrind", "Enable valgrind") orelse false, }; // zig build obj @@ -500,6 +497,7 @@ fn addMultiCheck( .codegen_path = root_build_options.codegen_path, .no_llvm = root_build_options.no_llvm, .enable_asan = root_build_options.enable_asan, + .enable_valgrind = root_build_options.enable_valgrind, .override_no_export_cpp_apis = root_build_options.override_no_export_cpp_apis, }; @@ -636,7 +634,7 @@ fn configureObj(b: *Build, opts: *BunBuildOptions, obj: *Compile) void { obj.link_function_sections = true; obj.link_data_sections = true; - if (opts.optimize == .Debug) { + if (opts.optimize == .Debug and opts.enable_valgrind) { obj.root_module.valgrind = true; } } diff --git a/cmake/Options.cmake b/cmake/Options.cmake index 3dd5220cc5..1e9b664321 100644 --- a/cmake/Options.cmake +++ b/cmake/Options.cmake @@ -60,10 +60,10 @@ endif() # Windows Code Signing Option if(WIN32) optionx(ENABLE_WINDOWS_CODESIGNING BOOL "Enable Windows code signing with DigiCert KeyLocker" DEFAULT OFF) - + if(ENABLE_WINDOWS_CODESIGNING) message(STATUS "Windows code signing: ENABLED") - + # Check for required environment variables if(NOT DEFINED ENV{SM_API_KEY}) message(WARNING "SM_API_KEY not set - code signing may fail") @@ -114,8 +114,10 @@ endif() if(DEBUG AND ((APPLE AND ARCH STREQUAL "aarch64") OR LINUX)) set(DEFAULT_ASAN ON) + set(DEFAULT_VALGRIND OFF) else() set(DEFAULT_ASAN OFF) + set(DEFAULT_VALGRIND OFF) endif() optionx(ENABLE_ASAN BOOL "If ASAN support should be enabled" DEFAULT ${DEFAULT_ASAN}) diff --git a/cmake/targets/BuildBun.cmake b/cmake/targets/BuildBun.cmake index 5e8795c1bd..888c360739 100644 --- a/cmake/targets/BuildBun.cmake +++ b/cmake/targets/BuildBun.cmake @@ -2,6 +2,8 @@ include(PathUtils) if(DEBUG) set(bun bun-debug) +elseif(ENABLE_ASAN AND ENABLE_VALGRIND) + set(bun bun-asan-valgrind) elseif(ENABLE_ASAN) set(bun bun-asan) elseif(ENABLE_VALGRIND) @@ -619,6 +621,7 @@ register_command( -Dcpu=${ZIG_CPU} -Denable_logs=$,true,false> -Denable_asan=$,true,false> + -Denable_valgrind=$,true,false> -Dversion=${VERSION} -Dreported_nodejs_version=${NODEJS_VERSION} -Dcanary=${CANARY_REVISION} @@ -886,12 +889,8 @@ if(NOT WIN32) endif() if(ENABLE_ASAN) - target_compile_options(${bun} PUBLIC - -fsanitize=address - ) - target_link_libraries(${bun} PUBLIC - -fsanitize=address - ) + target_compile_options(${bun} PUBLIC -fsanitize=address) + target_link_libraries(${bun} PUBLIC -fsanitize=address) endif() target_compile_options(${bun} PUBLIC @@ -930,12 +929,8 @@ if(NOT WIN32) ) if(ENABLE_ASAN) - target_compile_options(${bun} PUBLIC - -fsanitize=address - ) - target_link_libraries(${bun} PUBLIC - -fsanitize=address - ) + target_compile_options(${bun} PUBLIC -fsanitize=address) + target_link_libraries(${bun} PUBLIC -fsanitize=address) endif() endif() else() @@ -1063,7 +1058,7 @@ if(LINUX) ) endif() - if (NOT DEBUG AND NOT ENABLE_ASAN) + if (NOT DEBUG AND NOT ENABLE_ASAN AND NOT ENABLE_VALGRIND) target_link_options(${bun} PUBLIC -Wl,-icf=safe ) @@ -1366,12 +1361,20 @@ if(NOT BUN_CPP_ONLY) if(ENABLE_BASELINE) set(bunTriplet ${bunTriplet}-baseline) endif() - if(ENABLE_ASAN) + + if (ENABLE_ASAN AND ENABLE_VALGRIND) + set(bunTriplet ${bunTriplet}-asan-valgrind) + set(bunPath ${bunTriplet}) + elseif (ENABLE_VALGRIND) + set(bunTriplet ${bunTriplet}-valgrind) + set(bunPath ${bunTriplet}) + elseif(ENABLE_ASAN) set(bunTriplet ${bunTriplet}-asan) set(bunPath ${bunTriplet}) else() string(REPLACE bun ${bunTriplet} bunPath ${bun}) endif() + set(bunFiles ${bunExe} features.json) if(WIN32) list(APPEND bunFiles ${bun}.pdb) diff --git a/cmake/tools/SetupZig.cmake b/cmake/tools/SetupZig.cmake index cbcb50a867..5143353ca0 100644 --- a/cmake/tools/SetupZig.cmake +++ b/cmake/tools/SetupZig.cmake @@ -90,6 +90,7 @@ register_command( -DZIG_PATH=${ZIG_PATH} -DZIG_COMMIT=${ZIG_COMMIT} -DENABLE_ASAN=${ENABLE_ASAN} + -DENABLE_VALGRIND=${ENABLE_VALGRIND} -DZIG_COMPILER_SAFE=${ZIG_COMPILER_SAFE} -P ${CWD}/cmake/scripts/DownloadZig.cmake SOURCES diff --git a/docs/api/redis.md b/docs/api/redis.md index 929babb318..3c663cb0a0 100644 --- a/docs/api/redis.md +++ b/docs/api/redis.md @@ -161,6 +161,102 @@ const randomTag = await redis.srandmember("tags"); const poppedTag = await redis.spop("tags"); ``` +## Pub/Sub + +Bun provides native bindings for the [Redis +Pub/Sub](https://redis.io/docs/latest/develop/pubsub/) protocol. **New in Bun +1.2.23** + +{% callout %} +**🚧** — The Redis Pub/Sub feature is experimental. Although we expect it to be +stable, we're currently actively looking for feedback and areas for improvement. +{% /callout %} + +### Basic Usage + +To get started publishing messages, you can set up a publisher in +`publisher.ts`: + +```typescript#publisher.ts +import { RedisClient } from "bun"; + +const writer = new RedisClient("redis://localhost:6739"); +await writer.connect(); + +writer.publish("general", "Hello everyone!"); + +writer.close(); +``` + +In another file, create the subscriber in `subscriber.ts`: + +```typescript#subscriber.ts +import { RedisClient } from "bun"; + +const listener = new RedisClient("redis://localhost:6739"); +await listener.connect(); + +await listener.subscribe("general", (message, channel) => { + console.log(`Received: ${message}`); +}); +``` + +In one shell, run your subscriber: + +```bash +bun run subscriber.ts +``` + +and, in another, run your publisher: + +```bash +bun run publisher.ts +``` + +{% callout %} +**Note:** The subscription mode takes over the `RedisClient` connection. A +client with subscriptions can only call `RedisClient.prototype.subscribe()`. In +other words, applications which need to message Redis need a separate +connection, acquirable through `.duplicate()`: + +```typescript +import { RedisClient } from "bun"; + +const redis = new RedisClient("redis://localhost:6379"); +await redis.connect(); +const subscriber = await redis.duplicate(); + +await subscriber.subscribe("foo", () => {}); +await redis.set("bar", "baz"); +``` + +{% /callout %} + +### Publishing + +Publishing messages is done through the `publish()` method: + +```typescript +await client.publish(channelName, message); +``` + +### Subscriptions + +The Bun `RedisClient` allows you to subscribe to channels through the +`.subscribe()` method: + +```typescript +await client.subscribe(channel, (message, channel) => {}); +``` + +You can unsubscribe through the `.unsubscribe()` method: + +```typescript +await client.unsubscribe(); // Unsubscribe from all channels. +await client.unsubscribe(channel); // Unsubscribe a particular channel. +await client.unsubscribe(channel, listener); // Unsubscribe a particular listener. +``` + ## Advanced Usage ### Command Execution and Pipelining @@ -482,9 +578,10 @@ When connecting to Redis servers using older versions that don't support RESP3, Current limitations of the Redis client we are planning to address in future versions: -- [ ] No dedicated API for pub/sub functionality (though you can use the raw command API) - [ ] Transactions (MULTI/EXEC) must be done through raw commands for now - [ ] Streams are supported but without dedicated methods +- [ ] Pub/Sub does not currently support binary data, nor pattern-based + subscriptions. Unsupported features: diff --git a/docs/nav.ts b/docs/nav.ts index 0e7c91bed0..3d776961a0 100644 --- a/docs/nav.ts +++ b/docs/nav.ts @@ -359,7 +359,7 @@ export default { page("api/file-io", "File I/O", { description: `Read and write files fast with Bun's heavily optimized file system API.`, }), // "`Bun.write`"), - page("api/redis", "Redis client", { + page("api/redis", "Redis Client", { description: `Bun provides a fast, native Redis client with automatic command pipelining for better performance.`, }), page("api/import-meta", "import.meta", { diff --git a/packages/bun-types/redis.d.ts b/packages/bun-types/redis.d.ts index 9cf228c92a..414b01288b 100644 --- a/packages/bun-types/redis.d.ts +++ b/packages/bun-types/redis.d.ts @@ -52,21 +52,25 @@ declare module "bun" { export namespace RedisClient { type KeyLike = string | ArrayBufferView | Blob; + type StringPubSubListener = (message: string, channel: string) => void; + + // Buffer subscriptions are not yet implemented + // type BufferPubSubListener = (message: Uint8Array, channel: string) => void; } export class RedisClient { /** * Creates a new Redis client - * @param url URL to connect to, defaults to process.env.VALKEY_URL, process.env.REDIS_URL, or "valkey://localhost:6379" + * + * @param url URL to connect to, defaults to `process.env.VALKEY_URL`, + * `process.env.REDIS_URL`, or `"valkey://localhost:6379"` * @param options Additional options * * @example * ```ts - * const valkey = new RedisClient(); - * - * await valkey.set("hello", "world"); - * - * console.log(await valkey.get("hello")); + * const redis = new RedisClient(); + * await redis.set("hello", "world"); + * console.log(await redis.get("hello")); * ``` */ constructor(url?: string, options?: RedisOptions); @@ -88,12 +92,14 @@ declare module "bun" { /** * Callback fired when the client disconnects from the Redis server + * * @param error The error that caused the disconnection */ onclose: ((this: RedisClient, error: Error) => void) | null; /** * Connect to the Redis server + * * @returns A promise that resolves when connected */ connect(): Promise; @@ -152,10 +158,12 @@ declare module "bun" { set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, px: "PX", milliseconds: number): Promise<"OK">; /** - * Set key to hold the string value with expiration at a specific Unix timestamp + * Set key to hold the string value with expiration at a specific Unix + * timestamp * @param key The key to set * @param value The value to set - * @param exat Set the specified Unix time at which the key will expire, in seconds + * @param exat Set the specified Unix time at which the key will expire, in + * seconds * @returns Promise that resolves with "OK" on success */ set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, exat: "EXAT", timestampSeconds: number): Promise<"OK">; @@ -179,7 +187,8 @@ declare module "bun" { * @param key The key to set * @param value The value to set * @param nx Only set the key if it does not already exist - * @returns Promise that resolves with "OK" on success, or null if the key already exists + * @returns Promise that resolves with "OK" on success, or null if the key + * already exists */ set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, nx: "NX"): Promise<"OK" | null>; @@ -188,7 +197,8 @@ declare module "bun" { * @param key The key to set * @param value The value to set * @param xx Only set the key if it already exists - * @returns Promise that resolves with "OK" on success, or null if the key does not exist + * @returns Promise that resolves with "OK" on success, or null if the key + * does not exist */ set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, xx: "XX"): Promise<"OK" | null>; @@ -196,8 +206,10 @@ declare module "bun" { * Set key to hold the string value and return the old value * @param key The key to set * @param value The value to set - * @param get Return the old string stored at key, or null if key did not exist - * @returns Promise that resolves with the old value, or null if key did not exist + * @param get Return the old string stored at key, or null if key did not + * exist + * @returns Promise that resolves with the old value, or null if key did not + * exist */ set(key: RedisClient.KeyLike, value: RedisClient.KeyLike, get: "GET"): Promise; @@ -243,7 +255,8 @@ declare module "bun" { /** * Determine if a key exists * @param key The key to check - * @returns Promise that resolves with true if the key exists, false otherwise + * @returns Promise that resolves with true if the key exists, false + * otherwise */ exists(key: RedisClient.KeyLike): Promise; @@ -258,7 +271,8 @@ declare module "bun" { /** * Get the time to live for a key in seconds * @param key The key to get the TTL for - * @returns Promise that resolves with the TTL, -1 if no expiry, or -2 if key doesn't exist + * @returns Promise that resolves with the TTL, -1 if no expiry, or -2 if + * key doesn't exist */ ttl(key: RedisClient.KeyLike): Promise; @@ -290,7 +304,8 @@ declare module "bun" { * Check if a value is a member of a set * @param key The set key * @param member The member to check - * @returns Promise that resolves with true if the member exists, false otherwise + * @returns Promise that resolves with true if the member exists, false + * otherwise */ sismember(key: RedisClient.KeyLike, member: string): Promise; @@ -298,7 +313,8 @@ declare module "bun" { * Add a member to a set * @param key The set key * @param member The member to add - * @returns Promise that resolves with 1 if the member was added, 0 if it already existed + * @returns Promise that resolves with 1 if the member was added, 0 if it + * already existed */ sadd(key: RedisClient.KeyLike, member: string): Promise; @@ -306,7 +322,8 @@ declare module "bun" { * Remove a member from a set * @param key The set key * @param member The member to remove - * @returns Promise that resolves with 1 if the member was removed, 0 if it didn't exist + * @returns Promise that resolves with 1 if the member was removed, 0 if it + * didn't exist */ srem(key: RedisClient.KeyLike, member: string): Promise; @@ -320,14 +337,16 @@ declare module "bun" { /** * Get a random member from a set * @param key The set key - * @returns Promise that resolves with a random member, or null if the set is empty + * @returns Promise that resolves with a random member, or null if the set + * is empty */ srandmember(key: RedisClient.KeyLike): Promise; /** * Remove and return a random member from a set * @param key The set key - * @returns Promise that resolves with the removed member, or null if the set is empty + * @returns Promise that resolves with the removed member, or null if the + * set is empty */ spop(key: RedisClient.KeyLike): Promise; @@ -394,28 +413,32 @@ declare module "bun" { /** * Remove and get the first element in a list * @param key The list key - * @returns Promise that resolves with the first element, or null if the list is empty + * @returns Promise that resolves with the first element, or null if the + * list is empty */ lpop(key: RedisClient.KeyLike): Promise; /** * Remove the expiration from a key * @param key The key to persist - * @returns Promise that resolves with 1 if the timeout was removed, 0 if the key doesn't exist or has no timeout + * @returns Promise that resolves with 1 if the timeout was removed, 0 if + * the key doesn't exist or has no timeout */ persist(key: RedisClient.KeyLike): Promise; /** * Get the expiration time of a key as a UNIX timestamp in milliseconds * @param key The key to check - * @returns Promise that resolves with the timestamp, or -1 if the key has no expiration, or -2 if the key doesn't exist + * @returns Promise that resolves with the timestamp, or -1 if the key has + * no expiration, or -2 if the key doesn't exist */ pexpiretime(key: RedisClient.KeyLike): Promise; /** * Get the time to live for a key in milliseconds * @param key The key to check - * @returns Promise that resolves with the TTL in milliseconds, or -1 if the key has no expiration, or -2 if the key doesn't exist + * @returns Promise that resolves with the TTL in milliseconds, or -1 if the + * key has no expiration, or -2 if the key doesn't exist */ pttl(key: RedisClient.KeyLike): Promise; @@ -429,42 +452,48 @@ declare module "bun" { /** * Get the number of members in a set * @param key The set key - * @returns Promise that resolves with the cardinality (number of elements) of the set + * @returns Promise that resolves with the cardinality (number of elements) + * of the set */ scard(key: RedisClient.KeyLike): Promise; /** * Get the length of the value stored in a key * @param key The key to check - * @returns Promise that resolves with the length of the string value, or 0 if the key doesn't exist + * @returns Promise that resolves with the length of the string value, or 0 + * if the key doesn't exist */ strlen(key: RedisClient.KeyLike): Promise; /** * Get the number of members in a sorted set * @param key The sorted set key - * @returns Promise that resolves with the cardinality (number of elements) of the sorted set + * @returns Promise that resolves with the cardinality (number of elements) + * of the sorted set */ zcard(key: RedisClient.KeyLike): Promise; /** * Remove and return members with the highest scores in a sorted set * @param key The sorted set key - * @returns Promise that resolves with the removed member and its score, or null if the set is empty + * @returns Promise that resolves with the removed member and its score, or + * null if the set is empty */ zpopmax(key: RedisClient.KeyLike): Promise; /** * Remove and return members with the lowest scores in a sorted set * @param key The sorted set key - * @returns Promise that resolves with the removed member and its score, or null if the set is empty + * @returns Promise that resolves with the removed member and its score, or + * null if the set is empty */ zpopmin(key: RedisClient.KeyLike): Promise; /** * Get one or multiple random members from a sorted set * @param key The sorted set key - * @returns Promise that resolves with a random member, or null if the set is empty + * @returns Promise that resolves with a random member, or null if the set + * is empty */ zrandmember(key: RedisClient.KeyLike): Promise; @@ -472,7 +501,8 @@ declare module "bun" { * Append a value to a key * @param key The key to append to * @param value The value to append - * @returns Promise that resolves with the length of the string after the append operation + * @returns Promise that resolves with the length of the string after the + * append operation */ append(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -480,7 +510,8 @@ declare module "bun" { * Set the value of a key and return its old value * @param key The key to set * @param value The value to set - * @returns Promise that resolves with the old value, or null if the key didn't exist + * @returns Promise that resolves with the old value, or null if the key + * didn't exist */ getset(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -488,7 +519,8 @@ declare module "bun" { * Prepend one or multiple values to a list * @param key The list key * @param value The value to prepend - * @returns Promise that resolves with the length of the list after the push operation + * @returns Promise that resolves with the length of the list after the push + * operation */ lpush(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -496,7 +528,8 @@ declare module "bun" { * Prepend a value to a list, only if the list exists * @param key The list key * @param value The value to prepend - * @returns Promise that resolves with the length of the list after the push operation, or 0 if the list doesn't exist + * @returns Promise that resolves with the length of the list after the push + * operation, or 0 if the list doesn't exist */ lpushx(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -504,7 +537,8 @@ declare module "bun" { * Add one or more members to a HyperLogLog * @param key The HyperLogLog key * @param element The element to add - * @returns Promise that resolves with 1 if the HyperLogLog was altered, 0 otherwise + * @returns Promise that resolves with 1 if the HyperLogLog was altered, 0 + * otherwise */ pfadd(key: RedisClient.KeyLike, element: string): Promise; @@ -512,7 +546,8 @@ declare module "bun" { * Append one or multiple values to a list * @param key The list key * @param value The value to append - * @returns Promise that resolves with the length of the list after the push operation + * @returns Promise that resolves with the length of the list after the push + * operation */ rpush(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -520,7 +555,8 @@ declare module "bun" { * Append a value to a list, only if the list exists * @param key The list key * @param value The value to append - * @returns Promise that resolves with the length of the list after the push operation, or 0 if the list doesn't exist + * @returns Promise that resolves with the length of the list after the push + * operation, or 0 if the list doesn't exist */ rpushx(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -528,7 +564,8 @@ declare module "bun" { * Set the value of a key, only if the key does not exist * @param key The key to set * @param value The value to set - * @returns Promise that resolves with 1 if the key was set, 0 if the key was not set + * @returns Promise that resolves with 1 if the key was set, 0 if the key + * was not set */ setnx(key: RedisClient.KeyLike, value: RedisClient.KeyLike): Promise; @@ -536,14 +573,16 @@ declare module "bun" { * Get the score associated with the given member in a sorted set * @param key The sorted set key * @param member The member to get the score for - * @returns Promise that resolves with the score of the member as a string, or null if the member or key doesn't exist + * @returns Promise that resolves with the score of the member as a string, + * or null if the member or key doesn't exist */ zscore(key: RedisClient.KeyLike, member: string): Promise; /** * Get the values of all specified keys * @param keys The keys to get - * @returns Promise that resolves with an array of values, with null for keys that don't exist + * @returns Promise that resolves with an array of values, with null for + * keys that don't exist */ mget(...keys: RedisClient.KeyLike[]): Promise<(string | null)[]>; @@ -557,37 +596,46 @@ declare module "bun" { /** * Return a serialized version of the value stored at the specified key * @param key The key to dump - * @returns Promise that resolves with the serialized value, or null if the key doesn't exist + * @returns Promise that resolves with the serialized value, or null if the + * key doesn't exist */ dump(key: RedisClient.KeyLike): Promise; /** * Get the expiration time of a key as a UNIX timestamp in seconds + * * @param key The key to check - * @returns Promise that resolves with the timestamp, or -1 if the key has no expiration, or -2 if the key doesn't exist + * @returns Promise that resolves with the timestamp, or -1 if the key has + * no expiration, or -2 if the key doesn't exist */ expiretime(key: RedisClient.KeyLike): Promise; /** * Get the value of a key and delete the key + * * @param key The key to get and delete - * @returns Promise that resolves with the value of the key, or null if the key doesn't exist + * @returns Promise that resolves with the value of the key, or null if the + * key doesn't exist */ getdel(key: RedisClient.KeyLike): Promise; /** * Get the value of a key and optionally set its expiration + * * @param key The key to get - * @returns Promise that resolves with the value of the key, or null if the key doesn't exist + * @returns Promise that resolves with the value of the key, or null if the + * key doesn't exist */ getex(key: RedisClient.KeyLike): Promise; /** * Get the value of a key and set its expiration in seconds + * * @param key The key to get * @param ex Set the specified expire time, in seconds * @param seconds The number of seconds until expiration - * @returns Promise that resolves with the value of the key, or null if the key doesn't exist + * @returns Promise that resolves with the value of the key, or null if the + * key doesn't exist */ getex(key: RedisClient.KeyLike, ex: "EX", seconds: number): Promise; @@ -602,6 +650,7 @@ declare module "bun" { /** * Get the value of a key and set its expiration at a specific Unix timestamp in seconds + * * @param key The key to get * @param exat Set the specified Unix time at which the key will expire, in seconds * @param timestampSeconds The Unix timestamp in seconds @@ -611,6 +660,7 @@ declare module "bun" { /** * Get the value of a key and set its expiration at a specific Unix timestamp in milliseconds + * * @param key The key to get * @param pxat Set the specified Unix time at which the key will expire, in milliseconds * @param timestampMilliseconds The Unix timestamp in milliseconds @@ -620,6 +670,7 @@ declare module "bun" { /** * Get the value of a key and remove its expiration + * * @param key The key to get * @param persist Remove the expiration from the key * @returns Promise that resolves with the value of the key, or null if the key doesn't exist @@ -634,10 +685,133 @@ declare module "bun" { /** * Ping the server with a message + * * @param message The message to send to the server * @returns Promise that resolves with the message if the server is reachable, or throws an error if the server is not reachable */ ping(message: RedisClient.KeyLike): Promise; + + /** + * Publish a message to a Redis channel. + * + * @param channel The channel to publish to. + * @param message The message to publish. + * + * @returns The number of clients that received the message. Note that in a + * cluster this returns the total number of clients in the same node. + */ + publish(channel: string, message: string): Promise; + + /** + * Subscribe to a Redis channel. + * + * Subscribing disables automatic pipelining, so all commands will be + * received immediately. + * + * Subscribing moves the channel to a dedicated subscription state which + * prevents most other commands from being executed until unsubscribed. Only + * {@link ping `.ping()`}, {@link subscribe `.subscribe()`}, and + * {@link unsubscribe `.unsubscribe()`} are legal to invoke in a subscribed + * upon channel. + * + * @param channel The channel to subscribe to. + * @param listener The listener to call when a message is received on the + * channel. The listener will receive the message as the first argument and + * the channel as the second argument. + * + * @example + * ```ts + * await client.subscribe("my-channel", (message, channel) => { + * console.log(`Received message on ${channel}: ${message}`); + * }); + * ``` + */ + subscribe(channel: string, listener: RedisClient.StringPubSubListener): Promise; + + /** + * Subscribe to multiple Redis channels. + * + * Subscribing disables automatic pipelining, so all commands will be + * received immediately. + * + * Subscribing moves the channels to a dedicated subscription state in which + * only a limited set of commands can be executed. + * + * @param channels An array of channels to subscribe to. + * @param listener The listener to call when a message is received on any of + * the subscribed channels. The listener will receive the message as the + * first argument and the channel as the second argument. + */ + subscribe(channels: string[], listener: RedisClient.StringPubSubListener): Promise; + + /** + * Unsubscribe from a singular Redis channel. + * + * @param channel The channel to unsubscribe from. + * + * If there are no more channels subscribed to, the client automatically + * re-enables pipelining if it was previously enabled. + * + * Unsubscribing moves the channel back to a normal state out of the + * subscription state if all channels have been unsubscribed from. For + * further details on the subscription state, see + * {@link subscribe `.subscribe()`}. + */ + unsubscribe(channel: string): Promise; + + /** + * Remove a listener from a given Redis channel. + * + * If there are no more channels subscribed to, the client automatically + * re-enables pipelining if it was previously enabled. + * + * Unsubscribing moves the channel back to a normal state out of the + * subscription state if all channels have been unsubscribed from. For + * further details on the subscription state, see + * {@link subscribe `.subscribe()`}. + * + * @param channel The channel to unsubscribe from. + * @param listener The listener to remove. This is tested against + * referential equality so you must pass the exact same listener instance as + * when subscribing. + */ + unsubscribe(channel: string, listener: RedisClient.StringPubSubListener): Promise; + + /** + * Unsubscribe from all registered Redis channels. + * + * The client will automatically re-enable pipelining if it was previously + * enabled. + * + * Unsubscribing moves the channel back to a normal state out of the + * subscription state if all channels have been unsubscribed from. For + * further details on the subscription state, see + * {@link subscribe `.subscribe()`}. + */ + unsubscribe(): Promise; + + /** + * Unsubscribe from multiple Redis channels. + * + * @param channels An array of channels to unsubscribe from. + * + * If there are no more channels subscribed to, the client automatically + * re-enables pipelining if it was previously enabled. + * + * Unsubscribing moves the channel back to a normal state out of the + * subscription state if all channels have been unsubscribed from. For + * further details on the subscription state, see + * {@link subscribe `.subscribe()`}. + */ + unsubscribe(channels: string[]): Promise; + + /** + * @brief Create a new RedisClient instance with the same configuration as + * the current instance. + * + * This will open up a new connection to the Redis server. + */ + duplicate(): Promise; } /** diff --git a/src/bun.js/api/BunObject.zig b/src/bun.js/api/BunObject.zig index 8cd0dfe71e..a5af653677 100644 --- a/src/bun.js/api/BunObject.zig +++ b/src/bun.js/api/BunObject.zig @@ -1294,7 +1294,7 @@ pub fn setTLSDefaultCiphers(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject, c } pub fn getValkeyDefaultClient(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) jsc.JSValue { - const valkey = jsc.API.Valkey.create(globalThis, &.{.js_undefined}) catch |err| { + var valkey = jsc.API.Valkey.createNoJs(globalThis, &.{.js_undefined}) catch |err| { if (err != error.JSError) { _ = globalThis.throwError(err, "Failed to create Redis client") catch {}; return .zero; @@ -1302,7 +1302,11 @@ pub fn getValkeyDefaultClient(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) return .zero; }; - return valkey.toJS(globalThis); + const as_js = valkey.toJS(globalThis); + + valkey.this_value = jsc.JSRef.initWeak(as_js); + + return as_js; } pub fn getValkeyClientConstructor(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) jsc.JSValue { diff --git a/src/bun.js/api/valkey.classes.ts b/src/bun.js/api/valkey.classes.ts index abcfeaa395..3828db06d1 100644 --- a/src/bun.js/api/valkey.classes.ts +++ b/src/bun.js/api/valkey.classes.ts @@ -4,6 +4,7 @@ export default [ define({ name: "RedisClient", construct: true, + constructNeedsThis: true, call: false, finalize: true, configurable: false, @@ -226,11 +227,12 @@ export default [ zrank: { fn: "zrank" }, zrevrank: { fn: "zrevrank" }, subscribe: { fn: "subscribe" }, + duplicate: { fn: "duplicate" }, psubscribe: { fn: "psubscribe" }, unsubscribe: { fn: "unsubscribe" }, punsubscribe: { fn: "punsubscribe" }, pubsub: { fn: "pubsub" }, }, - values: ["onconnect", "onclose", "connectionPromise", "hello"], + values: ["onconnect", "onclose", "connectionPromise", "hello", "subscriptionCallbackMap"], }), ]; diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index 37a9ce660b..ecd0bb332c 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -276,24 +276,25 @@ const errors: ErrorCodeMapping = [ ["ERR_OSSL_EVP_INVALID_DIGEST", Error], ["ERR_KEY_GENERATION_JOB_FAILED", Error], ["ERR_MISSING_OPTION", TypeError], - ["ERR_REDIS_CONNECTION_CLOSED", Error, "RedisError"], - ["ERR_REDIS_INVALID_RESPONSE", Error, "RedisError"], - ["ERR_REDIS_INVALID_BULK_STRING", Error, "RedisError"], - ["ERR_REDIS_INVALID_ARRAY", Error, "RedisError"], - ["ERR_REDIS_INVALID_INTEGER", Error, "RedisError"], - ["ERR_REDIS_INVALID_SIMPLE_STRING", Error, "RedisError"], - ["ERR_REDIS_INVALID_ERROR_STRING", Error, "RedisError"], - ["ERR_REDIS_TLS_NOT_AVAILABLE", Error, "RedisError"], - ["ERR_REDIS_TLS_UPGRADE_FAILED", Error, "RedisError"], ["ERR_REDIS_AUTHENTICATION_FAILED", Error, "RedisError"], - ["ERR_REDIS_INVALID_PASSWORD", Error, "RedisError"], - ["ERR_REDIS_INVALID_USERNAME", Error, "RedisError"], - ["ERR_REDIS_INVALID_DATABASE", Error, "RedisError"], - ["ERR_REDIS_INVALID_COMMAND", Error, "RedisError"], - ["ERR_REDIS_INVALID_ARGUMENT", Error, "RedisError"], - ["ERR_REDIS_INVALID_RESPONSE_TYPE", Error, "RedisError"], + ["ERR_REDIS_CONNECTION_CLOSED", Error, "RedisError"], ["ERR_REDIS_CONNECTION_TIMEOUT", Error, "RedisError"], ["ERR_REDIS_IDLE_TIMEOUT", Error, "RedisError"], + ["ERR_REDIS_INVALID_ARGUMENT", Error, "RedisError"], + ["ERR_REDIS_INVALID_ARRAY", Error, "RedisError"], + ["ERR_REDIS_INVALID_BULK_STRING", Error, "RedisError"], + ["ERR_REDIS_INVALID_COMMAND", Error, "RedisError"], + ["ERR_REDIS_INVALID_DATABASE", Error, "RedisError"], + ["ERR_REDIS_INVALID_ERROR_STRING", Error, "RedisError"], + ["ERR_REDIS_INVALID_INTEGER", Error, "RedisError"], + ["ERR_REDIS_INVALID_PASSWORD", Error, "RedisError"], + ["ERR_REDIS_INVALID_RESPONSE", Error, "RedisError"], + ["ERR_REDIS_INVALID_RESPONSE_TYPE", Error, "RedisError"], + ["ERR_REDIS_INVALID_SIMPLE_STRING", Error, "RedisError"], + ["ERR_REDIS_INVALID_STATE", Error, "RedisError"], + ["ERR_REDIS_INVALID_USERNAME", Error, "RedisError"], + ["ERR_REDIS_TLS_NOT_AVAILABLE", Error, "RedisError"], + ["ERR_REDIS_TLS_UPGRADE_FAILED", Error, "RedisError"], ["HPE_UNEXPECTED_CONTENT_LENGTH", Error], ["HPE_INVALID_TRANSFER_ENCODING", Error], ["HPE_INVALID_EOF_STATE", Error], diff --git a/src/bun.js/bindings/JSGlobalObject.zig b/src/bun.js/bindings/JSGlobalObject.zig index 5642205a9d..101ccc8781 100644 --- a/src/bun.js/bindings/JSGlobalObject.zig +++ b/src/bun.js/bindings/JSGlobalObject.zig @@ -367,6 +367,10 @@ pub const JSGlobalObject = opaque { return this.throwValue(err); } + /// Throw an Error from a formatted string. + /// + /// Note: If you are throwing an error within somewhere in the Bun API, + /// chances are you should be using `.ERR(...).throw()` instead. pub fn throw(this: *JSGlobalObject, comptime fmt: [:0]const u8, args: anytype) JSError { const instance = this.createErrorInstance(fmt, args); bun.assert(instance != .zero); @@ -789,6 +793,9 @@ pub const JSGlobalObject = opaque { return .{ .globalObject = this }; } + /// Throw an error from within the Bun runtime. + /// + /// The set of errors accepted by `ERR()` is defined in `ErrorCode.ts`. pub fn ERR(global: *JSGlobalObject, comptime code: jsc.Error, comptime fmt: [:0]const u8, args: anytype) @import("ErrorCode").ErrorBuilder(code, fmt, @TypeOf(args)) { return .{ .global = global, .args = args }; } diff --git a/src/bun.js/bindings/JSMap.zig b/src/bun.js/bindings/JSMap.zig index 5c4ce35be9..0abbd06173 100644 --- a/src/bun.js/bindings/JSMap.zig +++ b/src/bun.js/bindings/JSMap.zig @@ -1,34 +1,26 @@ +/// Opaque type for working with JavaScript `Map` objects. pub const JSMap = opaque { - extern fn JSC__JSMap__create(*JSGlobalObject) JSValue; + pub const create = bun.cpp.JSC__JSMap__create; + pub const set = bun.cpp.JSC__JSMap__set; - pub fn create(globalObject: *JSGlobalObject) JSValue { - return JSC__JSMap__create(globalObject); - } + /// Retrieve a value from this JS Map object. + /// + /// Note this shares semantics with the JS `Map.prototype.get` method, and + /// will return .js_undefined if a value is not found. + pub const get = bun.cpp.JSC__JSMap__get; - pub fn set(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue, value: JSValue) void { - return bun.cpp.JSC__JSMap__set(this, globalObject, key, value); - } + /// Test whether this JS Map object has a given key. + pub const has = bun.cpp.JSC__JSMap__has; - pub fn get_(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue) JSValue { - return bun.cpp.JSC__JSMap__get_(this, globalObject, key); - } + /// Attempt to remove a key from this JS Map object. + pub const remove = bun.cpp.JSC__JSMap__remove; - pub fn get(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue) ?JSValue { - const value = get_(this, globalObject, key); - if (value.isEmpty()) { - return null; - } - return value; - } - - pub fn has(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue) bool { - return bun.cpp.JSC__JSMap__has(this, globalObject, key); - } - - pub fn remove(this: *JSMap, globalObject: *JSGlobalObject, key: JSValue) bool { - return bun.cpp.JSC__JSMap__remove(this, globalObject, key); - } + /// Retrieve the number of entries in this JS Map object. + pub const size = bun.cpp.JSC__JSMap__size; + /// Attempt to convert a `JSValue` to a `*JSMap`. + /// + /// Returns `null` if the value is not a Map. pub fn fromJS(value: JSValue) ?*JSMap { if (value.jsTypeLoose() == .Map) { return bun.cast(*JSMap, value.asEncoded().asPtr.?); @@ -41,5 +33,4 @@ pub const JSMap = opaque { const bun = @import("bun"); const jsc = bun.jsc; -const JSGlobalObject = jsc.JSGlobalObject; const JSValue = jsc.JSValue; diff --git a/src/bun.js/bindings/JSPromise.zig b/src/bun.js/bindings/JSPromise.zig index 35a581b106..3a0313241b 100644 --- a/src/bun.js/bindings/JSPromise.zig +++ b/src/bun.js/bindings/JSPromise.zig @@ -220,6 +220,9 @@ pub const JSPromise = opaque { bun.cpp.JSC__JSPromise__setHandled(this, vm); } + /// Create a new resolved promise resolving to a given value. + /// + /// Note: If you want the result as a JSValue, use `JSPromise.resolvedPromiseValue` instead. pub fn resolvedPromise(globalThis: *JSGlobalObject, value: JSValue) *JSPromise { return JSC__JSPromise__resolvedPromise(globalThis, value); } @@ -230,6 +233,9 @@ pub const JSPromise = opaque { return JSC__JSPromise__resolvedPromiseValue(globalThis, value); } + /// Create a new rejected promise rejecting to a given value. + /// + /// Note: If you want the result as a JSValue, use `JSPromise.rejectedPromiseValue` instead. pub fn rejectedPromise(globalThis: *JSGlobalObject, value: JSValue) *JSPromise { return JSC__JSPromise__rejectedPromise(globalThis, value); } @@ -275,6 +281,11 @@ pub const JSPromise = opaque { bun.cpp.JSC__JSPromise__rejectAsHandled(this, globalThis, value) catch return bun.debugAssert(false); // TODO: properly propagate exception upwards } + /// Create a new pending promise. + /// + /// Note: You should use `JSPromise.resolvedPromise` or + /// `JSPromise.rejectedPromise` if you want to create a promise that + /// is already resolved or rejected. pub fn create(globalThis: *JSGlobalObject) *JSPromise { return JSC__JSPromise__create(globalThis); } diff --git a/src/bun.js/bindings/JSRef.zig b/src/bun.js/bindings/JSRef.zig index c76a4056eb..08928aa73e 100644 --- a/src/bun.js/bindings/JSRef.zig +++ b/src/bun.js/bindings/JSRef.zig @@ -1,3 +1,7 @@ +/// Holds a reference to a JSValue. +/// +/// This reference can be either weak (a JSValue) or may be strong, in which +/// case it prevents the garbage collector from collecting the value. pub const JSRef = union(enum) { weak: jsc.JSValue, strong: jsc.Strong.Optional, @@ -91,6 +95,11 @@ pub const JSRef = union(enum) { }; } + /// Test whether this reference is a strong reference. + pub fn isStrong(this: *const @This()) bool { + return this.* == .strong; + } + pub fn deinit(this: *@This()) void { switch (this.*) { .weak => { diff --git a/src/bun.js/bindings/bindings.cpp b/src/bun.js/bindings/bindings.cpp index 496c5e14f8..6111976187 100644 --- a/src/bun.js/bindings/bindings.cpp +++ b/src/bun.js/bindings/bindings.cpp @@ -1,3 +1,12 @@ +/** + * Source code for JavaScriptCore bindings used by bind. + * + * This file is processed by cppbind.ts. + * + * @see cppbind.ts holds helpful tips on how to add and implement new bindings. + * Note that cppbind.ts also automatically runs some error-checking which + * can be disabled if necessary. Consult cppbind.ts for details. + */ #include "root.h" #include "JavaScriptCore/ErrorType.h" @@ -37,6 +46,7 @@ #include "JavaScriptCore/JSArrayInlines.h" #include "JavaScriptCore/ErrorInstanceInlines.h" #include "JavaScriptCore/BigIntObject.h" +#include "JavaScriptCore/OrderedHashTableHelper.h" #include "JavaScriptCore/JSCallbackObject.h" #include "JavaScriptCore/JSClassRef.h" @@ -6463,32 +6473,49 @@ CPP_DECL WebCore::DOMFormData* WebCore__DOMFormData__fromJS(JSC::EncodedJSValue #pragma mark - JSC::JSMap -CPP_DECL JSC::EncodedJSValue JSC__JSMap__create(JSC::JSGlobalObject* arg0) +CPP_DECL [[ZIG_EXPORT(nothrow)]] JSC::EncodedJSValue JSC__JSMap__create(JSC::JSGlobalObject* arg0) { - JSC::JSMap* map = JSC::JSMap::create(arg0->vm(), arg0->mapStructure()); - return JSC::JSValue::encode(map); + return JSC::JSValue::encode(JSC::JSMap::create(arg0->vm(), arg0->mapStructure())); } -CPP_DECL [[ZIG_EXPORT(nothrow)]] JSC::EncodedJSValue JSC__JSMap__get_(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) -{ - JSC::JSValue value = JSC::JSValue::decode(JSValue2); - return JSC::JSValue::encode(map->get(arg1, value)); -} -CPP_DECL [[ZIG_EXPORT(nothrow)]] bool JSC__JSMap__has(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) +// JSMap::get never returns JSValue::zero, even in the case of an exception. The +// best we can, therefore, do is manually test for exceptions. +// NOLINTNEXTLINE(bun-bindgen-force-zero_is_throw-for-jsvalue) +CPP_DECL [[ZIG_EXPORT(zero_is_throw)]] JSC::EncodedJSValue JSC__JSMap__get(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) { - JSC::JSValue value = JSC::JSValue::decode(JSValue2); + auto& vm = JSC::getVM(arg1); + const JSC::JSValue key = JSC::JSValue::decode(JSValue2); + + // JSMap::get never returns JSValue::zero, even in the case of an exception. + // It will return JSValue::undefined and set an exception on the VM. + auto scope = DECLARE_THROW_SCOPE(vm); + const JSValue value = map->get(arg1, key); + RETURN_IF_EXCEPTION(scope, {}); + return JSC::JSValue::encode(value); +} + +CPP_DECL [[ZIG_EXPORT(check_slow)]] bool JSC__JSMap__has(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) +{ + const JSC::JSValue value = JSC::JSValue::decode(JSValue2); return map->has(arg1, value); } -CPP_DECL [[ZIG_EXPORT(nothrow)]] bool JSC__JSMap__remove(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) + +CPP_DECL [[ZIG_EXPORT(check_slow)]] bool JSC__JSMap__remove(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2) { - JSC::JSValue value = JSC::JSValue::decode(JSValue2); + const JSC::JSValue value = JSC::JSValue::decode(JSValue2); return map->remove(arg1, value); } -CPP_DECL [[ZIG_EXPORT(nothrow)]] void JSC__JSMap__set(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2, JSC::EncodedJSValue JSValue3) + +CPP_DECL [[ZIG_EXPORT(check_slow)]] void JSC__JSMap__set(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2, JSC::EncodedJSValue JSValue3) { map->set(arg1, JSC::JSValue::decode(JSValue2), JSC::JSValue::decode(JSValue3)); } +CPP_DECL [[ZIG_EXPORT(check_slow)]] uint32_t JSC__JSMap__size(JSC::JSMap* map, JSC::JSGlobalObject* arg1) +{ + return map->size(); +} + CPP_DECL void JSC__VM__setControlFlowProfiler(JSC::VM* vm, bool isEnabled) { if (isEnabled) { diff --git a/src/bun.js/bindings/headers.h b/src/bun.js/bindings/headers.h index 4f3d3454b4..f02d203054 100644 --- a/src/bun.js/bindings/headers.h +++ b/src/bun.js/bindings/headers.h @@ -8,7 +8,7 @@ #define AUTO_EXTERN_C extern "C" #ifdef WIN32 - #define AUTO_EXTERN_C_ZIG extern "C" + #define AUTO_EXTERN_C_ZIG extern "C" #else #define AUTO_EXTERN_C_ZIG extern "C" __attribute__((weak)) #endif @@ -129,7 +129,7 @@ CPP_DECL void WebCore__AbortSignal__cleanNativeBindings(WebCore::AbortSignal* ar CPP_DECL JSC::EncodedJSValue WebCore__AbortSignal__create(JSC::JSGlobalObject* arg0); CPP_DECL WebCore::AbortSignal* WebCore__AbortSignal__fromJS(JSC::EncodedJSValue JSValue0); CPP_DECL WebCore::AbortSignal* WebCore__AbortSignal__ref(WebCore::AbortSignal* arg0); -CPP_DECL WebCore::AbortSignal* WebCore__AbortSignal__signal(WebCore::AbortSignal* arg0, JSC::JSGlobalObject*, uint8_t abortReason); +CPP_DECL WebCore::AbortSignal* WebCore__AbortSignal__signal(WebCore::AbortSignal* arg0, JSC::JSGlobalObject*, uint8_t abortReason); CPP_DECL JSC::EncodedJSValue WebCore__AbortSignal__toJS(WebCore::AbortSignal* arg0, JSC::JSGlobalObject* arg1); CPP_DECL void WebCore__AbortSignal__unref(WebCore::AbortSignal* arg0); @@ -186,10 +186,11 @@ CPP_DECL JSC::VM* JSC__JSGlobalObject__vm(JSC::JSGlobalObject* arg0); #pragma mark - JSC::JSMap CPP_DECL JSC::EncodedJSValue JSC__JSMap__create(JSC::JSGlobalObject* arg0); -CPP_DECL JSC::EncodedJSValue JSC__JSMap__get_(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2); +CPP_DECL JSC::EncodedJSValue JSC__JSMap__get(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2); CPP_DECL bool JSC__JSMap__has(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2); CPP_DECL bool JSC__JSMap__remove(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2); CPP_DECL void JSC__JSMap__set(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2, JSC::EncodedJSValue JSValue3); +CPP_DECL uint32_t JSC__JSMap__size(JSC::JSMap* arg0, JSC::JSGlobalObject* arg1); #pragma mark - JSC::JSValue diff --git a/src/bun.js/event_loop.zig b/src/bun.js/event_loop.zig index dd4451e81f..b84f9c7d41 100644 --- a/src/bun.js/event_loop.zig +++ b/src/bun.js/event_loop.zig @@ -57,12 +57,25 @@ pub const Debug = if (Environment.isDebug) struct { pub inline fn exit(_: Debug) void {} }; +/// Before your code enters JavaScript at the top of the event loop, call +/// `loop.enter()`. If running a single callback, prefer `runCallback` instead. +/// +/// When we call into JavaScript, we must drain process.nextTick & microtasks +/// afterwards (so that promises run). We must only do that once per task in the +/// event loop. To make that work, we count enter/exit calls and once that +/// counter reaches 0, we drain the microtasks. +/// +/// This function increments the counter for the number of times we've entered +/// the event loop. pub fn enter(this: *EventLoop) void { log("enter() = {d}", .{this.entered_event_loop_count}); this.entered_event_loop_count += 1; this.debug.enter(); } +/// "exit" a microtask context in the event loop. +/// +/// See the documentation for `enter` for more information. pub fn exit(this: *EventLoop) void { const count = this.entered_event_loop_count; log("exit() = {d}", .{count - 1}); diff --git a/src/deps/uws/us_socket_t.zig b/src/deps/uws/us_socket_t.zig index 0af82ca70e..a67efb9286 100644 --- a/src/deps/uws/us_socket_t.zig +++ b/src/deps/uws/us_socket_t.zig @@ -13,7 +13,7 @@ pub const us_socket_t = opaque { }; pub fn open(this: *us_socket_t, comptime is_ssl: bool, is_client: bool, ip_addr: ?[]const u8) void { - debug("us_socket_open({d}, is_client: {})", .{ @intFromPtr(this), is_client }); + debug("us_socket_open({p}, is_client: {})", .{ this, is_client }); const ssl = @intFromBool(is_ssl); if (ip_addr) |ip| { @@ -25,22 +25,22 @@ pub const us_socket_t = opaque { } pub fn pause(this: *us_socket_t, ssl: bool) void { - debug("us_socket_pause({d})", .{@intFromPtr(this)}); + debug("us_socket_pause({p})", .{this}); c.us_socket_pause(@intFromBool(ssl), this); } pub fn @"resume"(this: *us_socket_t, ssl: bool) void { - debug("us_socket_resume({d})", .{@intFromPtr(this)}); + debug("us_socket_resume({p})", .{this}); c.us_socket_resume(@intFromBool(ssl), this); } pub fn close(this: *us_socket_t, ssl: bool, code: CloseCode) void { - debug("us_socket_close({d}, {s})", .{ @intFromPtr(this), @tagName(code) }); + debug("us_socket_close({p}, {s})", .{ this, @tagName(code) }); _ = c.us_socket_close(@intFromBool(ssl), this, code, null); } pub fn shutdown(this: *us_socket_t, ssl: bool) void { - debug("us_socket_shutdown({d})", .{@intFromPtr(this)}); + debug("us_socket_shutdown({p})", .{this}); c.us_socket_shutdown(@intFromBool(ssl), this); } @@ -128,25 +128,25 @@ pub const us_socket_t = opaque { pub fn write(this: *us_socket_t, ssl: bool, data: []const u8) i32 { const rc = c.us_socket_write(@intFromBool(ssl), this, data.ptr, @intCast(data.len)); - debug("us_socket_write({d}, {d}) = {d}", .{ @intFromPtr(this), data.len, rc }); + debug("us_socket_write({p}, {d}) = {d}", .{ this, data.len, rc }); return rc; } pub fn writeFd(this: *us_socket_t, data: []const u8, file_descriptor: bun.FD) i32 { if (bun.Environment.isWindows) @compileError("TODO: implement writeFd on Windows"); const rc = c.us_socket_ipc_write_fd(this, data.ptr, @intCast(data.len), file_descriptor.native()); - debug("us_socket_ipc_write_fd({d}, {d}, {d}) = {d}", .{ @intFromPtr(this), data.len, file_descriptor.native(), rc }); + debug("us_socket_ipc_write_fd({p}, {d}, {d}) = {d}", .{ this, data.len, file_descriptor.native(), rc }); return rc; } pub fn write2(this: *us_socket_t, ssl: bool, first: []const u8, second: []const u8) i32 { const rc = c.us_socket_write2(@intFromBool(ssl), this, first.ptr, first.len, second.ptr, second.len); - debug("us_socket_write2({d}, {d}, {d}) = {d}", .{ @intFromPtr(this), first.len, second.len, rc }); + debug("us_socket_write2({p}, {d}, {d}) = {d}", .{ this, first.len, second.len, rc }); return rc; } pub fn rawWrite(this: *us_socket_t, ssl: bool, data: []const u8) i32 { - debug("us_socket_raw_write({d}, {d})", .{ @intFromPtr(this), data.len }); + debug("us_socket_raw_write({p}, {d})", .{ this, data.len }); return c.us_socket_raw_write(@intFromBool(ssl), this, data.ptr, @intCast(data.len)); } diff --git a/src/memory.zig b/src/memory.zig index 47e54a7a65..5ed890ee3e 100644 --- a/src/memory.zig +++ b/src/memory.zig @@ -75,6 +75,36 @@ pub fn deinit(ptr_or_slice: anytype) void { } } +/// Rebase a slice from one memory buffer to another buffer. +/// +/// Given a slice which points into a memory buffer with base `old_base`, return +/// a slice which points to the same offset in a new memory buffer with base +/// `new_base`, preserving the length of the slice. +/// +/// +/// ``` +/// const old_base = [6]u8{}; +/// assert(@ptrToInt(&old_base) == 0x32); +/// +/// 0x32 0x33 0x34 0x35 0x36 0x37 +/// old_base |????|????|????|????|????|????| +/// ^ +/// |<-- slice --->| +/// +/// const new_base = [6]u8{}; +/// assert(@ptrToInt(&new_base) == 0x74); +/// const output = rebaseSlice(slice, old_base, new_base) +/// +/// 0x74 0x75 0x76 0x77 0x78 0x79 +/// new_base |????|????|????|????|????|????| +/// ^ +/// |<-- output -->| +/// ``` +pub fn rebaseSlice(slice: []const u8, old_base: [*]const u8, new_base: [*]const u8) []const u8 { + const offset = @intFromPtr(slice.ptr) - @intFromPtr(old_base); + return new_base[offset..][0..slice.len]; +} + const std = @import("std"); const Allocator = std.mem.Allocator; diff --git a/src/string/StringBuilder.zig b/src/string/StringBuilder.zig index 91e8b1f250..ae89ee4f74 100644 --- a/src/string/StringBuilder.zig +++ b/src/string/StringBuilder.zig @@ -236,6 +236,15 @@ pub fn writable(this: *StringBuilder) []u8 { return ptr[this.len..this.cap]; } +/// Transfer ownership of the underlying memory to a slice. +/// +/// After calling this, you are responsible for freeing the underlying memory. +/// This StringBuilder should not be used after calling this function. +pub fn moveToSlice(this: *StringBuilder, into_slice: *[]u8) void { + into_slice.* = this.allocatedSlice(); + this.* = .{}; +} + const std = @import("std"); const Allocator = std.mem.Allocator; diff --git a/src/valkey/ValkeyCommand.zig b/src/valkey/ValkeyCommand.zig index dfc9c448e6..2349ecad4a 100644 --- a/src/valkey/ValkeyCommand.zig +++ b/src/valkey/ValkeyCommand.zig @@ -137,7 +137,7 @@ pub const Promise = struct { self.promise.resolve(globalObject, js_value); } - pub fn reject(self: *Promise, globalObject: *jsc.JSGlobalObject, jsvalue: jsc.JSValue) void { + pub fn reject(self: *Promise, globalObject: *jsc.JSGlobalObject, jsvalue: JSError!jsc.JSValue) void { self.promise.reject(globalObject, jsvalue); } @@ -162,6 +162,7 @@ const protocol = @import("./valkey_protocol.zig"); const std = @import("std"); const bun = @import("bun"); +const JSError = bun.JSError; const jsc = bun.jsc; const node = bun.api.node; const Slice = jsc.ZigString.Slice; diff --git a/src/valkey/js_valkey.zig b/src/valkey/js_valkey.zig index edde503c59..65406c5b48 100644 --- a/src/valkey/js_valkey.zig +++ b/src/valkey/js_valkey.zig @@ -1,9 +1,233 @@ +pub const SubscriptionCtx = struct { + const Self = @This(); + + // TODO(markovejnovic): Consider using refactoring this to use + // @fieldParentPtr. The reason this was not implemented is because there is + // no support for optional fields yet. + // + // See: https://github.com/ziglang/zig/issues/25241 + // + // An alternative is to hold a flag within the context itself, indicating + // whether it is active or not, but that feels less clean. + _parent: *JSValkeyClient, + + original_enable_offline_queue: bool, + original_enable_auto_pipelining: bool, + + const ParentJS = JSValkeyClient.js; + + pub fn init(parent: *JSValkeyClient, enable_offline_queue: bool, enable_auto_pipelining: bool) bun.JSError!Self { + const callback_map = jsc.JSMap.create(parent.globalObject); + const parent_this = parent.this_value.tryGet() orelse unreachable; + + ParentJS.gc.set(.subscriptionCallbackMap, parent_this, parent.globalObject, callback_map); + + const self = Self{ + ._parent = parent, + .original_enable_offline_queue = enable_offline_queue, + .original_enable_auto_pipelining = enable_auto_pipelining, + }; + return self; + } + + fn subscriptionCallbackMap(this: *Self) *jsc.JSMap { + const parent_this = this._parent.this_value.tryGet() orelse unreachable; + + const value_js = ParentJS.gc.get(.subscriptionCallbackMap, parent_this).?; + return jsc.JSMap.fromJS(value_js).?; + } + + /// Get the total number of channels that this subscription context is subscribed to. + pub fn channelsSubscribedToCount(this: *Self, globalObject: *jsc.JSGlobalObject) bun.JSError!u32 { + return this.subscriptionCallbackMap().size(globalObject); + } + + /// Test whether this context has any subscriptions. It is mandatory to + /// guard deinit with this function. + pub fn hasSubscriptions(this: *Self, globalObject: *jsc.JSGlobalObject) bun.JSError!bool { + return (try this.channelsSubscribedToCount(globalObject)) > 0; + } + + pub fn clearReceiveHandlers( + this: *Self, + globalObject: *jsc.JSGlobalObject, + channelName: JSValue, + ) bun.JSError!void { + const map = this.subscriptionCallbackMap(); + _ = try map.remove(globalObject, channelName); + } + + /// Remove a specific receive handler. + /// + /// Returns: The total number of remaining handlers for this channel, or null if here were no listeners originally + /// registered. + /// + /// Note: This function will empty out the map entry if there are no more handlers registered. + pub fn removeReceiveHandler( + this: *Self, + globalObject: *jsc.JSGlobalObject, + channelName: JSValue, + callback: JSValue, + ) !?usize { + const map = this.subscriptionCallbackMap(); + + const existing = try map.get(globalObject, channelName); + if (existing.isUndefinedOrNull()) { + // Nothing to remove. + return null; + } + + // Existing is guaranteed to be an array of callbacks. + // This check is necessary because crossing between Zig and C++ is necessary because Zig doesn't know that C++ + // is side-effect-free. + if (comptime bun.Environment.isDebug) { + bun.assert(existing.isArray()); + } + + // TODO(markovejnovic): I can't find a better way to do this... I generate a new array, + // filtering out the callback we want to remove. This is woefully inefficient for large + // sets (and surprisingly fast for small sets of callbacks). + // + // Perhaps there is an avenue to build a generic iterator pattern? @taylor.fish and I have + // briefly expressed a desire for this, and I promised her I would look into it, but at + // this moment have no proposal. + var array_it = try existing.arrayIterator(globalObject); + const updated_array = try jsc.JSArray.createEmpty(globalObject, 0); + while (try array_it.next()) |iter| { + if (iter == callback) + continue; + + try updated_array.push(globalObject, iter); + } + + // Otherwise, we have ourselves an array of callbacks. We need to remove the element in the + // array that matches the callback. + _ = try map.remove(globalObject, channelName); + + // Only populate the map if we have remaining callbacks for this channel. + const new_length = try updated_array.getLength(globalObject); + + if (new_length != 0) { + try map.set(globalObject, channelName, updated_array); + } + + return new_length; + } + + /// Add a handler for receiving messages on a specific channel + pub fn upsertReceiveHandler( + this: *Self, + globalObject: *jsc.JSGlobalObject, + channelName: JSValue, + callback: JSValue, + ) bun.JSError!void { + defer this._parent.onNewSubscriptionCallbackInsert(); + const map = this.subscriptionCallbackMap(); + + var handlers_array: JSValue = undefined; + var is_new_channel = false; + const existing_handler_arr = try map.get(globalObject, channelName); + if (existing_handler_arr != .js_undefined) { + debug("Adding a new receive handler.", .{}); + // Note that we need to cover this case because maps in JSC can return undefined when the key has never been + // set. + if (existing_handler_arr.isUndefined()) { + // Create a new array if the existing_handler_arr is undefined/null + handlers_array = try jsc.JSArray.createEmpty(globalObject, 0); + is_new_channel = true; + } else if (existing_handler_arr.isArray()) { + // Use the existing array + handlers_array = existing_handler_arr; + } else unreachable; + } else { + // No existing_handler_arr exists, create a new array + handlers_array = try jsc.JSArray.createEmpty(globalObject, 0); + is_new_channel = true; + } + + // Append the new callback to the array + try handlers_array.push(globalObject, callback); + + // Set the updated array back in the map + try map.set(globalObject, channelName, handlers_array); + } + + pub fn getCallbacks(this: *Self, globalObject: *jsc.JSGlobalObject, channelName: JSValue) bun.JSError!?JSValue { + const result = try this.subscriptionCallbackMap().get(globalObject, channelName); + if (result == .js_undefined) { + return null; + } + + return result; + } + + /// Invoke callbacks for a channel with the given arguments + /// Handles both single callbacks and arrays of callbacks + pub fn invokeCallbacks( + this: *Self, + globalObject: *jsc.JSGlobalObject, + channelName: JSValue, + args: []const JSValue, + ) bun.JSError!void { + const callbacks = try this.getCallbacks(globalObject, channelName) orelse { + debug("No callbacks found for channel {s}", .{channelName.asString().getZigString(globalObject)}); + return; + }; + + if (comptime bun.Environment.isDebug) { + bun.assert(callbacks.isArray()); + } + + const vm = jsc.VirtualMachine.get(); + const event_loop = vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + + // After we go through every single callback, we will have to update the poll ref. + // The user may, for example, unsubscribe in the callbacks, or even stop the client. + defer this._parent.updatePollRef(); + + // If callbacks is an array, iterate and call each one + var iter = try callbacks.arrayIterator(globalObject); + while (try iter.next()) |callback| { + if (comptime bun.Environment.isDebug) { + bun.assert(callback.isCallable()); + } + + event_loop.runCallback(callback, globalObject, .js_undefined, args); + } + } + + /// Return whether the subscription context is ready to be deleted by the JS garbage collector. + pub fn isDeletable(this: *Self, global_object: *jsc.JSGlobalObject) bun.JSError!bool { + // The user may request .close(), in which case we can dispose of the subscription object. If that is the case, + // finalized will be true. Otherwise, we should treat the object as disposable if there are no active + // subscriptions. + return this._parent.client.flags.finalized or !(try this.hasSubscriptions(global_object)); + } + + pub fn deinit(this: *Self, global_object: *jsc.JSGlobalObject) void { + // This check is necessary because crossing between Zig and C++ is necessary because Zig doesn't know that C++ + // is side-effect-free. + if (comptime bun.Environment.isDebug) { + bun.debugAssert(this.isDeletable(this._parent.globalObject) catch unreachable); + } + + if (this._parent.this_value.tryGet()) |parent_this| { + ParentJS.gc.set(.subscriptionCallbackMap, parent_this, global_object, .js_undefined); + } + } +}; + /// Valkey client wrapper for JavaScript pub const JSValkeyClient = struct { client: valkey.ValkeyClient, globalObject: *jsc.JSGlobalObject, this_value: jsc.JSRef = jsc.JSRef.empty(), poll_ref: bun.Async.KeepAlive = .{}, + + _subscription_ctx: ?SubscriptionCtx, + timer: Timer.EventLoopTimer = .{ .tag = .ValkeyConnectionTimeout, .next = .{ @@ -31,11 +255,13 @@ pub const JSValkeyClient = struct { pub const new = bun.TrivialNew(@This()); // Factory function to create a new Valkey client from JS - pub fn constructor(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!*JSValkeyClient { - return try create(globalObject, callframe.arguments()); + pub fn constructor(globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame, js_this: JSValue) bun.JSError!*JSValkeyClient { + return try create(globalObject, callframe.arguments(), js_this); } - pub fn create(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue) bun.JSError!*JSValkeyClient { + pub fn createNoJs(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue) bun.JSError!*JSValkeyClient { + const this_allocator = bun.default_allocator; + const vm = globalObject.bunVM(); const url_str = if (arguments.len < 1 or arguments[0].isUndefined()) if (vm.transpiler.env.get("REDIS_URL") orelse vm.transpiler.env.get("VALKEY_URL")) |url| @@ -46,7 +272,7 @@ pub const JSValkeyClient = struct { try arguments[0].toBunString(globalObject); defer url_str.deref(); - const url_utf8 = url_str.toUTF8WithoutRef(bun.default_allocator); + const url_utf8 = url_str.toUTF8WithoutRef(this_allocator); defer url_utf8.deinit(); const url = bun.URL.parse(url_utf8.slice()); @@ -88,7 +314,7 @@ pub const JSValkeyClient = struct { var connection_strings: []u8 = &.{}; errdefer { - bun.default_allocator.free(connection_strings); + this_allocator.free(connection_strings); } if (url.username.len > 0 or url.password.len > 0 or hostname.len > 0) { @@ -96,19 +322,21 @@ pub const JSValkeyClient = struct { b.count(url.username); b.count(url.password); b.count(hostname); - try b.allocate(bun.default_allocator); + try b.allocate(this_allocator); + defer b.deinit(this_allocator); username = b.append(url.username); password = b.append(url.password); hostname = b.append(hostname); - connection_strings = b.allocatedSlice(); + b.moveToSlice(&connection_strings); } const database = if (url.pathname.len > 0) std.fmt.parseInt(u32, url.pathname[1..], 10) catch 0 else 0; bun.analytics.Features.valkey += 1; - return JSValkeyClient.new(.{ + const client = JSValkeyClient.new(.{ .ref_count = .init(), + ._subscription_ctx = null, .client = .{ .vm = vm, .address = switch (uri) { @@ -120,10 +348,11 @@ pub const JSValkeyClient = struct { }, }, }, + .protocol = uri, .username = username, .password = password, - .in_flight = .init(bun.default_allocator), - .queue = .init(bun.default_allocator), + .in_flight = .init(this_allocator), + .queue = .init(this_allocator), .status = .disconnected, .connection_strings = connection_strings, .socket = .{ @@ -134,7 +363,7 @@ pub const JSValkeyClient = struct { }, }, .database = database, - .allocator = bun.default_allocator, + .allocator = this_allocator, .flags = .{ .enable_auto_reconnect = options.enable_auto_reconnect, .enable_offline_queue = options.enable_offline_queue, @@ -146,6 +375,130 @@ pub const JSValkeyClient = struct { }, .globalObject = globalObject, }); + + return client; + } + + pub fn create(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue, js_this: JSValue) bun.JSError!*JSValkeyClient { + var new_client = try JSValkeyClient.createNoJs(globalObject, arguments); + + // Initially, we only need to hold a weak reference to the JS object. + new_client.this_value = jsc.JSRef.initWeak(js_this); + return new_client; + } + + /// Clone this client while remaining in the initial disconnected state. + /// + /// Note that this does not create an object with an associated this_value. + /// You may need to populate it yourself. + pub fn cloneWithoutConnecting( + this: *const JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + ) bun.OOM!*JSValkeyClient { + const vm = globalObject.bunVM(); + + // Make a copy of connection_strings to avoid double-free + const connection_strings_copy = try this.client.allocator.dupe(u8, this.client.connection_strings); + + // Note that there is no need to copy username, password and address since the copies live + // within the connection_strings buffer. + const base_ptr = this.client.connection_strings.ptr; + const new_base = connection_strings_copy.ptr; + const username = bun.memory.rebaseSlice(this.client.username, base_ptr, new_base); + const password = bun.memory.rebaseSlice(this.client.password, base_ptr, new_base); + const orig_hostname = this.client.address.hostname(); + const hostname = bun.memory.rebaseSlice(orig_hostname, base_ptr, new_base); + const new_alloc = this.client.allocator; + + return JSValkeyClient.new(.{ + .ref_count = .init(), + ._subscription_ctx = null, + .client = .{ + .vm = vm, + .address = switch (this.client.protocol) { + .standalone_unix, .standalone_tls_unix => .{ .unix = hostname }, + else => .{ + .host = .{ + .host = hostname, + .port = this.client.address.host.port, + }, + }, + }, + .protocol = this.client.protocol, + .username = username, + .password = password, + .in_flight = .init(new_alloc), + .queue = .init(new_alloc), + .status = .disconnected, + .connection_strings = connection_strings_copy, + .socket = .{ + .SocketTCP = .{ + .socket = .{ + .detached = {}, + }, + }, + }, + .database = this.client.database, + .allocator = new_alloc, + .flags = .{ + // Because this starts in the disconnected state, we need to reset some flags. + .is_authenticated = false, + // If the user manually closed the connection, then duplicating a closed client + // means the new client remains finalized. + .is_manually_closed = this.client.flags.is_manually_closed, + .enable_offline_queue = if (this._subscription_ctx) |*ctx| ctx.original_enable_offline_queue else this.client.flags.enable_offline_queue, + .needs_to_open_socket = true, + .enable_auto_reconnect = this.client.flags.enable_auto_reconnect, + .is_reconnecting = false, + .auto_pipelining = if (this._subscription_ctx) |*ctx| ctx.original_enable_auto_pipelining else this.client.flags.auto_pipelining, + // Duplicating a finalized client means it stays finalized. + .finalized = this.client.flags.finalized, + }, + .max_retries = this.client.max_retries, + .connection_timeout_ms = this.client.connection_timeout_ms, + .idle_timeout_interval_ms = this.client.idle_timeout_interval_ms, + }, + .globalObject = globalObject, + }); + } + + pub fn getOrCreateSubscriptionCtxEnteringSubscriptionMode( + this: *JSValkeyClient, + ) bun.JSError!*SubscriptionCtx { + if (this._subscription_ctx) |*ctx| { + // If we already have a subscription context, return it + return ctx; + } + + // Save the original flag values and create a new subscription context + this._subscription_ctx = try SubscriptionCtx.init( + this, + this.client.flags.enable_offline_queue, + this.client.flags.auto_pipelining, + ); + + // We need to make sure we disable the offline queue. + this.client.flags.enable_offline_queue = false; + this.client.flags.auto_pipelining = false; + + return &(this._subscription_ctx.?); + } + + pub fn deleteSubscriptionCtx(this: *JSValkeyClient) void { + if (this._subscription_ctx) |*ctx| { + // Restore the original flag values when leaving subscription mode + this.client.flags.enable_offline_queue = ctx.original_enable_offline_queue; + this.client.flags.auto_pipelining = ctx.original_enable_auto_pipelining; + + ctx.deinit(this.globalObject); + this._subscription_ctx = null; + } + + this._subscription_ctx = null; + } + + pub fn isSubscriber(this: *const JSValkeyClient) bool { + return this._subscription_ctx != null; } pub fn getConnected(this: *JSValkeyClient, _: *jsc.JSGlobalObject) JSValue { @@ -159,16 +512,22 @@ pub const JSValkeyClient = struct { return JSValue.jsNumber(len); } - pub fn doConnect(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, this_value: JSValue) bun.JSError!JSValue { + pub fn doConnect( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + this_value: JSValue, + ) bun.JSError!JSValue { this.ref(); defer this.deref(); // If already connected, resolve immediately if (this.client.status == .connected) { + debug("Connecting client is already connected.", .{}); return jsc.JSPromise.resolvedPromiseValue(globalObject, js.helloGetCached(this_value) orelse .js_undefined); } if (js.connectionPromiseGetCached(this_value)) |promise| { + debug("Connecting client is already connected.", .{}); return promise; } @@ -181,12 +540,16 @@ pub const JSValkeyClient = struct { this.this_value.setStrong(this_value, globalObject); if (this.client.flags.needs_to_open_socket) { + debug("Need to open socket, starting connection process.", .{}); this.poll_ref.ref(this.client.vm); this.connect() catch |err| { this.poll_ref.unref(this.client.vm); this.client.flags.needs_to_open_socket = true; const err_value = globalObject.ERR(.SOCKET_CLOSED_BEFORE_CONNECTION, " {s} connecting to Valkey", .{@errorName(err)}).toJS(); + const event_loop = this.client.vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); promise_ptr.reject(globalObject, err_value); return promise; }; @@ -274,14 +637,14 @@ pub const JSValkeyClient = struct { /// Safely remove a timer with proper reference counting and event loop keepalive fn removeTimer(this: *JSValkeyClient, timer: *Timer.EventLoopTimer) void { if (timer.state == .ACTIVE) { - // Store VM reference to use later const vm = this.client.vm; // Remove the timer from the event loop vm.timer.remove(timer); - // Balance the ref from addTimer + // this.addTimer() adds a reference to 'this' when the timer is + // alive which is balanced here. this.deref(); } } @@ -391,11 +754,7 @@ pub const JSValkeyClient = struct { // Callback for when Valkey client connects pub fn onValkeyConnect(this: *JSValkeyClient, value: *protocol.RESPValue) void { - // Safety check to ensure a valid connection state - if (this.client.status != .connected) { - debug("onValkeyConnect called but client status is not 'connected': {s}", .{@tagName(this.client.status)}); - return; - } + bun.debugAssert(this.client.status == .connected); const globalObject = this.globalObject; const event_loop = this.client.vm.eventLoop(); @@ -414,7 +773,16 @@ pub const JSValkeyClient = struct { if (js.connectionPromiseGetCached(this_value)) |promise| { js.connectionPromiseSetCached(this_value, globalObject, .zero); - promise.asPromise().?.resolve(globalObject, hello_value); + const js_promise = promise.asPromise().?; + if (this.client.flags.connection_promise_returns_client) { + debug("Resolving connection promise with client instance", .{}); + const this_js = this.toJS(globalObject); + js_promise.resolve(globalObject, this_js); + } else { + debug("Resolving connection promise with HELLO response", .{}); + js_promise.resolve(globalObject, hello_value); + } + this.client.flags.connection_promise_returns_client = false; } } @@ -422,6 +790,97 @@ pub const JSValkeyClient = struct { this.updatePollRef(); } + /// Invoked when the Valkey client receives a new listener. + /// + /// `SubscriptionCtx` will invoke this to communicate that it has added a new listener. + pub fn onNewSubscriptionCallbackInsert(this: *JSValkeyClient) void { + this.ref(); + defer this.deref(); + + this.client.onWritable(); + this.updatePollRef(); + } + + pub fn onValkeySubscribe(this: *JSValkeyClient, value: *protocol.RESPValue) void { + bun.debugAssert(this.isSubscriber()); + bun.debugAssert(this.this_value.isStrong()); + + this.ref(); + defer this.deref(); + + _ = value; + + this.client.onWritable(); + this.updatePollRef(); + } + + pub fn onValkeyUnsubscribe(this: *JSValkeyClient) bun.JSError!void { + bun.debugAssert(this.isSubscriber()); + bun.debugAssert(this.this_value.isStrong()); + + this.ref(); + defer this.deref(); + + var subscription_ctx = this._subscription_ctx.?; + + // Check if we have any remaining subscriptions + // If the callback map is empty, we can exit subscription mode + + // If fetching the subscription count fails, the best we can do is + // bubble the error up. + const has_subs = try subscription_ctx.hasSubscriptions(this.globalObject); + if (!has_subs) { + // No more subscriptions, exit subscription mode + this.deleteSubscriptionCtx(); + } + + this.client.onWritable(); + this.updatePollRef(); + } + + pub fn onValkeyMessage(this: *JSValkeyClient, value: []protocol.RESPValue) void { + if (!this.isSubscriber()) { + debug("onMessage called but client is not in subscriber mode", .{}); + return; + } + + const globalObject = this.globalObject; + const event_loop = this.client.vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); + + // The message push should be an array with [channel, message] + if (value.len < 2) { + debug("Message array has insufficient elements: {}", .{value.len}); + return; + } + + // Extract channel and message + const channel_value = value[0].toJS(globalObject) catch { + debug("Failed to convert channel to JS", .{}); + return; + }; + const message_value = value[1].toJS(globalObject) catch { + debug("Failed to convert message to JS", .{}); + return; + }; + + // Get the subscription context + const subs_ctx = &this._subscription_ctx.?; + + // Invoke callbacks for this channel with message and channel as arguments + subs_ctx.invokeCallbacks( + globalObject, + channel_value, + &[_]JSValue{ message_value, channel_value }, + ) catch { + return; + }; + + this.client.onWritable(); + this.updatePollRef(); + } + // Callback for when Valkey client needs to reconnect pub fn onValkeyReconnect(this: *JSValkeyClient) void { // Schedule reconnection using our safe timer methods @@ -468,6 +927,9 @@ pub const JSValkeyClient = struct { &[_]JSValue{error_value}, ) catch |e| globalObject.reportActiveExceptionAsUnhandled(e); } + + // Update poll reference to allow garbage collection of disconnected clients + this.updatePollRef(); } // Callback for when Valkey client times out @@ -495,19 +957,20 @@ pub const JSValkeyClient = struct { } pub fn finalize(this: *JSValkeyClient) void { - // Since this.stopTimers impacts the reference count potentially, we - // need to ref/unref here as well. this.ref(); defer this.deref(); this.stopTimers(); - this.this_value.deinit(); - if (this.client.status == .connected or this.client.status == .connecting) { - this.client.flags.is_manually_closed = true; - } + this.this_value.finalize(); this.client.flags.finalized = true; this.client.close(); - this.deref(); + + // We do not need to free the subscription context here because we're + // guaranteed to have freed it by virtue of the fact that we are + // garbage collected now and the subscription context holds a reference + // to us. If we still had a subscription context, we would never be + // garbage collected. + bun.debugAssert(this._subscription_ctx == null); } pub fn stopTimers(this: *JSValkeyClient) void { @@ -521,6 +984,7 @@ pub const JSValkeyClient = struct { } fn connect(this: *JSValkeyClient) !void { + debug("Connecting to Redis.", .{}); this.client.flags.needs_to_open_socket = false; const vm = this.client.vm; @@ -578,6 +1042,9 @@ pub const JSValkeyClient = struct { this.client.flags.needs_to_open_socket = true; const err_value = globalThis.ERR(.SOCKET_CLOSED_BEFORE_CONNECTION, " {s} connecting to Valkey", .{@errorName(err)}).toJS(); const promise = jsc.JSPromise.create(globalThis); + const event_loop = this.client.vm.eventLoop(); + event_loop.enter(); + defer event_loop.exit(); promise.reject(globalThis, err_value); return promise; }; @@ -612,25 +1079,77 @@ pub const JSValkeyClient = struct { this.client.deinit(null); this.poll_ref.disable(); this.stopTimers(); - this.this_value.deinit(); + this.this_value.finalize(); this.ref_count.assertNoRefs(); bun.destroy(this); } /// Keep the event loop alive, or don't keep it alive + /// + /// This requires this_value to be alive. pub fn updatePollRef(this: *JSValkeyClient) void { - if (!this.client.hasAnyPendingCommands() and this.client.status == .connected) { - this.poll_ref.unref(this.client.vm); - // If we don't have any pending commands and we're connected, we don't need to keep the object alive. - if (this.this_value.tryGet()) |value| { - this.this_value.setWeak(value); - } - } else if (this.client.hasAnyPendingCommands()) { + // TODO(markovejnovic): This function is such a crazy cop out. We really + // should be treating valkey as a state machine, with well-defined + // state and modes in which it tracks and manages its own lifecycle. + // This is a mess beyond belief and it is incredibly fragile. + + const has_pending_commands = this.client.hasAnyPendingCommands(); + + // isDeletable may throw an exception, and if it does, we have to assume + // that the object still has references. Best we can do is hope nothing + // catastrophic happens. + const subs_deletable: bool = if (this._subscription_ctx) |*ctx| + ctx.isDeletable(this.globalObject) catch false + else + true; + + const has_activity = has_pending_commands or !subs_deletable; + + // There's a couple cases to handle here: + if (has_activity) { + // If we currently have pending activity, we need to keep the event + // loop alive. this.poll_ref.ref(this.client.vm); - // If we have pending commands, we need to keep the object alive. - if (this.this_value == .weak) { + } else { + // There is no pending activity so it is safe to remove the event + // loop. + this.poll_ref.unref(this.client.vm); + } + + if (this.this_value.isEmpty()) { + return; + } + + // Orthogonal to this, we need to manage the strong reference to the JS + // object. + switch (this.client.status) { + .connecting, .connected => { + // Whenever we're connected, we need to keep the object alive. + // + // TODO(markovejnovic): This is a leak. + // Note this is an intentional leak. Unless the user manually + // closes the connection, the object will stay alive forever, + // even if it falls out of scope. This is kind of stupid, since + // if the object is out of scope, and isn't subscribed upon, + // how exactly is the user going to call anything on the object? + // + // It is 100% safe to drop the strong reference there and let + // the object be GC'd, but we're not doing that now. this.this_value.upgrade(this.globalObject); - } + }, + .disconnected, .failed => { + // If we're disconnected or failed, we need to check if we have + // any pending activity. + if (has_activity) { + // If we have pending activity, we need to keep the object + // alive. + this.this_value.upgrade(this.globalObject); + } else { + // If we don't have any pending activity, we can drop the + // strong reference. + this.this_value.downgrade(); + } + }, } } @@ -641,6 +1160,7 @@ pub const JSValkeyClient = struct { pub const decr = fns.decr; pub const del = fns.del; pub const dump = fns.dump; + pub const duplicate = fns.duplicate; pub const exists = fns.exists; pub const expire = fns.expire; pub const expiretime = fns.expiretime; @@ -756,18 +1276,21 @@ fn SocketHandler(comptime ssl: bool) type { pub const onHandshake = if (ssl) onHandshake_ else null; pub fn onClose(this: *JSValkeyClient, _: SocketType, _: i32, _: ?*anyopaque) void { + // No need to deref since this.client.onClose() invokes onValkeyClose which does the deref. + + debug("Socket closed.", .{}); + // Ensure the socket pointer is updated. this.client.socket = .{ .SocketTCP = .detached }; this.client.onClose(); + this.updatePollRef(); } pub fn onEnd(this: *JSValkeyClient, socket: SocketType) void { - // Ensure the socket pointer is updated before closing - this.client.socket = _socket(socket); - - // Do not allow half-open connections - socket.close(.normal); + _ = this; + _ = socket; + // Half-opened sockets are not allowed. } pub fn onConnectError(this: *JSValkeyClient, _: SocketType, _: i32) void { @@ -778,6 +1301,8 @@ fn SocketHandler(comptime ssl: bool) type { } pub fn onTimeout(this: *JSValkeyClient, socket: SocketType) void { + debug("Socket timed out.", .{}); + this.client.socket = _socket(socket); // Handle socket timeout } diff --git a/src/valkey/js_valkey_functions.zig b/src/valkey/js_valkey_functions.zig index 4b6d2daad6..9361e404cf 100644 --- a/src/valkey/js_valkey_functions.zig +++ b/src/valkey/js_valkey_functions.zig @@ -1,3 +1,19 @@ +fn requireNotSubscriber(this: *JSValkeyClient, function_name: []const u8) bun.JSError!void { + const fmt_string = "RedisClient.prototype.{s} cannot be called while in subscriber mode."; + + if (this.isSubscriber()) { + return this.globalObject.ERR(.REDIS_INVALID_STATE, fmt_string, .{function_name}).throw(); + } +} + +fn requireSubscriber(this: *JSValkeyClient, function_name: []const u8) bun.JSError!void { + const fmt_string = "RedisClient.prototype.{s} can only be called while in subscriber mode."; + + if (!this.isSubscriber()) { + return this.globalObject.ERR(.REDIS_INVALID_STATE, fmt_string, .{function_name}).throw(); + } +} + pub fn jsSend(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { const command = try callframe.argument(0).toBunString(globalObject); defer command.deref(); @@ -41,6 +57,8 @@ pub fn jsSend(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfram } pub fn get(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("get", "key", "string or buffer"); }; @@ -61,6 +79,8 @@ pub fn get(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: } pub fn getBuffer(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("getBuffer", "key", "string or buffer"); }; @@ -81,6 +101,8 @@ pub fn getBuffer(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callf } pub fn set(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const args_view = callframe.arguments(); var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); var args = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), args_view.len); @@ -127,6 +149,8 @@ pub fn set(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: } pub fn incr(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("incr", "key", "string or buffer"); }; @@ -147,6 +171,8 @@ pub fn incr(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: } pub fn decr(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("decr", "key", "string or buffer"); }; @@ -167,6 +193,8 @@ pub fn decr(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: } pub fn exists(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("exists", "key", "string or buffer"); }; @@ -188,6 +216,8 @@ pub fn exists(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfram } pub fn expire(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("expire", "key", "string or buffer"); }; @@ -219,6 +249,8 @@ pub fn expire(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfram } pub fn ttl(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("ttl", "key", "string or buffer"); }; @@ -240,6 +272,8 @@ pub fn ttl(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: // Implement srem (remove value from a set) pub fn srem(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("srem", "key", "string or buffer"); }; @@ -265,6 +299,8 @@ pub fn srem(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: // Implement srandmember (get random member from set) pub fn srandmember(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("srandmember", "key", "string or buffer"); }; @@ -286,6 +322,8 @@ pub fn srandmember(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, cal // Implement smembers (get all members of a set) pub fn smembers(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("smembers", "key", "string or buffer"); }; @@ -307,6 +345,8 @@ pub fn smembers(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfr // Implement spop (pop a random member from a set) pub fn spop(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("spop", "key", "string or buffer"); }; @@ -328,6 +368,8 @@ pub fn spop(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: // Implement sadd (add member to a set) pub fn sadd(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("sadd", "key", "string or buffer"); }; @@ -353,6 +395,8 @@ pub fn sadd(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: // Implement sismember (check if value is member of a set) pub fn sismember(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("sismember", "key", "string or buffer"); }; @@ -379,6 +423,8 @@ pub fn sismember(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callf // Implement hmget (get multiple values from hash) pub fn hmget(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType("hmget", "key", "string or buffer"); }; @@ -426,6 +472,8 @@ pub fn hmget(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe // Implement hincrby (increment hash field by integer value) pub fn hincrby(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = try callframe.argument(0).toBunString(globalObject); defer key.deref(); const field = try callframe.argument(1).toBunString(globalObject); @@ -456,6 +504,8 @@ pub fn hincrby(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callfra // Implement hincrbyfloat (increment hash field by float value) pub fn hincrbyfloat(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = try callframe.argument(0).toBunString(globalObject); defer key.deref(); const field = try callframe.argument(1).toBunString(globalObject); @@ -486,6 +536,8 @@ pub fn hincrbyfloat(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, ca // Implement hmset (set multiple values in hash) pub fn hmset(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + const key = try callframe.argument(0).toBunString(globalObject); defer key.deref(); @@ -578,60 +630,346 @@ pub fn ping(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: return promise.toJS(); } -pub const bitcount = compile.@"(key: RedisKey)"("bitcount", "BITCOUNT", "key").call; -pub const dump = compile.@"(key: RedisKey)"("dump", "DUMP", "key").call; -pub const expiretime = compile.@"(key: RedisKey)"("expiretime", "EXPIRETIME", "key").call; -pub const getdel = compile.@"(key: RedisKey)"("getdel", "GETDEL", "key").call; -pub const getex = compile.@"(...strings: string[])"("getex", "GETEX").call; -pub const hgetall = compile.@"(key: RedisKey)"("hgetall", "HGETALL", "key").call; -pub const hkeys = compile.@"(key: RedisKey)"("hkeys", "HKEYS", "key").call; -pub const hlen = compile.@"(key: RedisKey)"("hlen", "HLEN", "key").call; -pub const hvals = compile.@"(key: RedisKey)"("hvals", "HVALS", "key").call; -pub const keys = compile.@"(key: RedisKey)"("keys", "KEYS", "key").call; -pub const llen = compile.@"(key: RedisKey)"("llen", "LLEN", "key").call; -pub const lpop = compile.@"(key: RedisKey)"("lpop", "LPOP", "key").call; -pub const persist = compile.@"(key: RedisKey)"("persist", "PERSIST", "key").call; -pub const pexpiretime = compile.@"(key: RedisKey)"("pexpiretime", "PEXPIRETIME", "key").call; -pub const pttl = compile.@"(key: RedisKey)"("pttl", "PTTL", "key").call; -pub const rpop = compile.@"(key: RedisKey)"("rpop", "RPOP", "key").call; -pub const scard = compile.@"(key: RedisKey)"("scard", "SCARD", "key").call; -pub const strlen = compile.@"(key: RedisKey)"("strlen", "STRLEN", "key").call; -pub const @"type" = compile.@"(key: RedisKey)"("type", "TYPE", "key").call; -pub const zcard = compile.@"(key: RedisKey)"("zcard", "ZCARD", "key").call; -pub const zpopmax = compile.@"(key: RedisKey)"("zpopmax", "ZPOPMAX", "key").call; -pub const zpopmin = compile.@"(key: RedisKey)"("zpopmin", "ZPOPMIN", "key").call; -pub const zrandmember = compile.@"(key: RedisKey)"("zrandmember", "ZRANDMEMBER", "key").call; +pub const bitcount = compile.@"(key: RedisKey)"("bitcount", "BITCOUNT", "key", .not_subscriber).call; +pub const dump = compile.@"(key: RedisKey)"("dump", "DUMP", "key", .not_subscriber).call; +pub const expiretime = compile.@"(key: RedisKey)"("expiretime", "EXPIRETIME", "key", .not_subscriber).call; +pub const getdel = compile.@"(key: RedisKey)"("getdel", "GETDEL", "key", .not_subscriber).call; +pub const getex = compile.@"(...strings: string[])"("getex", "GETEX", .not_subscriber).call; +pub const hgetall = compile.@"(key: RedisKey)"("hgetall", "HGETALL", "key", .not_subscriber).call; +pub const hkeys = compile.@"(key: RedisKey)"("hkeys", "HKEYS", "key", .not_subscriber).call; +pub const hlen = compile.@"(key: RedisKey)"("hlen", "HLEN", "key", .not_subscriber).call; +pub const hvals = compile.@"(key: RedisKey)"("hvals", "HVALS", "key", .not_subscriber).call; +pub const keys = compile.@"(key: RedisKey)"("keys", "KEYS", "key", .not_subscriber).call; +pub const llen = compile.@"(key: RedisKey)"("llen", "LLEN", "key", .not_subscriber).call; +pub const lpop = compile.@"(key: RedisKey)"("lpop", "LPOP", "key", .not_subscriber).call; +pub const persist = compile.@"(key: RedisKey)"("persist", "PERSIST", "key", .not_subscriber).call; +pub const pexpiretime = compile.@"(key: RedisKey)"("pexpiretime", "PEXPIRETIME", "key", .not_subscriber).call; +pub const pttl = compile.@"(key: RedisKey)"("pttl", "PTTL", "key", .not_subscriber).call; +pub const rpop = compile.@"(key: RedisKey)"("rpop", "RPOP", "key", .not_subscriber).call; +pub const scard = compile.@"(key: RedisKey)"("scard", "SCARD", "key", .not_subscriber).call; +pub const strlen = compile.@"(key: RedisKey)"("strlen", "STRLEN", "key", .not_subscriber).call; +pub const @"type" = compile.@"(key: RedisKey)"("type", "TYPE", "key", .not_subscriber).call; +pub const zcard = compile.@"(key: RedisKey)"("zcard", "ZCARD", "key", .not_subscriber).call; +pub const zpopmax = compile.@"(key: RedisKey)"("zpopmax", "ZPOPMAX", "key", .not_subscriber).call; +pub const zpopmin = compile.@"(key: RedisKey)"("zpopmin", "ZPOPMIN", "key", .not_subscriber).call; +pub const zrandmember = compile.@"(key: RedisKey)"("zrandmember", "ZRANDMEMBER", "key", .not_subscriber).call; -pub const append = compile.@"(key: RedisKey, value: RedisValue)"("append", "APPEND", "key", "value").call; -pub const getset = compile.@"(key: RedisKey, value: RedisValue)"("getset", "GETSET", "key", "value").call; -pub const hget = compile.@"(key: RedisKey, value: RedisValue)"("hget", "HGET", "key", "field").call; -pub const lpush = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("lpush", "LPUSH").call; -pub const lpushx = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("lpushx", "LPUSHX").call; -pub const pfadd = compile.@"(key: RedisKey, value: RedisValue)"("pfadd", "PFADD", "key", "value").call; -pub const rpush = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("rpush", "RPUSH").call; -pub const rpushx = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("rpushx", "RPUSHX").call; -pub const setnx = compile.@"(key: RedisKey, value: RedisValue)"("setnx", "SETNX", "key", "value").call; -pub const zscore = compile.@"(key: RedisKey, value: RedisValue)"("zscore", "ZSCORE", "key", "value").call; +pub const append = compile.@"(key: RedisKey, value: RedisValue)"("append", "APPEND", "key", "value", .not_subscriber).call; +pub const getset = compile.@"(key: RedisKey, value: RedisValue)"("getset", "GETSET", "key", "value", .not_subscriber).call; +pub const hget = compile.@"(key: RedisKey, value: RedisValue)"("hget", "HGET", "key", "field", .not_subscriber).call; +pub const lpush = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("lpush", "LPUSH", .not_subscriber).call; +pub const lpushx = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("lpushx", "LPUSHX", .not_subscriber).call; +pub const pfadd = compile.@"(key: RedisKey, value: RedisValue)"("pfadd", "PFADD", "key", "value", .not_subscriber).call; +pub const rpush = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("rpush", "RPUSH", .not_subscriber).call; +pub const rpushx = compile.@"(key: RedisKey, value: RedisValue, ...args: RedisValue)"("rpushx", "RPUSHX", .not_subscriber).call; +pub const setnx = compile.@"(key: RedisKey, value: RedisValue)"("setnx", "SETNX", "key", "value", .not_subscriber).call; +pub const zscore = compile.@"(key: RedisKey, value: RedisValue)"("zscore", "ZSCORE", "key", "value", .not_subscriber).call; -pub const del = compile.@"(key: RedisKey, ...args: RedisKey[])"("del", "DEL", "key").call; -pub const mget = compile.@"(key: RedisKey, ...args: RedisKey[])"("mget", "MGET", "key").call; +pub const del = compile.@"(key: RedisKey, ...args: RedisKey[])"("del", "DEL", "key", .not_subscriber).call; +pub const mget = compile.@"(key: RedisKey, ...args: RedisKey[])"("mget", "MGET", "key", .not_subscriber).call; -pub const publish = compile.@"(...strings: string[])"("publish", "PUBLISH").call; -pub const script = compile.@"(...strings: string[])"("script", "SCRIPT").call; -pub const select = compile.@"(...strings: string[])"("select", "SELECT").call; -pub const spublish = compile.@"(...strings: string[])"("spublish", "SPUBLISH").call; -pub const smove = compile.@"(...strings: string[])"("smove", "SMOVE").call; -pub const substr = compile.@"(...strings: string[])"("substr", "SUBSTR").call; -pub const hstrlen = compile.@"(...strings: string[])"("hstrlen", "HSTRLEN").call; -pub const zrank = compile.@"(...strings: string[])"("zrank", "ZRANK").call; -pub const zrevrank = compile.@"(...strings: string[])"("zrevrank", "ZREVRANK").call; -pub const subscribe = compile.@"(...strings: string[])"("subscribe", "SUBSCRIBE").call; -pub const psubscribe = compile.@"(...strings: string[])"("psubscribe", "PSUBSCRIBE").call; -pub const unsubscribe = compile.@"(...strings: string[])"("unsubscribe", "UNSUBSCRIBE").call; -pub const punsubscribe = compile.@"(...strings: string[])"("punsubscribe", "PUNSUBSCRIBE").call; -pub const pubsub = compile.@"(...strings: string[])"("pubsub", "PUBSUB").call; +pub const script = compile.@"(...strings: string[])"("script", "SCRIPT", .not_subscriber).call; +pub const select = compile.@"(...strings: string[])"("select", "SELECT", .not_subscriber).call; +pub const spublish = compile.@"(...strings: string[])"("spublish", "SPUBLISH", .not_subscriber).call; +pub const smove = compile.@"(...strings: string[])"("smove", "SMOVE", .not_subscriber).call; +pub const substr = compile.@"(...strings: string[])"("substr", "SUBSTR", .not_subscriber).call; +pub const hstrlen = compile.@"(...strings: string[])"("hstrlen", "HSTRLEN", .not_subscriber).call; +pub const zrank = compile.@"(...strings: string[])"("zrank", "ZRANK", .not_subscriber).call; +pub const zrevrank = compile.@"(...strings: string[])"("zrevrank", "ZREVRANK", .not_subscriber).call; +pub const psubscribe = compile.@"(...strings: string[])"("psubscribe", "PSUBSCRIBE", .dont_care).call; +pub const punsubscribe = compile.@"(...strings: string[])"("punsubscribe", "PUNSUBSCRIBE", .dont_care).call; +pub const pubsub = compile.@"(...strings: string[])"("pubsub", "PUBSUB", .dont_care).call; + +pub fn publish( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + callframe: *jsc.CallFrame, +) bun.JSError!JSValue { + try requireNotSubscriber(this, @src().fn_name); + + const args_view = callframe.arguments(); + var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); + var args = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), args_view.len); + defer { + for (args.items) |*item| { + item.deinit(); + } + args.deinit(); + } + + const arg0 = callframe.argument(0); + if (!arg0.isString()) { + return globalObject.throwInvalidArgumentType("publish", "channel", "string"); + } + const channel = (try fromJS(globalObject, arg0)) orelse unreachable; + + args.appendAssumeCapacity(channel); + + const arg1 = callframe.argument(1); + if (!arg1.isString()) { + return globalObject.throwInvalidArgumentType("publish", "message", "string"); + } + const message = (try fromJS(globalObject, arg1)) orelse unreachable; + args.appendAssumeCapacity(message); + + const promise = this.send( + globalObject, + callframe.this(), + &.{ + .command = "PUBLISH", + .args = .{ .args = args.items }, + }, + ) catch |err| { + return protocol.valkeyErrorToJS(globalObject, "Failed to send PUBLISH command", err); + }; + + return promise.toJS(); +} + +pub fn subscribe( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + callframe: *jsc.CallFrame, +) bun.JSError!JSValue { + const channel_or_many, const handler_callback = callframe.argumentsAsArray(2); + var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); + var redis_channels = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), 1); + defer { + for (redis_channels.items) |*item| { + item.deinit(); + } + redis_channels.deinit(); + } + + if (!handler_callback.isCallable()) { + return globalObject.throwInvalidArgumentType("subscribe", "listener", "function"); + } + + // We now need to register the callback with our subscription context, which may or may not exist. + var subscription_ctx = try this.getOrCreateSubscriptionCtxEnteringSubscriptionMode(); + + // The first argument given is the channel or may be an array of channels. + if (channel_or_many.isArray()) { + if ((try channel_or_many.getLength(globalObject)) == 0) { + return globalObject.throwInvalidArguments("subscribe requires at least one channel", .{}); + } + try redis_channels.ensureTotalCapacity(try channel_or_many.getLength(globalObject)); + + var array_iter = try channel_or_many.arrayIterator(globalObject); + while (try array_iter.next()) |channel_arg| { + const channel = (try fromJS(globalObject, channel_arg)) orelse { + return globalObject.throwInvalidArgumentType("subscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + + try subscription_ctx.upsertReceiveHandler(globalObject, channel_arg, handler_callback); + } + } else if (channel_or_many.isString()) { + // It is a single string channel + const channel = (try fromJS(globalObject, channel_or_many)) orelse { + return globalObject.throwInvalidArgumentType("subscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + + try subscription_ctx.upsertReceiveHandler(globalObject, channel_or_many, handler_callback); + } else { + return globalObject.throwInvalidArgumentType("subscribe", "channel", "string or array"); + } + + const command: valkey.Command = .{ + .command = "SUBSCRIBE", + .args = .{ .args = redis_channels.items }, + }; + const promise = this.send( + globalObject, + callframe.this(), + &command, + ) catch |err| { + // If we find an error, we need to clean up the subscription context. + this.deleteSubscriptionCtx(); + return protocol.valkeyErrorToJS(globalObject, "Failed to send SUBSCRIBE command", err); + }; + + return promise.toJS(); +} + +/// Send redis the UNSUBSCRIBE RESP command and clean up anything necessary after the unsubscribe commoand. +/// +/// The subscription context must exist when calling this function. +fn sendUnsubscribeRequestAndCleanup( + this: *JSValkeyClient, + this_js: jsc.JSValue, + globalObject: *jsc.JSGlobalObject, + redis_channels: []JSArgument, +) !jsc.JSValue { + // Send UNSUBSCRIBE command + const command: valkey.Command = .{ + .command = "UNSUBSCRIBE", + .args = .{ .args = redis_channels }, + }; + const promise = this.send( + globalObject, + this_js, + &command, + ) catch |err| { + return protocol.valkeyErrorToJS(globalObject, "Failed to send UNSUBSCRIBE command", err); + }; + + // We do not delete the subscription context here, but rather when the + // onValkeyUnsubscribe callback is invoked. + + return promise.toJS(); +} + +pub fn unsubscribe( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + callframe: *jsc.CallFrame, +) bun.JSError!JSValue { + // Check if we're in subscription mode + try requireSubscriber(this, @src().fn_name); + + const args_view = callframe.arguments(); + + var stack_fallback = std.heap.stackFallback(512, bun.default_allocator); + var redis_channels = try std.ArrayList(JSArgument).initCapacity(stack_fallback.get(), 1); + defer { + for (redis_channels.items) |*item| { + item.deinit(); + } + redis_channels.deinit(); + } + + // If no arguments, unsubscribe from all channels + if (args_view.len == 0) { + return try sendUnsubscribeRequestAndCleanup(this, callframe.this(), globalObject, redis_channels.items); + } + + // The first argument can be a channel or an array of channels + const channel_or_many = callframe.argument(0); + + // Get the subscription context + var subscription_ctx = this._subscription_ctx orelse { + return jsc.JSPromise.resolvedPromiseValue(globalObject, .js_undefined); + }; + + // Two arguments means .unsubscribe(channel, listener) is invoked. + if (callframe.arguments().len == 2) { + // In this case, the first argument is a channel string and the second + // argument is the handler to remove. + if (!channel_or_many.isString()) { + return globalObject.throwInvalidArgumentType( + "unsubscribe", + "channel", + "string", + ); + } + + const channel = channel_or_many; + const listener_cb = callframe.argument(1); + + if (!listener_cb.isCallable()) { + return globalObject.throwInvalidArgumentType( + "unsubscribe", + "listener", + "function", + ); + } + + // Populate the redis_channels list with the single channel to + // unsubscribe from. This s important since this list is used to send + // the UNSUBSCRIBE command to redis. Without this, we would end up + // unsubscribing from all channels. + redis_channels.appendAssumeCapacity((try fromJS(globalObject, channel)) orelse { + return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string"); + }); + + const remaining_listeners = subscription_ctx.removeReceiveHandler( + globalObject, + channel, + listener_cb, + ) catch { + return globalObject.throw( + "Failed to remove handler for channel {}", + .{channel.asString().getZigString(globalObject)}, + ); + } orelse { + // Listeners weren't present in the first place, so we can return a + // resolved promise. + return jsc.JSPromise.resolvedPromiseValue(globalObject, .js_undefined); + }; + + // In this case, we only want to send the unsubscribe command to redis if there are no more listeners for this + // channel. + if (remaining_listeners == 0) { + return try sendUnsubscribeRequestAndCleanup(this, callframe.this(), globalObject, redis_channels.items); + } + + // Otherwise, in order to keep the API consistent, we need to return a resolved promise. + return jsc.JSPromise.resolvedPromiseValue(globalObject, .js_undefined); + } + + if (channel_or_many.isArray()) { + if ((try channel_or_many.getLength(globalObject)) == 0) { + return globalObject.throwInvalidArguments( + "unsubscribe requires at least one channel", + .{}, + ); + } + + try redis_channels.ensureTotalCapacity(try channel_or_many.getLength(globalObject)); + // It is an array, so let's iterate over it + var array_iter = try channel_or_many.arrayIterator(globalObject); + while (try array_iter.next()) |channel_arg| { + const channel = (try fromJS(globalObject, channel_arg)) orelse { + return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + // Clear the handlers for this channel + try subscription_ctx.clearReceiveHandlers(globalObject, channel_arg); + } + } else if (channel_or_many.isString()) { + // It is a single string channel + const channel = (try fromJS(globalObject, channel_or_many)) orelse { + return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string"); + }; + redis_channels.appendAssumeCapacity(channel); + // Clear the handlers for this channel + try subscription_ctx.clearReceiveHandlers(globalObject, channel_or_many); + } else { + return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string or array"); + } + + // Now send the unsubscribe command and clean up if necessary + return try sendUnsubscribeRequestAndCleanup(this, callframe.this(), globalObject, redis_channels.items); +} + +pub fn duplicate( + this: *JSValkeyClient, + globalObject: *jsc.JSGlobalObject, + callframe: *jsc.CallFrame, +) bun.JSError!JSValue { + _ = callframe; + + var new_client: *JSValkeyClient = try this.cloneWithoutConnecting(globalObject); + + const new_client_js = new_client.toJS(globalObject); + new_client.this_value = + if (this.client.status == .connected and !this.client.flags.is_manually_closed) + jsc.JSRef.initStrong(new_client_js, globalObject) + else + jsc.JSRef.initWeak(new_client_js); + + // If the original client is already connected and not manually closed, start connecting the new client. + if (this.client.status == .connected and !this.client.flags.is_manually_closed) { + // Use strong reference during connection to prevent premature GC + new_client.client.flags.connection_promise_returns_client = true; + return try new_client.doConnect(globalObject, new_client_js); + } + + return jsc.JSPromise.resolvedPromiseValue(globalObject, new_client_js); +} -// publish(channel: RedisValue, message: RedisValue) // script(subcommand: "LOAD", script: RedisValue) // select(index: number | string) // spublish(shardchannel: RedisValue, message: RedisValue) @@ -645,13 +983,37 @@ pub const pubsub = compile.@"(...strings: string[])"("pubsub", "PUBSUB").call; // cluster(subcommand: "KEYSLOT", key: RedisKey) const compile = struct { + pub const ClientStateRequirement = enum { + /// The client must be a subscriber (in subscription mode). + subscriber, + /// The client must not be a subscriber (not in subscription mode). + not_subscriber, + /// We don't care about the client state (subscriber or not). + dont_care, + }; + + fn testCorrectState( + this: *JSValkeyClient, + js_client_prototype_function_name: []const u8, + comptime client_state_requirement: ClientStateRequirement, + ) bun.JSError!void { + return switch (client_state_requirement) { + .subscriber => requireSubscriber(this, js_client_prototype_function_name), + .not_subscriber => requireNotSubscriber(this, js_client_prototype_function_name), + .dont_care => {}, + }; + } + pub fn @"(key: RedisKey)"( comptime name: []const u8, comptime command: []const u8, comptime arg0_name: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType(name, arg0_name, "string or buffer"); }; @@ -676,9 +1038,12 @@ const compile = struct { comptime name: []const u8, comptime command: []const u8, comptime arg0_name: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + if (callframe.argument(0).isUndefinedOrNull()) { return globalObject.throwMissingArgumentsValue(&.{arg0_name}); } @@ -722,9 +1087,12 @@ const compile = struct { comptime command: []const u8, comptime arg0_name: []const u8, comptime arg1_name: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + const key = (try fromJS(globalObject, callframe.argument(0))) orelse { return globalObject.throwInvalidArgumentType(name, arg0_name, "string or buffer"); }; @@ -752,9 +1120,12 @@ const compile = struct { pub fn @"(...strings: string[])"( comptime name: []const u8, comptime command: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + var args = try std.ArrayList(JSArgument).initCapacity(bun.default_allocator, callframe.arguments().len); defer { for (args.items) |*item| { @@ -788,9 +1159,12 @@ const compile = struct { pub fn @"(key: RedisKey, value: RedisValue, ...args: RedisValue)"( comptime name: []const u8, comptime command: []const u8, + comptime client_state_requirement: ClientStateRequirement, ) type { return struct { pub fn call(this: *JSValkeyClient, globalObject: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) bun.JSError!JSValue { + try testCorrectState(this, name, client_state_requirement); + var args = try std.ArrayList(JSArgument).initCapacity(bun.default_allocator, callframe.arguments().len); defer { for (args.items) |*item| { diff --git a/src/valkey/valkey.zig b/src/valkey/valkey.zig index 87b9cea495..f2c22bb2c6 100644 --- a/src/valkey/valkey.zig +++ b/src/valkey/valkey.zig @@ -18,6 +18,16 @@ pub const ConnectionFlags = struct { is_reconnecting: bool = false, auto_pipelining: bool = true, finalized: bool = false, + // This flag is a slight hack to allow returning the client instance in the + // promise which resolves when the connection is established. There are two + // modes through which a client may connect: + // 1. Connect through `client.connect()` which has the semantics of + // resolving the promise with the connection information. + // 2. Through `client.duplicate()` which creates a promise through + // `onConnect()` which resolves with the client instance itself. + // This flag is set to true in the latter case to indicate to the promise + // resolution delegation to resolve the promise with the client. + connection_promise_returns_client: bool = false, }; /// Valkey connection status @@ -28,6 +38,13 @@ pub const Status = enum { failed, }; +pub fn isActive(this: *const Status) bool { + return switch (this.*) { + .connected, .connecting => true, + else => false, + }; +} + pub const Command = @import("./ValkeyCommand.zig"); /// Valkey protocol types (standalone, TLS, Unix socket) @@ -106,6 +123,13 @@ pub const Address = union(enum) { port: u16, }, + pub fn hostname(this: *const Address) []const u8 { + return switch (this.*) { + .unix => |unix_addr| unix_addr, + .host => |h| h.host, + }; + } + pub fn connect(this: *const Address, client: *ValkeyClient, ctx: *bun.uws.SocketContext, is_tls: bool) !uws.AnySocket { switch (is_tls) { inline else => |tls| { @@ -155,6 +179,7 @@ pub const ValkeyClient = struct { username: []const u8 = "", database: u32 = 0, address: Address, + protocol: Protocol, connection_strings: []u8 = &.{}, @@ -211,6 +236,8 @@ pub const ValkeyClient = struct { } this.allocator.free(this.connection_strings); + // Note there is no need to deallocate username, password and hostname since they are + // within the this.connection_strings buffer. this.write_buffer.deinit(this.allocator); this.read_buffer.deinit(this.allocator); this.tls.deinit(); @@ -383,14 +410,14 @@ pub const ValkeyClient = struct { /// Mark the connection as failed with error message pub fn fail(this: *ValkeyClient, message: []const u8, err: protocol.RedisError) void { - debug("failed: {s}: {s}", .{ message, @errorName(err) }); + debug("failed: {s}: {}", .{ message, err }); if (this.status == .failed) return; if (this.flags.finalized) { // We can't run promises inside finalizers. if (this.queue.count + this.in_flight.count > 0) { const vm = this.vm; - const deferred_failrue = bun.new(DeferredFailure, .{ + const deferred_failure = bun.new(DeferredFailure, .{ // This memory is not owned by us. .message = bun.handleOom(bun.default_allocator.dupe(u8, message)), @@ -401,7 +428,7 @@ pub const ValkeyClient = struct { }); this.in_flight = .init(this.allocator); this.queue = .init(this.allocator); - deferred_failrue.enqueue(); + deferred_failure.enqueue(); } // Allow the finalizer to call .close() @@ -504,9 +531,9 @@ pub const ValkeyClient = struct { } /// Process data received from socket + /// + /// Caller refs / derefs. pub fn onData(this: *ValkeyClient, data: []const u8) void { - // Caller refs / derefs. - // Path 1: Buffer already has data, append and process from buffer if (this.read_buffer.remaining().len > 0) { this.read_buffer.write(this.allocator, data) catch @panic("failed to write to read buffer"); @@ -613,6 +640,69 @@ pub const ValkeyClient = struct { // If the loop finishes, the entire 'data' was processed without needing the buffer. } + /// Try handling this response as a subscriber-state response. + /// Returns `handled` if we handled it, `fallthrough` if we did not. + fn handleSubscribeResponse(this: *ValkeyClient, value: *protocol.RESPValue, pair: *ValkeyCommand.PromisePair) bun.JSError!enum { handled, fallthrough } { + // Resolve the promise with the potentially transformed value + var promise_ptr = &pair.promise; + const globalThis = this.globalObject(); + const loop = this.vm.eventLoop(); + + debug("Handling a subscribe response: {any}", .{value.*}); + loop.enter(); + defer loop.exit(); + + return switch (value.*) { + .Error => { + promise_ptr.reject(globalThis, value.toJS(globalThis)); + return .handled; + }, + .Push => |push| { + const p = this.parent(); + const subs_ctx = try p.getOrCreateSubscriptionCtxEnteringSubscriptionMode(); + const sub_count = try subs_ctx.channelsSubscribedToCount(globalThis); + + if (protocol.SubscriptionPushMessage.map.get(push.kind)) |msg_type| { + switch (msg_type) { + .subscribe => { + this.onValkeySubscribe(value); + promise_ptr.promise.resolve(globalThis, .jsNumber(sub_count)); + return .handled; + }, + .unsubscribe => { + try this.onValkeyUnsubscribe(); + promise_ptr.promise.resolve(globalThis, .js_undefined); + return .handled; + }, + else => { + // Other push messages (message, pmessage, etc) are not handled here + @branchHint(.cold); + this.fail( + "Push message is not a subscription message.", + protocol.RedisError.InvalidResponseType, + ); + return .fallthrough; + }, + } + } else { + // We should rarely reach this point. If we're guaranteed to be handling a subscribe/unsubscribe, + // then this is an unexpected path. + @branchHint(.cold); + this.fail( + "Push message is not a subscription message.", + protocol.RedisError.InvalidResponseType, + ); + return .handled; + } + }, + else => { + // This may be a regular command response. Let's pass it down + // to the next handler. + return .fallthrough; + }, + }; + } + fn handleHelloResponse(this: *ValkeyClient, value: *protocol.RESPValue) void { debug("Processing HELLO response", .{}); @@ -705,9 +795,68 @@ pub const ValkeyClient = struct { }, }; } + // Let's load the promise pair. + var pair_maybe = this.in_flight.readItem(); + + // We handle subscriptions specially because they are not regular + // commands and their failure will potentially cause the client to drop + // out of subscriber mode. + if (this.parent().isSubscriber()) { + debug("This client is a subscriber. Handling as subscriber...", .{}); + + // There are multiple different commands we may receive in + // subscriber mode. One is from a client.subscribe() call which + // requires that a promise is in-flight, but otherwise, we may also + // receive push messages from the server that do not have an + // associated promise. + if (pair_maybe) |*pair| { + debug("There is a request in flight. Handling as a subscribe request...", .{}); + if ((try this.handleSubscribeResponse(value, pair)) == .handled) { + return; + } + } + + switch (value.*) { + .Error => |err| { + this.fail(err, protocol.RedisError.InvalidResponse); + return; + }, + .Push => |push| { + if (protocol.SubscriptionPushMessage.map.get(push.kind)) |msg_type| { + switch (msg_type) { + .message => { + @branchHint(.likely); + debug("Received a message.", .{}); + this.onValkeyMessage(push.data); + return; + }, + else => { + @branchHint(.cold); + debug("Received non-message push without promise: {any}", .{push.data}); + return; + }, + } + } else { + @branchHint(.cold); + this.fail( + "Unexpected push message kind without promise", + protocol.RedisError.InvalidResponseType, + ); + return; + } + }, + else => { + // In the else case, we fall through to the regular + // handler. Subscribers can send .Push commands which have + // the same semantics as regular commands. + }, + } + + debug("Treating subscriber response as a regular command...", .{}); + } // For regular commands, get the next command+promise pair from the queue - var pair = this.in_flight.readItem() orelse { + var pair = pair_maybe orelse { debug("Received response but no promise in queue", .{}); return; }; @@ -932,7 +1081,14 @@ pub const ValkeyClient = struct { if (this.flags.enable_offline_queue) { try this.enqueue(command, &promise); } else { - promise.reject(globalThis, globalThis.ERR(.REDIS_CONNECTION_CLOSED, "Connection is closed and offline queue is disabled", .{}).toJS()); + promise.reject( + globalThis, + globalThis.ERR( + .REDIS_CONNECTION_CLOSED, + "Connection is closed and offline queue is disabled", + .{}, + ).toJS(), + ); } }, .failed => { @@ -985,6 +1141,18 @@ pub const ValkeyClient = struct { this.parent().onValkeyConnect(value); } + pub fn onValkeySubscribe(this: *ValkeyClient, value: *protocol.RESPValue) void { + this.parent().onValkeySubscribe(value); + } + + pub fn onValkeyUnsubscribe(this: *ValkeyClient) bun.JSError!void { + return this.parent().onValkeyUnsubscribe(); + } + + pub fn onValkeyMessage(this: *ValkeyClient, value: []protocol.RESPValue) void { + this.parent().onValkeyMessage(value); + } + pub fn onValkeyReconnect(this: *ValkeyClient) void { this.parent().onValkeyReconnect(); } @@ -999,9 +1167,9 @@ pub const ValkeyClient = struct { }; // Auto-pipelining - const debug = bun.Output.scoped(.Redis, .visible); +const ValkeyCommand = @import("./ValkeyCommand.zig"); const protocol = @import("./valkey_protocol.zig"); const std = @import("std"); diff --git a/src/valkey/valkey_protocol.zig b/src/valkey/valkey_protocol.zig index a02514f7c6..ed63377925 100644 --- a/src/valkey/valkey_protocol.zig +++ b/src/valkey/valkey_protocol.zig @@ -656,6 +656,18 @@ pub const Attribute = struct { } }; +pub const SubscriptionPushMessage = enum(u2) { + message, + subscribe, + unsubscribe, + + pub const map = bun.ComptimeStringMap(SubscriptionPushMessage, .{ + .{ "message", .message }, + .{ "subscribe", .subscribe }, + .{ "unsubscribe", .unsubscribe }, + }); +}; + const std = @import("std"); const bun = @import("bun"); diff --git a/test/_util/numeric.ts b/test/_util/numeric.ts new file mode 100644 index 0000000000..2b4729acec --- /dev/null +++ b/test/_util/numeric.ts @@ -0,0 +1,192 @@ +/** + * Parameter accepted by some of the algorithms in this namespace which + * controls the input/output format of numbers. + * + * @example + * ```ts + * // Returns an integer + * numeric.random.between(0, 10, { domain: "integral" }); + * // Returns a floating point number + * numeric.random.between(0, 10, { domain: "floating" }); + * ``` + */ +export type FormatSpecifier = { + domain: "floating" | "integral"; +}; + +const DefaultFormatSpecifier: FormatSpecifier = { + domain: "floating", +}; + +/** + * Generate an array of evenly-spaced numbers in a range. + * + * The name iota comes from https://aplwiki.com/wiki/Index_Generator. It is + * commonly used across programming languages and libraries. + * + * @param count The total number of points to generate. + * @param step The step size between each value. + * @returns An array of evenly-spaced numbers. + */ +export function iota(count: number, step: number = 1) { + return Array.from({ length: count }, (_, i) => i * step); +} + +/** + * Create an array of linearly spaced numbers. + * + * @param start The starting value of the sequence. + * @param end The end value of the sequence. + * @param numPoints The number of points to generate. + * + * @returns An array of numbers, spaced evenly in the linear space. + */ +export function linSpace(start: number, end: number, numPoints: number): number[] { + if (numPoints <= 0) return []; + if (numPoints === 1) return [start]; + if (numPoints === 2) return [start, end]; + const step = (end - start) / (numPoints - 1); + + return iota(numPoints).map(i => start + i * step); +} + +/** + * Create an array of exponentially spaced numbers. + * + * @param start The starting value of the sequence. + * @param end The end value of the sequence. + * @param numPoints The number of points to generate. + * @param base The exponential base + * + * @returns An array of numbers, spaced evenly in the exponential space. + */ +export function expSpace(start: number, end: number, numPoints: number, base: number): number[] { + if (numPoints <= 0) return []; + if (numPoints === 1) return [start]; + + if (!Number.isFinite(base) || base <= 0 || base === 1) { + throw new Error('expSpace: "base" must be > 0 and !== 1'); + } + + // Generate exponentially spaced values from 0 to 1 + const exponentialValues = Array.from( + { length: numPoints }, + (_, i) => (Math.pow(base, i / (numPoints - 1)) - 1) / (base - 1), + ); + + // Scale and shift to fit the [start, end] range + return exponentialValues.map(t => start + t * (end - start)); +} + +export namespace stats { + /** + * Computes the Pearson correlation coefficient between two arrays of numbers. + * + * The Pearson correlation coefficient, also known as Pearson's r, is a + * statistical measure that quantifies the strength and direction of a linear + * relationship between two variables. + * + * @param xs The first array of numbers. + * @param ys The second array of numbers. + * @returns The Pearson correlation coefficient, or 0 if there is no correlation. + */ + export function computePearsonCorrelation(xs: number[], ys: number[]): number { + if (xs.length !== ys.length || xs.length === 0) { + throw new Error("Input arrays must have the same non-zero length"); + } + + const n = xs.length; + const sumX = xs.reduce((a, b) => a + b, 0); + const sumY = ys.reduce((a, b) => a + b, 0); + const sumXY = xs.reduce((sum, x, i) => sum + x * ys[i], 0); + const sumX2 = xs.reduce((sum, x) => sum + x * x, 0); + const sumY2 = ys.reduce((sum, y) => sum + y * y, 0); + + // Compute the Pearson correlation coefficient (r) using the formula: + // r = (n * Σ(xy) - Σx * Σy) / sqrt[(n * Σ(x^2) - (Σx)^2) * (n * Σ(y^2) - (Σy)^2)] + const numerator = n * sumXY - sumX * sumY; + const denominator = Math.sqrt((n * sumX2 - sumX * sumX) * (n * sumY2 - sumY * sumY)); + + if (denominator === 0) { + return 0; // Avoid division by zero; implies no correlation + } + + return numerator / denominator; + } + + /** + * Compute the slope of the best-fit line using linear regression. + * + * @param xs The random variable. + * @param ys The dependent variable. + * @returns The slope of the best-fit line. + */ + export function computeLinearSlope(xs: number[], ys: number[]): number { + if (xs.length !== ys.length || xs.length === 0) { + throw new Error("Input arrays must have the same non-zero length"); + } + + const n = xs.length; + const sumX = xs.reduce((a, b) => a + b, 0); + const sumY = ys.reduce((a, b) => a + b, 0); + const sumXY = xs.reduce((sum, x, i) => sum + x * ys[i], 0); + const sumX2 = xs.reduce((sum, x) => sum + x * x, 0); + + // Compute the slope (m) using the formula: + // m = (n * Σ(xy) - Σx * Σy) / (n * Σ(x^2) - (Σx)^2) + const slope = (n * sumXY - sumX * sumY) / (n * sumX2 - sumX * sumX); + return slope; + } + + /** + * Compute euclidean the mean (average) of an array of numbers. + * + * @param xs An array of numbers. + * @returns The mean of the numbers. + */ + export function computeMean(xs: number[]): number { + return xs.reduce((a, b) => a + b, 0) / xs.length; + } + + /** + * Compute the average absolute deviation of an array of numbers. + * + * The average absolute deviation (AAD) of a data set is the average of the + * absolute deviations from a central point. + * + * @param xs An array of numbers. + * @returns The average absolute deviation of the numbers. + */ + export function computeAverageAbsoluteDeviation(xs: number[]): number { + const mean = computeMean(xs); + return xs.reduce((sum, x) => sum + Math.abs(x - mean), 0) / xs.length; + } +} + +/** + * Utilities for numeric randomness. + * + * @todo Perhaps this does not belong in the numeric namespace. + */ +export namespace random { + /** + * Generate a random number between the specified range. + * + * @param min The minimum value (inclusive for integrals). + * @param max The maximum value (inclusive for integrals). + * @param format The format specifier for the random number. + * @returns A random number between min and max, formatted according to the specifier. + */ + export function between(min: number, max: number, format: FormatSpecifier = DefaultFormatSpecifier): number { + if (!Number.isFinite(min) || !Number.isFinite(max)) throw new Error("min/max must be finite"); + if (max < min) throw new Error("max must be >= min"); + + if (format.domain === "floating") { + return Math.random() * (max - min) + min; + } + + const lo = Math.ceil(min); + const hi = Math.floor(max); + return Math.floor(Math.random() * (hi - lo + 1)) + lo; + } +} diff --git a/test/harness.ts b/test/harness.ts index 9b46505c57..f54cf878a3 100644 --- a/test/harness.ts +++ b/test/harness.ts @@ -7,12 +7,13 @@ import { gc as bunGC, sleepSync, spawnSync, unsafe, which, write } from "bun"; import { heapStats } from "bun:jsc"; -import { afterAll, beforeAll, describe, expect, test } from "bun:test"; +import { beforeAll, describe, expect } from "bun:test"; import { ChildProcess, execSync, fork } from "child_process"; import { readdir, readFile, readlink, rm, writeFile } from "fs/promises"; import fs, { closeSync, openSync, rmSync } from "node:fs"; import os from "node:os"; import { dirname, isAbsolute, join } from "path"; +import * as numeric from "_util/numeric.ts"; type Awaitable = T | Promise; @@ -357,7 +358,7 @@ export function bunRunAsScript( } export function randomLoneSurrogate() { - const n = randomRange(0, 2); + const n = numeric.random.between(0, 2, { domain: "integral" }); if (n === 0) return randomLoneHighSurrogate(); return randomLoneLowSurrogate(); } @@ -370,16 +371,12 @@ export function randomInvalidSurrogatePair() { // Generates a random lone high surrogate (from the range D800-DBFF) export function randomLoneHighSurrogate() { - return String.fromCharCode(randomRange(0xd800, 0xdbff)); + return String.fromCharCode(numeric.random.between(0xd800, 0xdbff, { domain: "integral" })); } // Generates a random lone high surrogate (from the range DC00-DFFF) export function randomLoneLowSurrogate() { - return String.fromCharCode(randomRange(0xdc00, 0xdfff)); -} - -function randomRange(low: number, high: number): number { - return low + Math.floor(Math.random() * (high - low)); + return String.fromCharCode(numeric.random.between(0xdc00, 0xdfff, { domain: "integral" })); } export function runWithError(cb: () => unknown): Error | undefined { diff --git a/test/integration/bun-types/fixture/redis.ts b/test/integration/bun-types/fixture/redis.ts new file mode 100644 index 0000000000..f7e536200d --- /dev/null +++ b/test/integration/bun-types/fixture/redis.ts @@ -0,0 +1,31 @@ +import { expectType } from "./utilities"; + +expectType(Bun.redis.publish("hello", "world")).is>(); + +const copy = await Bun.redis.duplicate(); +expectType(copy.connected).is(); +expectType(copy).is(); + +const listener: Bun.RedisClient.StringPubSubListener = (message, channel) => { + expectType(message).is(); + expectType(channel).is(); +}; +Bun.redis.subscribe("hello", listener); + +// Buffer subscriptions are not yet implemented +// const bufferListener: Bun.RedisClient.BufferPubSubListener = (message, channel) => { +// expectType(message).is>(); +// expectType(channel).is(); +// }; +// Bun.redis.subscribe("hello", bufferListener); + +expectType( + copy.subscribe("hello", message => { + expectType(message).is(); + }), +).is>(); + +await copy.unsubscribe(); +await copy.unsubscribe("hello"); + +expectType(copy.unsubscribe("hello", () => {})).is>(); diff --git a/test/js/valkey/test-utils.ts b/test/js/valkey/test-utils.ts index 34ce3347cc..04deb19046 100644 --- a/test/js/valkey/test-utils.ts +++ b/test/js/valkey/test-utils.ts @@ -5,8 +5,6 @@ import path from "path"; import * as dockerCompose from "../../docker/index.ts"; import { UnixDomainSocketProxy } from "../../unix-domain-socket-proxy.ts"; -import * as fs from "node:fs"; -import * as os from "node:os"; const dockerCLI = dockerExe() as string; export const isEnabled = @@ -165,7 +163,7 @@ async function startContainer(): Promise { REDIS_PORT = port; REDIS_TLS_PORT = tlsPort; REDIS_HOST = redisInfo.host; - REDIS_UNIX_SOCKET = unixSocketProxy.path; // Use the proxy socket + REDIS_UNIX_SOCKET = unixSocketProxy.path; // Use the proxy socket DEFAULT_REDIS_URL = `redis://${REDIS_HOST}:${REDIS_PORT}`; TLS_REDIS_URL = `rediss://${REDIS_HOST}:${REDIS_TLS_PORT}`; UNIX_REDIS_URL = `redis+unix://${REDIS_UNIX_SOCKET}`; @@ -223,12 +221,12 @@ import { tmpdir } from "os"; * Create a new client with specific connection type */ export function createClient( - connectionType: ConnectionType = ConnectionType.TCP, - customOptions = {}, - dbId: number | undefined = undefined, + connectionType: ConnectionType = ConnectionType.TCP, + customOptions = {}, + dbId: number | undefined = undefined, ) { let url: string; - const mkUrl = (baseUrl: string) => dbId ? `${baseUrl}/${dbId}`: baseUrl; + const mkUrl = (baseUrl: string) => (dbId ? `${baseUrl}/${dbId}` : baseUrl); let options: any = {}; context.id++; @@ -314,6 +312,9 @@ export interface TestContext { redisWriteOnly?: RedisClient; id: number; restartServer: () => Promise; + __subscriberClientPool: RedisClient[]; + newSubscriberClient: (connectionType: ConnectionType) => Promise; + cleanupSubscribers: () => Promise; } // Create a singleton promise for Docker initialization @@ -336,10 +337,30 @@ export const context: TestContext = { redisWriteOnly: undefined, id, restartServer: restartRedisContainer, + __subscriberClientPool: [], + newSubscriberClient: async function (connectionType: ConnectionType) { + const client = createClient(connectionType); + this.__subscriberClientPool.push(client); + await client.connect(); + return client; + }, + cleanupSubscribers: async function () { + for (const client of this.__subscriberClientPool) { + try { + await client.unsubscribe(); + } catch {} + + if (client.connected) { + client.close(); + } + } + + this.__subscriberClientPool = []; + }, }; export { context as ctx }; -if (isEnabled) +if (isEnabled) { beforeAll(async () => { // Initialize Docker container once for all tests if (!dockerInitPromise) { @@ -405,8 +426,9 @@ if (isEnabled) // console.warn("Test initialization failed - tests may be skipped"); // } }); +} -if (isEnabled) +if (isEnabled) { afterAll(async () => { console.log("Cleaning up Redis container"); if (!context.redis?.connected) { @@ -426,26 +448,26 @@ if (isEnabled) } // Disconnect all clients - await context.redis.close(); + context.redis.close(); if (context.redisTLS) { - await context.redisTLS.close(); + context.redisTLS.close(); } if (context.redisUnix) { - await context.redisUnix.close(); + context.redisUnix.close(); } if (context.redisAuth) { - await context.redisAuth.close(); + context.redisAuth.close(); } if (context.redisReadOnly) { - await context.redisReadOnly.close(); + context.redisReadOnly.close(); } if (context.redisWriteOnly) { - await context.redisWriteOnly.close(); + context.redisWriteOnly.close(); } // Clean up Unix socket proxy if it exists @@ -456,6 +478,7 @@ if (isEnabled) console.error("Error during test cleanup:", err); } }); +} if (!isEnabled) { console.warn("Redis is not enabled, skipping tests"); @@ -545,6 +568,14 @@ async function getRedisContainerName(): Promise { /** * Restart the Redis container to simulate connection drop + * + * Restarts the container identified by the test harness and waits briefly for it + * to come back online (approximately 2 seconds). Use this to simulate a server + * restart or connection drop during tests. + * + * @returns A promise that resolves when the restart and short wait complete. + * @throws If the Docker restart command exits with a non-zero code; the error + * message includes the container's stderr output. */ export async function restartRedisContainer(): Promise { // If using docker-compose, get the actual container name @@ -611,3 +642,50 @@ export async function restartRedisContainer(): Promise { } } } + +/** + * @returns true or false with approximately equal probability + */ +export function randomCoinFlip(): boolean { + return Math.floor(Math.random() * 2) == 0; +} + +/** + * Utility for creating a counter that can be awaited until it reaches a target value. + */ +export function awaitableCounter(timeoutMs: number = 1000) { + let activeResolvers: [number, NodeJS.Timeout, (value: number) => void][] = []; + let currentCount = 0; + + return { + increment: () => { + currentCount++; + + for (const [value, alarm, resolve] of activeResolvers) { + alarm.close(); + + if (currentCount >= value) { + resolve(currentCount); + } + } + + // Remove resolved promises + activeResolvers = activeResolvers.filter(([value]) => currentCount < value); + }, + count: () => currentCount, + + untilValue: (value: number) => + new Promise((resolve, reject) => { + if (currentCount >= value) { + resolve(currentCount); + return; + } + + const alarm = setTimeout(() => { + reject(new Error(`Timeout waiting for counter to reach ${value}, current is ${currentCount}.`)); + }, timeoutMs); + + activeResolvers.push([value, alarm, resolve]); + }), + }; +} diff --git a/test/js/valkey/valkey.failing-subscriber.ts b/test/js/valkey/valkey.failing-subscriber.ts new file mode 100644 index 0000000000..872ea5a072 --- /dev/null +++ b/test/js/valkey/valkey.failing-subscriber.ts @@ -0,0 +1,48 @@ +// Program which sets up a subscriber outside the scope of the main Jest process. +// Used within valkey.test.ts. +// +// DO NOT IMPORT FROM test-utils.ts. That import is janky and will have different state at different from different +// importers. +import {RedisClient} from "bun"; + +function trySend(msg: any) { + if (process === undefined || process.send === undefined) { + throw new Error("process is undefined"); + } + + process.send(msg); +} + +let redisUrlResolver: (url: string) => void; +const redisUrl = new Promise((resolve) => { + redisUrlResolver = resolve; +}); + +process.on("message", (msg: any) => { + if (msg.event === "start") { + redisUrlResolver(msg.url); + } else { + throw new Error("Unknown event " + msg.event); + } +}); + +const CHANNEL = "error-callback-channel"; + +// We will wait for the parent process to tell us to start. +const url = await redisUrl; +const subscriber = new RedisClient(url); +await subscriber.connect(); +trySend({ event: "ready" }); + +let counter = 0; +await subscriber.subscribe(CHANNEL, () => { + if ((counter++) === 1) { + throw new Error("Intentional callback error"); + } + + trySend({ event: "message", index: counter }); +}); + +process.on("uncaughtException", (e) => { + trySend({ event: "exception", exMsg: e.message }); +}); diff --git a/test/js/valkey/valkey.test.ts b/test/js/valkey/valkey.test.ts index a70c9659b8..11b1902fbd 100644 --- a/test/js/valkey/valkey.test.ts +++ b/test/js/valkey/valkey.test.ts @@ -1,12 +1,14 @@ -import { randomUUIDv7, RedisClient } from "bun"; +import { randomUUIDv7, RedisClient, spawn } from "bun"; import { beforeAll, beforeEach, describe, expect, test } from "bun:test"; import { + awaitableCounter, ConnectionType, createClient, ctx, DEFAULT_REDIS_URL, expectType, isEnabled, + randomCoinFlip, setupDockerContainer, } from "./test-utils"; @@ -26,6 +28,7 @@ describe.skipIf(!isEnabled)("Valkey Redis Client", () => { } // Flush all data for clean test state + await ctx.redis.connect(); await ctx.redis.send("FLUSHALL", ["SYNC"]); }); @@ -40,11 +43,11 @@ describe.skipIf(!isEnabled)("Valkey Redis Client", () => { expect(setResult).toMatchInlineSnapshot(`"OK"`); const setResult2 = await redis.set(testKey, testValue, "GET"); - expect(setResult2).toMatchInlineSnapshot(`"Hello from Bun Redis!"`); + expect(setResult2).toMatchInlineSnapshot(`"${testValue}"`); // GET should return the value we set const getValue = await redis.get(testKey); - expect(getValue).toMatchInlineSnapshot(`"Hello from Bun Redis!"`); + expect(getValue).toMatchInlineSnapshot(`"${testValue}"`); }); test("should test key existence", async () => { @@ -235,4 +238,572 @@ describe.skipIf(!isEnabled)("Valkey Redis Client", () => { expect(valueAfterStop).toBe(TEST_VALUE); }); }); + + describe("PUB/SUB", () => { + var i = 0; + const testChannel = () => { + return `test-channel-${i++}`; + }; + const testKey = () => { + return `test-key-${i++}`; + }; + const testValue = () => { + return `test-value-${i++}`; + }; + const testMessage = () => { + return `test-message-${i++}`; + }; + + beforeEach(async () => { + // The PUB/SUB tests expect that ctx.redis is connected but not in subscriber mode. + await ctx.cleanupSubscribers(); + }); + + test("publishing to a channel does not fail", async () => { + expect(await ctx.redis.publish(testChannel(), testMessage())).toBe(0); + }); + + test("setting in subscriber mode gracefully fails", async () => { + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + + await subscriber.subscribe(testChannel(), () => {}); + + expect(() => subscriber.set(testKey(), testValue())).toThrow( + "RedisClient.prototype.set cannot be called while in subscriber mode", + ); + + await subscriber.unsubscribe(testChannel()); + }); + + test("setting after unsubscribing works", async () => { + const channel = testChannel(); + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + await subscriber.subscribe(channel, () => {}); + await subscriber.unsubscribe(channel); + expect(ctx.redis.set(testKey(), testValue())).resolves.toEqual("OK"); + }); + + test("subscribing to a channel receives messages", async () => { + const TEST_MESSAGE_COUNT = 128; + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + const channel = testChannel(); + const message = testMessage(); + + const counter = awaitableCounter(); + await subscriber.subscribe(channel, (message, channel) => { + counter.increment(); + expect(channel).toBe(channel); + expect(message).toBe(message); + }); + + Array.from({ length: TEST_MESSAGE_COUNT }).forEach(async () => { + expect(await ctx.redis.publish(channel, message)).toBe(1); + }); + + await counter.untilValue(TEST_MESSAGE_COUNT); + expect(counter.count()).toBe(TEST_MESSAGE_COUNT); + }); + + test("messages are received in order", async () => { + const channel = testChannel(); + + await ctx.redis.set("START-TEST", "1"); + const TEST_MESSAGE_COUNT = 1024; + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + + const counter = awaitableCounter(); + var receivedMessages: string[] = []; + await subscriber.subscribe(channel, message => { + receivedMessages.push(message); + counter.increment(); + }); + + const sentMessages = Array.from({ length: TEST_MESSAGE_COUNT }).map(() => { + return randomUUIDv7(); + }); + await Promise.all( + sentMessages.map(async message => { + expect(await ctx.redis.publish(channel, message)).toBe(1); + }), + ); + + await counter.untilValue(TEST_MESSAGE_COUNT); + expect(receivedMessages.length).toBe(sentMessages.length); + expect(receivedMessages).toEqual(sentMessages); + + await subscriber.unsubscribe(channel); + + await ctx.redis.set("STOP-TEST", "1"); + }); + + test("subscribing to multiple channels receives messages", async () => { + const TEST_MESSAGE_COUNT = 128; + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + + const channels = [testChannel(), testChannel()]; + const counter = awaitableCounter(); + + var receivedMessages: { [channel: string]: string[] } = {}; + await subscriber.subscribe(channels, (message, channel) => { + receivedMessages[channel] = receivedMessages[channel] || []; + receivedMessages[channel].push(message); + counter.increment(); + }); + + var sentMessages: { [channel: string]: string[] } = {}; + for (let i = 0; i < TEST_MESSAGE_COUNT; i++) { + const channel = channels[randomCoinFlip() ? 0 : 1]; + const message = randomUUIDv7(); + + expect(await ctx.redis.publish(channel, message)).toBe(1); + + sentMessages[channel] = sentMessages[channel] || []; + sentMessages[channel].push(message); + } + + await counter.untilValue(TEST_MESSAGE_COUNT); + + // Check that we received messages on both channels + expect(Object.keys(receivedMessages).sort()).toEqual(Object.keys(sentMessages).sort()); + + // Check messages match for each channel + for (const channel of channels) { + if (sentMessages[channel]) { + expect(receivedMessages[channel]).toEqual(sentMessages[channel]); + } + } + + await subscriber.unsubscribe(channels); + }); + + test("unsubscribing from specific channels while remaining subscribed to others", async () => { + const channel1 = "channel-1"; + const channel2 = "channel-2"; + const channel3 = "channel-3"; + + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + let receivedMessages: { [channel: string]: string[] } = {}; + + // Total counter for all messages we expect to receive: 3 initial + 2 after unsubscribe = 5 total + const counter = awaitableCounter(); + + // Subscribe to three channels + await subscriber.subscribe([channel1, channel2, channel3], (message, channel) => { + receivedMessages[channel] = receivedMessages[channel] || []; + receivedMessages[channel].push(message); + counter.increment(); + }); + + // Send initial messages to all channels + expect(await ctx.redis.publish(channel1, "msg1-before")).toBe(1); + expect(await ctx.redis.publish(channel2, "msg2-before")).toBe(1); + expect(await ctx.redis.publish(channel3, "msg3-before")).toBe(1); + + // Wait for initial messages, then unsubscribe from channel2 + await counter.untilValue(3); + await subscriber.unsubscribe(channel2); + + // Send messages after unsubscribing from channel2 + expect(await ctx.redis.publish(channel1, "msg1-after")).toBe(1); + expect(await ctx.redis.publish(channel2, "msg2-after")).toBe(0); + expect(await ctx.redis.publish(channel3, "msg3-after")).toBe(1); + + await counter.untilValue(5); + + // Check we received messages only on subscribed channels + expect(receivedMessages[channel1]).toEqual(["msg1-before", "msg1-after"]); + expect(receivedMessages[channel2]).toEqual(["msg2-before"]); // No "msg2-after" + expect(receivedMessages[channel3]).toEqual(["msg3-before", "msg3-after"]); + + await subscriber.unsubscribe([channel1, channel3]); + }); + + test("subscribing to the same channel multiple times", async () => { + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + const channel = testChannel(); + + const counter = awaitableCounter(); + + let callCount = 0; + const listener = () => { + callCount++; + counter.increment(); + }; + + let callCount2 = 0; + const listener2 = () => { + callCount2++; + counter.increment(); + }; + + // Subscribe to the same channel twice + await subscriber.subscribe(channel, listener); + await subscriber.subscribe(channel, listener2); + + // Publish a single message + expect(await ctx.redis.publish(channel, "test-message")).toBe(1); + + await counter.untilValue(2); + + // Both listeners should have been called once. + expect(callCount).toBe(1); + expect(callCount2).toBe(1); + + await subscriber.unsubscribe(channel); + }); + + test("empty string messages", async () => { + const channel = "empty-message-channel"; + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + const counter = awaitableCounter(); + let receivedMessage: string | undefined = undefined; + await subscriber.subscribe(channel, message => { + receivedMessage = message; + counter.increment(); + }); + + expect(await ctx.redis.publish(channel, "")).toBe(1); + await counter.untilValue(1); + + expect(receivedMessage).not.toBeUndefined(); + expect(receivedMessage!).toBe(""); + + await subscriber.unsubscribe(channel); + }); + + test("special characters in channel names", async () => { + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + const specialChannels = [ + "channel:with:colons", + "channel with spaces", + "channel-with-unicode-😀", + "channel[with]brackets", + "channel@with#special$chars", + ]; + + for (const channel of specialChannels) { + const counter = awaitableCounter(); + let received = false; + await subscriber.subscribe(channel, () => { + received = true; + counter.increment(); + }); + + expect(await ctx.redis.publish(channel, "test")).toBe(1); + await counter.untilValue(1); + + expect(received).toBe(true); + await subscriber.unsubscribe(channel); + } + }); + + test("ping works in subscription mode", async () => { + const channel = "ping-test-channel"; + + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + await subscriber.subscribe(channel, () => {}); + + // Ping should work in subscription mode + const pong = await subscriber.ping(); + expect(pong).toBe("PONG"); + + const customPing = await subscriber.ping("hello"); + expect(customPing).toBe("hello"); + }); + + test("publish does not work from a subscribed client", async () => { + const channel = "self-publish-channel"; + + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + await subscriber.subscribe(channel, () => {}); + + // Publishing from the same client should work + expect(async () => subscriber.publish(channel, "self-published")).toThrow(); + }); + + test("complete unsubscribe restores normal command mode", async () => { + const channel = "restore-test-channel"; + const testKey = "restore-test-key"; + + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + await subscriber.subscribe(channel, () => {}); + + // Should fail in subscription mode + expect(() => subscriber.set(testKey, testValue())).toThrow( + "RedisClient.prototype.set cannot be called while in subscriber mode.", + ); + + // Unsubscribe from all channels + await subscriber.unsubscribe(); + + // Should work after unsubscribing + const result = await ctx.redis.set(testKey, "value"); + expect(result).toBe("OK"); + + const value = await ctx.redis.get(testKey); + expect(value).toBe("value"); + }); + + test("publishing without subscribers succeeds", async () => { + const channel = "no-subscribers-channel"; + + // Publishing without subscribers should not throw + expect(await ctx.redis.publish(channel, "message")).toBe(0); + }); + + test("unsubscribing from non-subscribed channels", async () => { + const channel = "never-subscribed-channel"; + + expect(() => ctx.redis.unsubscribe(channel)).toThrow( + "RedisClient.prototype.unsubscribe can only be called while in subscriber mode.", + ); + }); + + test("callback errors don't crash the client", async () => { + const channel = "error-callback-channel"; + + const STEP_SUBSCRIBED = 1; + const STEP_FIRST_MESSAGE = 2; + const STEP_SECOND_MESSAGE = 3; + const STEP_THIRD_MESSAGE = 4; + + // stepCounter is a slight hack to track the progress of the subprocess. + const stepCounter = awaitableCounter(); + let currentMessage: any = {}; + + const subscriberProc = spawn({ + cmd: [self.process.execPath, "run", `${__dirname}/valkey.failing-subscriber.ts`], + stdout: "inherit", + stderr: "inherit", + ipc: msg => { + currentMessage = msg; + stepCounter.increment(); + }, + env: { + ...process.env, + NODE_ENV: "development", + }, + }); + + subscriberProc.send({ event: "start", url: DEFAULT_REDIS_URL }); + + try { + await stepCounter.untilValue(STEP_SUBSCRIBED); + expect(currentMessage.event).toBe("ready"); + + // Send multiple messages + expect(await ctx.redis.publish(channel, "message1")).toBe(1); + await stepCounter.untilValue(STEP_FIRST_MESSAGE); + expect(currentMessage.event).toBe("message"); + expect(currentMessage.index).toBe(1); + + // Now, the subscriber process will crash + expect(await ctx.redis.publish(channel, "message2")).toBe(1); + await stepCounter.untilValue(STEP_SECOND_MESSAGE); + expect(currentMessage.event).toBe("exception"); + //expect(currentMessage.index).toBe(2); + + // But it should recover and continue receiving messages + expect(await ctx.redis.publish(channel, "message3")).toBe(1); + await stepCounter.untilValue(STEP_THIRD_MESSAGE); + expect(currentMessage.event).toBe("message"); + expect(currentMessage.index).toBe(3); + } finally { + subscriberProc.kill(); + } + }); + + test("subscriptions return correct counts", async () => { + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + expect(await subscriber.subscribe("chan1", () => {})).toBe(1); + expect(await subscriber.subscribe("chan2", () => {})).toBe(2); + }); + + test("unsubscribing from listeners", async () => { + const channel = "error-callback-channel"; + + const subscriber = createClient(ConnectionType.TCP); + await subscriber.connect(); + + // First phase: both listeners should receive 1 message each (2 total) + const counter = awaitableCounter(); + let messageCount1 = 0; + const listener1 = () => { + messageCount1++; + counter.increment(); + }; + await subscriber.subscribe(channel, listener1); + + let messageCount2 = 0; + const listener2 = () => { + messageCount2++; + counter.increment(); + }; + await subscriber.subscribe(channel, listener2); + + await ctx.redis.publish(channel, "message1"); + await counter.untilValue(2); + + expect(messageCount1).toBe(1); + expect(messageCount2).toBe(1); + + console.log("Unsubscribing listener2"); + await subscriber.unsubscribe(channel, listener2); + + await ctx.redis.publish(channel, "message1"); + await counter.untilValue(3); + + expect(messageCount1).toBe(2); + expect(messageCount2).toBe(1); + }); + }); + + describe("duplicate()", () => { + test("should create duplicate of connected client that gets connected", async () => { + const duplicate = await ctx.redis.duplicate(); + + expect(duplicate.connected).toBe(true); + expect(duplicate).not.toBe(ctx.redis); + + // Both should work independently + await ctx.redis.set("test-original", "original-value"); + await duplicate.set("test-duplicate", "duplicate-value"); + + expect(await ctx.redis.get("test-duplicate")).toBe("duplicate-value"); + expect(await duplicate.get("test-original")).toBe("original-value"); + + duplicate.close(); + }); + + test("should preserve connection configuration in duplicate", async () => { + await ctx.redis.connect(); + + const duplicate = await ctx.redis.duplicate(); + + // Both clients should be able to perform the same operations + const testKey = `duplicate-config-test-${randomUUIDv7().substring(0, 8)}`; + const testValue = "test-value"; + + await ctx.redis.set(testKey, testValue); + const retrievedValue = await duplicate.get(testKey); + + expect(retrievedValue).toBe(testValue); + + duplicate.close(); + }); + + test("should allow duplicate to work independently from original", async () => { + const duplicate = await ctx.redis.duplicate(); + + // Close original, duplicate should still work + duplicate.close(); + + const testKey = `independent-test-${randomUUIDv7().substring(0, 8)}`; + const testValue = "independent-value"; + + await ctx.redis.set(testKey, testValue); + const retrievedValue = await ctx.redis.get(testKey); + + expect(retrievedValue).toBe(testValue); + }); + + test("should handle duplicate of client in subscriber mode", async () => { + const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); + + const testChannel = "test-subscriber-duplicate"; + + // Put original client in subscriber mode + await subscriber.subscribe(testChannel, () => {}); + + const duplicate = await subscriber.duplicate(); + + // Duplicate should not be in subscriber mode + expect(() => duplicate.set("test-key", "test-value")).not.toThrow(); + + await subscriber.unsubscribe(testChannel); + }); + + test("should create multiple duplicates from same client", async () => { + await ctx.redis.connect(); + + const duplicate1 = await ctx.redis.duplicate(); + const duplicate2 = await ctx.redis.duplicate(); + const duplicate3 = await ctx.redis.duplicate(); + + // All should be connected + expect(duplicate1.connected).toBe(true); + expect(duplicate2.connected).toBe(true); + expect(duplicate3.connected).toBe(true); + + // All should work independently + const testKey = `multi-duplicate-test-${randomUUIDv7().substring(0, 8)}`; + await duplicate1.set(`${testKey}-1`, "value-1"); + await duplicate2.set(`${testKey}-2`, "value-2"); + await duplicate3.set(`${testKey}-3`, "value-3"); + + expect(await duplicate1.get(`${testKey}-1`)).toBe("value-1"); + expect(await duplicate2.get(`${testKey}-2`)).toBe("value-2"); + expect(await duplicate3.get(`${testKey}-3`)).toBe("value-3"); + + // Cross-check: each duplicate can read what others wrote + expect(await duplicate1.get(`${testKey}-2`)).toBe("value-2"); + expect(await duplicate2.get(`${testKey}-3`)).toBe("value-3"); + expect(await duplicate3.get(`${testKey}-1`)).toBe("value-1"); + + duplicate1.close(); + duplicate2.close(); + duplicate3.close(); + }); + + test("should duplicate client that failed to connect", async () => { + // Create client with invalid credentials to force connection failure + const url = new URL(DEFAULT_REDIS_URL); + url.username = "invaliduser"; + url.password = "invalidpassword"; + const failedRedis = new RedisClient(url.toString()); + + // Try to connect and expect it to fail + let connectionFailed = false; + try { + await failedRedis.connect(); + } catch { + connectionFailed = true; + } + + expect(connectionFailed).toBe(true); + expect(failedRedis.connected).toBe(false); + + // Duplicate should also remain unconnected + const duplicate = await failedRedis.duplicate(); + expect(duplicate.connected).toBe(false); + }); + + test("should handle duplicate timing with concurrent operations", async () => { + await ctx.redis.connect(); + + // Start some operations on the original client + const testKey = `concurrent-test-${randomUUIDv7().substring(0, 8)}`; + const originalOperation = ctx.redis.set(testKey, "original-value"); + + // Create duplicate while operation is in flight + const duplicate = await ctx.redis.duplicate(); + + // Wait for original operation to complete + await originalOperation; + + // Duplicate should be able to read the value + expect(await duplicate.get(testKey)).toBe("original-value"); + + duplicate.close(); + }); + }); }); From 00490199f19ba329bb94ad5b04e70de2eca65cba Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Fri, 26 Sep 2025 03:47:26 -0700 Subject: [PATCH 25/43] `bun feedback` (#22710) ### What does this PR do? ### How did you verify your code works? --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- build.zig | 1 + src/cli.zig | 2 + src/cli/run_command.zig | 19 + src/codegen/bundle-modules.ts | 21 + src/js/eval/README.md | 1 + src/js/eval/feedback.ts | 761 ++++++++++++++++++++++++++++++++++ 6 files changed, 805 insertions(+) create mode 100644 src/js/eval/README.md create mode 100644 src/js/eval/feedback.ts diff --git a/build.zig b/build.zig index 239df85c53..148e4c1366 100644 --- a/build.zig +++ b/build.zig @@ -743,6 +743,7 @@ fn addInternalImports(b: *Build, mod: *Module, opts: *BunBuildOptions) void { .{ .file = "node-fallbacks/url.js", .enable = opts.shouldEmbedCode() }, .{ .file = "node-fallbacks/util.js", .enable = opts.shouldEmbedCode() }, .{ .file = "node-fallbacks/zlib.js", .enable = opts.shouldEmbedCode() }, + .{ .file = "eval/feedback.ts", .enable = opts.shouldEmbedCode() }, }) |entry| { if (!@hasField(@TypeOf(entry), "enable") or entry.enable) { const path = b.pathJoin(&.{ opts.codegen_path, entry.file }); diff --git a/src/cli.zig b/src/cli.zig index 4461e8f92c..e5fbda4598 100644 --- a/src/cli.zig +++ b/src/cli.zig @@ -187,6 +187,8 @@ pub const HelpCommand = struct { \\ init Start an empty Bun project from a built-in template \\ create {s:<16} Create a new project from a template (bun c) \\ upgrade Upgrade to latest version of Bun. + \\ feedback ./file1 ./file2 Provide feedback to the Bun team. + \\ \\ \ --help Print help text for command. \\ ; diff --git a/src/cli/run_command.zig b/src/cli/run_command.zig index 23fb3f5ff7..c72159c96c 100644 --- a/src/cli/run_command.zig +++ b/src/cli/run_command.zig @@ -1603,6 +1603,12 @@ pub const RunCommand = struct { return true; } + if (ctx.filters.len == 0 and !ctx.workspaces and CLI.Cli.cmd != null and CLI.Cli.cmd.? == .AutoCommand) { + if (bun.strings.eqlComptime(target_name, "feedback")) { + try @"bun feedback"(ctx); + } + } + if (log_errors) { const ext = std.fs.path.extension(target_name); const default_loader = options.defaultLoaders.get(ext); @@ -1665,6 +1671,19 @@ pub const RunCommand = struct { Global.exit(1); }; } + + fn @"bun feedback"(ctx: Command.Context) !noreturn { + const trigger = bun.pathLiteral("/[eval]"); + var entry_point_buf: [bun.MAX_PATH_BYTES + trigger.len]u8 = undefined; + const cwd = try std.posix.getcwd(&entry_point_buf); + @memcpy(entry_point_buf[cwd.len..][0..trigger.len], trigger); + ctx.runtime_options.eval.script = if (bun.Environment.codegen_embed) + @embedFile("eval/feedback.ts") + else + bun.runtimeEmbedFile(.codegen, "eval/feedback.ts"); + try Run.boot(ctx, entry_point_buf[0 .. cwd.len + trigger.len], null); + Global.exit(0); + } }; pub const BunXFastPath = struct { diff --git a/src/codegen/bundle-modules.ts b/src/codegen/bundle-modules.ts index 429b37c2b1..18a5941d02 100644 --- a/src/codegen/bundle-modules.ts +++ b/src/codegen/bundle-modules.ts @@ -541,6 +541,27 @@ declare module "module" { mark("Generate Code"); +const evalFiles = new Bun.Glob(path.join(BASE, "eval", "*.ts")).scanSync(); +for (const file of evalFiles) { + const { + outputs: [output], + } = await Bun.build({ + entrypoints: [file], + + // Shrink it. + minify: !debug, + + target: "bun", + format: "esm", + env: "disable", + define: { + "process.platform": JSON.stringify(process.platform), + "process.arch": JSON.stringify(process.arch), + }, + }); + writeIfNotChanged(path.join(CODEGEN_DIR, "eval", path.basename(file)), await output.text()); +} + if (!silent) { console.log(""); console.timeEnd(timeString); diff --git a/src/js/eval/README.md b/src/js/eval/README.md new file mode 100644 index 0000000000..b92df5e8f6 --- /dev/null +++ b/src/js/eval/README.md @@ -0,0 +1 @@ +These are not bundled as builtin modules and instead are minified. diff --git a/src/js/eval/feedback.ts b/src/js/eval/feedback.ts new file mode 100644 index 0000000000..0bfd1e136c --- /dev/null +++ b/src/js/eval/feedback.ts @@ -0,0 +1,761 @@ +import { spawnSync } from "node:child_process"; +import { closeSync, promises as fsp, openSync } from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import readline from "node:readline"; +import tty from "node:tty"; +import { parseArgs as nodeParseArgs } from "node:util"; + +const supportsAnsi = Boolean(process.stdout.isTTY && !("NO_COLOR" in process.env)); +const reset = supportsAnsi ? "\x1b[0m" : ""; +const bold = supportsAnsi ? "\x1b[1m" : ""; +const dim = supportsAnsi ? "\x1b[2m" : ""; +const red = supportsAnsi ? "\x1b[31m" : ""; +const green = supportsAnsi ? "\x1b[32m" : ""; +const cyan = supportsAnsi ? "\x1b[36m" : ""; +const gray = supportsAnsi ? "\x1b[90m" : ""; +const symbols = { + question: `${cyan}?${reset}`, + check: `${green}✔${reset}`, + cross: `${red}✖${reset}`, +}; +const inputPrefix = `${gray}> ${reset}`; +const thankYouBanner = ` +${supportsAnsi ? bold : ""}THANK YOU! ${reset}`; +const enum IPSupport { + ipv4 = "ipv4", + ipv6 = "ipv6", + ipv4_and_ipv6 = "ipv4_and_ipv6", + none = "none", +} + +type TerminalIO = { + input: tty.ReadStream; + output: tty.WriteStream; + cleanup: () => void; +}; + +function openTerminal(): TerminalIO | null { + if (process.stdin.isTTY && process.stdout.isTTY) { + return { + input: process.stdin as unknown as tty.ReadStream, + output: process.stdout as unknown as tty.WriteStream, + cleanup: () => {}, + }; + } + + const candidates = process.platform === "win32" ? ["CON"] : ["/dev/tty"]; + + for (const candidate of candidates) { + try { + const fd = openSync(candidate, "r+"); + const input = new tty.ReadStream(fd); + const output = new tty.WriteStream(fd); + input.setEncoding("utf8"); + return { + input, + output, + cleanup: () => { + input.destroy(); + output.destroy(); + try { + closeSync(fd); + } catch {} + }, + }; + } catch {} + } + + return null; +} +const logError = (message: string) => { + process.stderr.write(`${symbols.cross} ${message}\n`); +}; +const logInfo = (message: string) => { + process.stdout.write(`${bold}${message}${reset}\n`); +}; + +const isValidEmail = (value: string | undefined): value is string => { + if (!value) return false; + const trimmed = value.trim(); + if (!trimmed.includes("@")) return false; + if (!trimmed.includes(".")) return false; + return true; +}; + +type ParsedArgs = { + email?: string; + help: boolean; + positionals: string[]; +}; + +function parseCliArgs(argv: string[]): ParsedArgs { + try { + const { values, positionals } = nodeParseArgs({ + args: argv, + allowPositionals: true, + strict: false, + options: { + email: { + type: "string", + short: "e", + }, + help: { + type: "boolean", + short: "h", + }, + }, + }); + + return { + email: values.email, + help: Boolean(values.help), + positionals, + }; + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + logError(message); + process.exit(1); + return { email: undefined, help: false, positionals: [] }; + } +} + +function printHelp() { + const heading = `${bold}${cyan}bun feedback${reset}`; + const usage = `${bold}Usage${reset} + bun feedback [options] [feedback text ... | files ...]`; + const options = `${bold}Options${reset} + ${cyan}-e${reset}, ${cyan}--email${reset} Set the email address used for this submission + ${cyan}-h${reset}, ${cyan}--help${reset} Show this help message and exit`; + const examples = `${bold}Examples${reset} + bun feedback "Love the new release!" + bun feedback report.txt details.log + echo "please document X" | bun feedback --email you@example.com`; + + console.log([heading, "", usage, "", options, "", examples].join("\n")); +} + +async function readEmailFromBunInstall(): Promise { + const installRoot = process.env.BUN_INSTALL ?? path.join(os.homedir(), ".bun"); + const emailFile = path.join(installRoot, "feedback"); + try { + const data = await fsp.readFile(emailFile, "utf8"); + const trimmed = data.trim(); + return trimmed.length > 0 ? trimmed : undefined; + } catch (error) { + if ((error as NodeJS.ErrnoException).code !== "ENOENT") { + console.warn(`Unable to read ${emailFile}:`, (error as Error).message); + } + return undefined; + } +} + +async function persistEmailToBunInstall(email: string): Promise { + const installRoot = process.env.BUN_INSTALL; + if (!installRoot) return; + + const emailFile = path.join(installRoot, "feedback"); + try { + await fsp.mkdir(path.dirname(emailFile), { recursive: true }); + await fsp.writeFile(emailFile, `${email.trim()}\n`, "utf8"); + } catch (error) { + console.warn(`Unable to persist email to ${emailFile}:`, (error as Error).message); + } +} + +function readEmailFromGitConfig(): string | undefined { + const result = spawnSync("git", ["config", "user.email"], { + encoding: "utf8", + stdio: ["ignore", "pipe", "ignore"], + }); + if (result.status !== 0) { + return undefined; + } + const output = result.stdout.trim(); + return output.length > 0 ? output : undefined; +} + +async function promptForEmail(terminal: TerminalIO | null, defaultEmail?: string): Promise { + if (!terminal) { + return defaultEmail && isValidEmail(defaultEmail) ? defaultEmail : undefined; + } + + let currentDefault = defaultEmail; + + for (;;) { + const answer = await promptForEmailInteractive(terminal, currentDefault); + if (typeof answer === "string" && isValidEmail(answer)) { + return answer.trim(); + } + + terminal.output.write(`${symbols.cross} Please provide a valid email address containing "@" and ".".\n`); + currentDefault = undefined; + } +} + +async function promptForEmailInteractive(terminal: TerminalIO, defaultEmail?: string): Promise { + const input = terminal.input; + const output = terminal.output; + + readline.emitKeypressEvents(input); + const hadRawMode = typeof input.isRaw === "boolean" ? input.isRaw : undefined; + if (typeof input.setRawMode === "function") { + input.setRawMode(true); + } + if (typeof input.resume === "function") { + input.resume(); + } + + const placeholder = defaultEmail ?? ""; + let placeholderActive = placeholder.length > 0; + let value = ""; + let resolved = false; + + const render = () => { + output.write(`\r\x1b[2K${symbols.question} ${bold}Email${reset}: `); + if (placeholderActive && placeholder.length > 0) { + output.write(`${dim}<${placeholder}>${reset}`); + output.write(`\x1b[${placeholder.length + 2}D`); + } else { + output.write(value); + } + }; + + render(); + + return await new Promise(resolve => { + const cleanup = (result?: string) => { + if (resolved) return; + resolved = true; + input.removeListener("keypress", onKeypress); + if (typeof input.setRawMode === "function") { + if (typeof hadRawMode === "boolean") { + input.setRawMode(hadRawMode); + } else { + input.setRawMode(false); + } + } + if (typeof input.pause === "function") { + input.pause(); + } + output.write("\n"); + resolve(result); + }; + + const onKeypress = (str: string, key: readline.Key) => { + if (!key && str) { + if (placeholderActive) { + placeholderActive = false; + value = ""; + render(); + } + value += str; + output.write(str); + return; + } + + if (key && (key.sequence === "\u0003" || (key.ctrl && key.name === "c"))) { + cleanup(); + process.exit(130); + return; + } + + if (key?.name === "return") { + if (placeholderActive && placeholder.length > 0) { + cleanup(placeholder); + return; + } + const trimmed = value.trim(); + cleanup(trimmed.length > 0 ? trimmed : undefined); + return; + } + + if (key?.name === "backspace") { + if (placeholderActive) { + return; + } + if (value.length > 0) { + value = value.slice(0, -1); + render(); + } + return; + } + + if (!str) { + return; + } + + if (key && key.name && key.name.length > 1 && key.name !== "space") { + return; + } + + if (placeholderActive) { + placeholderActive = false; + value = ""; + render(); + } + + value += str; + output.write(str); + }; + + input.on("keypress", onKeypress); + }); +} + +async function promptForBody( + terminal: TerminalIO | null, + attachments: PositionalContent["files"], +): Promise { + if (!terminal) { + return undefined; + } + + const input = terminal.input; + const output = terminal.output; + + readline.emitKeypressEvents(input); + const hadRawMode = typeof input.isRaw === "boolean" ? input.isRaw : undefined; + if (typeof input.setRawMode === "function") { + input.setRawMode(true); + } + if (typeof input.resume === "function") { + input.resume(); + } + + const header = `${symbols.question} ${bold}Share your feedback with Bun's team${reset} ${dim}(Enter to send, Shift+Enter for a newline)${reset}`; + output.write(`${header}\n`); + if (attachments.length > 0) { + output.write(`${dim}+ ${attachments.map(file => file.filename).join(", ")}${reset}\n`); + } + output.write(`${inputPrefix}`); + + const lines: string[] = [""]; + let currentLine = 0; + let resolved = false; + + return await new Promise(resolve => { + const cleanup = (value?: string) => { + if (resolved) return; + resolved = true; + input.removeListener("keypress", onKeypress); + if (typeof input.setRawMode === "function") { + if (typeof hadRawMode === "boolean") { + input.setRawMode(hadRawMode); + } else { + input.setRawMode(false); + } + } + if (typeof input.pause === "function") { + input.pause(); + } + output.write("\n"); + resolve(value); + }; + + const onKeypress = (str: string, key: readline.Key) => { + if (!key) { + if (str) { + lines[currentLine] += str; + output.write(str); + } + return; + } + + if (key.sequence === "\u0003" || (key.ctrl && key.name === "c")) { + cleanup(); + process.exit(130); + return; + } + + if (key.name === "return") { + if (key.shift) { + lines.push(""); + currentLine += 1; + output.write(`\n${inputPrefix}`); + return; + } + const message = lines.join("\n"); + cleanup(message); + return; + } + + if (key.name === "backspace") { + const current = lines[currentLine]; + if (current.length > 0) { + lines[currentLine] = current.slice(0, -1); + output.write("\b \b"); + } else if (currentLine > 0) { + lines.pop(); + currentLine -= 1; + output.write("\r\x1b[2K"); + output.write("\x1b[F"); + output.write("\r\x1b[2K"); + output.write(`${inputPrefix}${lines[currentLine]}`); + } + return; + } + + if (key.name && key.name.length > 1 && key.name !== "space") { + return; + } + + if (str) { + lines[currentLine] += str; + output.write(str); + } + }; + + input.on("keypress", onKeypress); + }); +} + +async function readFromStdin(): Promise { + const stdin = process.stdin; + if (!stdin || stdin.isTTY) return undefined; + + if (typeof stdin.setEncoding === "function") { + stdin.setEncoding("utf8"); + } + + if (typeof stdin.resume === "function") { + stdin.resume(); + } + + const chunks: string[] = []; + for await (const chunk of stdin as AsyncIterable) { + chunks.push(typeof chunk === "string" ? chunk : chunk.toString("utf8")); + } + + const content = chunks.join(""); + return content.length > 0 ? content : undefined; +} + +type PositionalContent = { + messageParts: string[]; + files: { filename: string; content: Uint8Array }[]; +}; + +async function resolveFileCandidate(token: string): Promise { + const candidates = new Set(); + candidates.add(token); + + if (token.startsWith("~/")) { + candidates.add(path.join(os.homedir(), token.slice(2))); + } + + const resolved = path.join(process.cwd(), token); + candidates.add(resolved); + + for (const candidate of candidates) { + try { + const stat = await fsp.stat(candidate); + if (stat.isFile()) { + return candidate; + } + } catch (error) { + const code = (error as NodeJS.ErrnoException).code; + if (code && (code === "ENOENT" || code === "ENOTDIR")) { + continue; + } + console.warn(`Unable to inspect ${candidate}:`, (error as Error).message); + } + } + + return undefined; +} + +async function readFromPositionals(positionals: string[]): Promise { + const messageParts: string[] = []; + const files: PositionalContent["files"] = []; + let literalTokens: string[] = []; + + const flushTokens = () => { + if (literalTokens.length > 0) { + messageParts.push(literalTokens.join(" ")); + literalTokens = []; + } + }; + + for (const token of positionals) { + const filePath = await resolveFileCandidate(token); + + if (filePath) { + try { + let fileContents = await Bun.file(filePath).bytes(); + // Truncate to + if (fileContents.length > 1024 * 1024 * 10) { + fileContents = fileContents.slice(0, 1024 * 1024 * 10); + } + + flushTokens(); + files.push({ + filename: path.normalize(path.relative(process.cwd(), filePath)), + content: fileContents, + }); + continue; + } catch { + // Ignore read errors; treat token as part of the message instead. + } + } + + literalTokens.push(token); + } + + flushTokens(); + return { messageParts, files }; +} + +function getIPSupport(networkInterface: os.NetworkInterfaceInfo, original: IPSupport): IPSupport { + if (networkInterface.family === "IPv4") { + switch (original) { + case IPSupport.none: + return IPSupport.ipv4; + case IPSupport.ipv4: + return IPSupport.ipv4_and_ipv6; + case IPSupport.ipv6: + return IPSupport.ipv4_and_ipv6; + case IPSupport.ipv4_and_ipv6: + return IPSupport.ipv4_and_ipv6; + } + } else if (networkInterface.family === "IPv6") { + switch (original) { + case IPSupport.none: + return IPSupport.ipv6; + case IPSupport.ipv4: + return IPSupport.ipv4_and_ipv6; + case IPSupport.ipv6: + return IPSupport.ipv4_and_ipv6; + case IPSupport.ipv4_and_ipv6: + return IPSupport.ipv4_and_ipv6; + } + } + return original; +} + +function getOldestGitSha(): string | undefined { + const result = spawnSync("git", ["rev-list", "--max-parents=0", "HEAD"], { + encoding: "utf8", + stdio: ["ignore", "pipe", "ignore"], + }); + + if (result.status !== 0) { + return undefined; + } + + const firstLine = result.stdout.split(/\r?\n/).find(line => line.trim().length > 0); + return firstLine?.trim(); +} + +async function main() { + const rawArgv = [...process.argv.slice(1)]; + + let terminal: TerminalIO | null = null; + try { + const { email: emailFlag, help, positionals } = parseCliArgs(rawArgv); + if (help) { + printHelp(); + return; + } + + terminal = openTerminal(); + + const exit = (code: number): never => { + terminal?.cleanup(); + process.exit(code); + }; + + if (emailFlag && !isValidEmail(emailFlag)) { + logError("The provided email must include both '@' and '.'."); + exit(1); + } + + const storedEmailRaw = await readEmailFromBunInstall(); + const storedEmail = isValidEmail(storedEmailRaw) ? storedEmailRaw.trim() : undefined; + + const gitEmailRaw = readEmailFromGitConfig(); + const gitEmail = isValidEmail(gitEmailRaw) ? gitEmailRaw.trim() : undefined; + + const canPrompt = terminal !== null; + + let email = emailFlag?.trim() ?? storedEmail ?? gitEmail; + + if (canPrompt && !emailFlag && !storedEmail) { + email = await promptForEmail(terminal, email ?? gitEmail ?? undefined); + } + + if (!isValidEmail(email)) { + if (!canPrompt) { + logError("Unable to determine email automatically. Pass --email
."); + } else { + logError("An email address is required. Pass --email or configure git user.email."); + } + exit(1); + return; + } + + const normalizedEmail = email.trim(); + + if (process.env.BUN_INSTALL && !storedEmail) { + await persistEmailToBunInstall(normalizedEmail); + } + + const stdinContent = await readFromStdin(); + const positionalContent = await readFromPositionals(positionals); + const positionalParts = positionalContent.messageParts; + const pieces: string[] = []; + if (stdinContent && stdinContent.trim().length > 0) pieces.push(stdinContent); + for (const part of positionalParts) { + if (part.trim().length > 0) { + pieces.push(part); + } + } + + let message = pieces.length > 0 ? pieces.join(pieces.length > 1 ? "\n\n" : "") : ""; + + if (message.trim().length === 0 && terminal) { + const interactiveBody = await promptForBody(terminal, positionalContent.files); + if (interactiveBody && interactiveBody.trim().length > 0) { + message = interactiveBody; + } + } + + const normalizedMessage = message.trim(); + if (normalizedMessage.length === 0) { + logError("No feedback provided. Supply text, file paths, or pipe input."); + exit(1); + return; + } + + const messageBody = normalizedMessage; + + const projectId = getOldestGitSha(); + const endpoint = process.env.BUN_FEEDBACK_URL || "https://bun.report/v1/feedback"; + + const form = new FormData(); + form.append("email", normalizedEmail); + const fileList = positionalContent.files.map(file => file.filename); + form.append("message", messageBody); + for (const file of positionalContent.files) { + form.append("files[]", new Blob([file.content]), file.filename); + } + + const id = Bun.randomUUIDv7(); + + form.append("platform", process.platform); + form.append("arch", process.arch); + form.append("bunRevision", Bun.revision); + form.append("hardwareConcurrency", String(navigator.hardwareConcurrency)); + form.append("bunVersion", Bun.version); + form.append("bunBuild", path.basename(process.release.sourceUrl!, path.extname(process.release.sourceUrl!))); + form.append("availableMemory", String(process.availableMemory())); + form.append("totalMemory", String(os.totalmem())); + form.append("osVersion", String(os.version())); + form.append("osRelease", String(os.release())); + form.append("id", id); + + // Check if we're running in Docker + let inDocker = false; + if (process.platform === "linux") { + if (require("fs").existsSync("/.dockerenv")) { + inDocker = true; + } + } + + if (inDocker) { + form.append("docker", "true"); + } + + let remoteIP: IPSupport = IPSupport.none; + let localIP: IPSupport = IPSupport.none; + + try { + const networkInterfaces = Object.entries(os.networkInterfaces() || {}); + + for (const [name, interfaces] of networkInterfaces) { + for (const networkInterface of interfaces || []) { + if (networkInterface.family === "IPv4") { + if (networkInterface.internal) { + localIP = getIPSupport(networkInterface, localIP); + } else { + remoteIP = getIPSupport(networkInterface, remoteIP); + } + } else if (networkInterface.family === "IPv6") { + if (networkInterface.internal) { + localIP = getIPSupport(networkInterface, localIP); + } else { + remoteIP = getIPSupport(networkInterface, remoteIP); + } + } + } + } + } catch { + // Ignore errors; treat as no IP support. + } + + form.append("localIPSupport", localIP); + form.append("remoteIPSupport", remoteIP); + + // Check if current working directory is on a remote filesystem + if (process.platform === "linux" || process.platform === "darwin") { + let isRemoteFilesystem = false; + try { + const cwd = process.cwd(); + const stats = await fsp.statfs(cwd); + + // Check filesystem type based on the type field + // Common remote filesystem types have specific type values + const remoteFsTypes = new Set([ + 0x6969, // NFS + 0xff534d42, // CIFS/SMB + 0x65735546, // FUSE (used by sshfs, etc.) + ]); + + if (remoteFsTypes.has(stats.type)) { + isRemoteFilesystem = true; + } + } catch { + // Ignore errors; treat as local filesystem + } + + if (isRemoteFilesystem) { + form.append("remoteFilesystem", "true"); + } + } + + if (projectId) { + form.append("projectId", projectId); + } + + const response = await fetch(endpoint, { + method: "POST", + body: form, + }); + + if (!response.ok || response.status !== 200) { + const bodyText = await response.text().catch(() => ""); + logError(`Failed to send feedback (${response.status} ${response.statusText}).`); + if (bodyText) { + process.stderr.write(`${bodyText}\n`); + } + exit(1); + } + + let IDBanner = ``; + if (supportsAnsi) { + IDBanner = `\n${dim}ID: ${id}${reset}`; + } else { + IDBanner = `\nID: ${id}`; + } + + process.stdout.write(`${symbols.check} Feedback sent.\n${IDBanner}${thankYouBanner}\n`); + } finally { + terminal?.cleanup(); + } +} + +await main().catch(error => { + const detail = error instanceof Error ? error.message : String(error); + logError(`Unexpected error while sending feedback: ${detail}`); + process.exit(1); +}); From f45900d7e6d1dfb2bb289b3e293c7d47efcb09a4 Mon Sep 17 00:00:00 2001 From: Filip Stevanovic <62512535+filipstev@users.noreply.github.com> Date: Fri, 26 Sep 2025 12:54:41 +0200 Subject: [PATCH 26/43] fix(fetch): print request body for application/x-www-form-urlencoded in curl logs (#22849) ### What does this PR do? fixes an issue where fetch requests with `Content-Type: application/x-www-form-urlencoded` would not include the request body in curl logs when `BUN_CONFIG_VERBOSE_FETCH=curl` is enabled previously, only JSON and text-based content types were recognized as safe-to-print in the curl formatter. This change updates the allow-list to also handle `application/x-www-form-urlencoded`, ensuring bodies for common form submissions are shown in logs ### How did you verify your code works? - added `Content-Type: application/x-www-form-urlencoded` to a fetch request and confirmed that `BUN_CONFIG_VERBOSE_FETCH=curl` now outputs a `--data-raw` section with the encoded body - verified the fix against the reproduction script provided in issue #12042 - created and ran a regression test - checked that existing content types (JSON, text, etc.) continue to print correctly fixes #12042 --- src/deps/picohttp.zig | 11 ++++++- test/regression/issue/12042.test.ts | 47 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 test/regression/issue/12042.test.ts diff --git a/src/deps/picohttp.zig b/src/deps/picohttp.zig index b674897fb9..a0507c6ffa 100644 --- a/src/deps/picohttp.zig +++ b/src/deps/picohttp.zig @@ -97,6 +97,15 @@ pub const Request = struct { ignore_insecure: bool = false, body: []const u8 = "", + fn isPrintableBody(content_type: []const u8) bool { + if (content_type.len == 0) return false; + + return bun.strings.hasPrefixComptime(content_type, "text/") or + bun.strings.hasPrefixComptime(content_type, "application/json") or + bun.strings.containsComptime(content_type, "json") or + bun.strings.hasPrefixComptime(content_type, "application/x-www-form-urlencoded"); + } + pub fn format(self: @This(), comptime _: []const u8, _: fmt.FormatOptions, writer: anytype) !void { const request = self.request; if (Output.enable_ansi_colors_stderr) { @@ -132,7 +141,7 @@ pub const Request = struct { } } - if (self.body.len > 0 and (content_type.len > 0 and bun.strings.hasPrefixComptime(content_type, "application/json") or bun.strings.hasPrefixComptime(content_type, "text/") or bun.strings.containsComptime(content_type, "json"))) { + if (self.body.len > 0 and isPrintableBody(content_type)) { _ = try writer.writeAll(" --data-raw "); try bun.js_printer.writeJSONString(self.body, @TypeOf(writer), writer, .utf8); } diff --git a/test/regression/issue/12042.test.ts b/test/regression/issue/12042.test.ts new file mode 100644 index 0000000000..f882aa78fe --- /dev/null +++ b/test/regression/issue/12042.test.ts @@ -0,0 +1,47 @@ +import { expect, test } from "bun:test"; +import { bunEnv, bunExe, normalizeBunSnapshot, tempDir } from "harness"; + +test("#12042 curl verbose fetch logs form-urlencoded body", async () => { + using dir = tempDir("issue-12042", { + "form.ts": ` +const server = Bun.serve({ + port: 0, + fetch() { + return new Response(JSON.stringify({ ok: true }), { + headers: { "Content-Type": "application/json" }, + }); + }, +}); + +const params = new URLSearchParams(); +params.set("grant_type", "client_credentials"); +params.set("client_id", "abc"); +params.set("client_secret", "xyz"); + +await fetch(String(server.url), { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: params, +}); + +await server.stop(); + `, + }); + + const dirPath = String(dir); + + await using proc = Bun.spawn({ + cmd: [bunExe(), "form.ts"], + env: { ...bunEnv, BUN_CONFIG_VERBOSE_FETCH: "curl" }, + cwd: dirPath, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr] = await Promise.all([proc.stdout.text(), proc.stderr.text()]); + + const output = stdout + stderr; + const normalized = normalizeBunSnapshot(output, dirPath); + + expect(normalized).toContain('--data-raw "grant_type=client_credentials&client_id=abc&client_secret=xyz'); +}); From a329da97f49d8ea020b5d89f8b2c6bad192993e1 Mon Sep 17 00:00:00 2001 From: robobun Date: Fri, 26 Sep 2025 04:59:07 -0700 Subject: [PATCH 27/43] Fix server stability issue with oversized requests (#22701) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Improves server stability when handling certain request edge cases. ## Test plan - Added regression test in `test/regression/issue/22353.test.ts` - Test verifies server continues operating normally after handling edge case requests - All existing HTTP server tests pass Fixes #22353 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude Bot Co-authored-by: Claude Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jarred Sumner --- src/bun.js/api/server.zig | 49 ++++++++++++++++++----------- test/regression/issue/22353.test.ts | 37 ++++++++++++++++++++++ 2 files changed, 67 insertions(+), 19 deletions(-) create mode 100644 test/regression/issue/22353.test.ts diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index aa17c485e1..5fa67440e6 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -2143,7 +2143,36 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d pub fn prepareJsRequestContext(this: *ThisServer, req: *uws.Request, resp: *App.Response, should_deinit_context: ?*bool, create_js_request: bool, method: ?bun.http.Method) ?PreparedRequest { jsc.markBinding(@src()); + + // We need to register the handler immediately since uSockets will not buffer. + // + // We first validate the self-reported request body length so that + // we avoid needing to worry as much about what memory to free. + const request_body_length: ?usize = request_body_length: { + if ((HTTP.Method.which(req.method()) orelse HTTP.Method.OPTIONS).hasRequestBody()) { + const len: usize = brk: { + if (req.header("content-length")) |content_length| { + break :brk std.fmt.parseInt(usize, content_length, 10) catch 0; + } + + break :brk 0; + }; + + // Abort the request very early. + if (len > this.config.max_request_body_size) { + resp.writeStatus("413 Request Entity Too Large"); + resp.endWithoutBody(true); + return null; + } + + break :request_body_length len; + } + + break :request_body_length null; + }; + this.onPendingRequest(); + if (comptime Environment.isDebug) { this.vm.eventLoop().debug.enter(); } @@ -2193,25 +2222,7 @@ pub fn NewServer(protocol_enum: enum { http, https }, development_kind: enum { d }; } - // we need to do this very early unfortunately - // it seems to work fine for synchronous requests but anything async will take too long to register the handler - // we do this only for HTTP methods that support request bodies, so not GET, HEAD, OPTIONS, or CONNECT. - if ((HTTP.Method.which(req.method()) orelse HTTP.Method.OPTIONS).hasRequestBody()) { - const req_len: usize = brk: { - if (req.header("content-length")) |content_length| { - break :brk std.fmt.parseInt(usize, content_length, 10) catch 0; - } - - break :brk 0; - }; - - if (req_len > this.config.max_request_body_size) { - resp.writeStatus("413 Request Entity Too Large"); - resp.endWithoutBody(true); - this.finalize(); - return null; - } - + if (request_body_length) |req_len| { ctx.request_body_content_len = req_len; ctx.flags.is_transfer_encoding = req.header("transfer-encoding") != null; if (req_len > 0 or ctx.flags.is_transfer_encoding) { diff --git a/test/regression/issue/22353.test.ts b/test/regression/issue/22353.test.ts new file mode 100644 index 0000000000..5465d90519 --- /dev/null +++ b/test/regression/issue/22353.test.ts @@ -0,0 +1,37 @@ +import { expect, test } from "bun:test"; + +test("issue #22353 - server should handle oversized request without crashing", async () => { + using server = Bun.serve({ + port: 0, + maxRequestBodySize: 1024, // 1KB limit + async fetch(req) { + const body = await req.text(); + return new Response( + JSON.stringify({ + received: true, + size: body.length, + }), + { + headers: { "Content-Type": "application/json" }, + }, + ); + }, + }); + + const resp = await fetch(server.url, { + method: "POST", + body: "A".repeat(1025), + }); + expect(resp.status).toBe(413); + expect(await resp.text()).toBeEmpty(); + for (let i = 0; i < 100; i++) { + const resp2 = await fetch(server.url, { + method: "POST", + }); + expect(resp2.status).toBe(200); + expect(await resp2.json()).toEqual({ + received: true, + size: 0, + }); + } +}, 10000); From 9d01a7b91a3b7bf8a33ca9ddb605b9e7b30044ae Mon Sep 17 00:00:00 2001 From: pfg Date: Fri, 26 Sep 2025 13:47:24 -0700 Subject: [PATCH 28/43] Require deinit function for memory.deinit() (#22923) Co-authored-by: taylor.fish --- src/allocators.zig | 2 ++ src/allocators/MimallocArena.zig | 2 ++ src/js_parser.zig | 2 ++ src/memory.zig | 55 ++++++++++++++++++++++++++------ src/threading/Mutex.zig | 2 ++ 5 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/allocators.zig b/src/allocators.zig index a4d0992a75..6336b43244 100644 --- a/src/allocators.zig +++ b/src/allocators.zig @@ -919,6 +919,8 @@ pub const Default = struct { _ = self; return c_allocator; } + + pub const deinit = void; }; const basic = if (bun.use_mimalloc) diff --git a/src/allocators/MimallocArena.zig b/src/allocators/MimallocArena.zig index 59b81d9e4d..5fe5dccd7b 100644 --- a/src/allocators/MimallocArena.zig +++ b/src/allocators/MimallocArena.zig @@ -94,6 +94,8 @@ const BorrowedHeap = if (safety_checks) *DebugHeap else *mimalloc.Heap; const DebugHeap = struct { inner: *mimalloc.Heap, thread_lock: bun.safety.ThreadLock, + + pub const deinit = void; }; threadlocal var thread_heap: if (safety_checks) ?DebugHeap else void = if (safety_checks) null; diff --git a/src/js_parser.zig b/src/js_parser.zig index 14c8f7c9ea..9bba1c6028 100644 --- a/src/js_parser.zig +++ b/src/js_parser.zig @@ -545,6 +545,8 @@ pub const StringVoidMap = struct { allocator: Allocator, map: bun.StringHashMapUnmanaged(void) = bun.StringHashMapUnmanaged(void){}, + pub const deinit = void; + /// Returns true if the map already contained the given key. pub fn getOrPutContains(this: *StringVoidMap, key: string) bool { const entry = this.map.getOrPut(this.allocator, key) catch unreachable; diff --git a/src/memory.zig b/src/memory.zig index 5ed890ee3e..6f41f9fed9 100644 --- a/src/memory.zig +++ b/src/memory.zig @@ -42,6 +42,29 @@ pub fn initDefault(comptime T: type) T { .{}; } +/// Returns true if `T` should not be required to have a `deinit` method. +/// +/// This method is primarily for external types where a `deinit` method can't be added. +/// For other types, prefer adding a `deinit` method or adding `pub const deinit = void;` if +/// possible. +fn exemptedFromDeinit(comptime T: type) bool { + return switch (T) { + std.mem.Allocator => true, + else => { + _ = T.deinit; // no deinit method? add one, set to void, or add an exemption + return false; + }, + }; +} + +fn deinitIsVoid(comptime T: type) bool { + return switch (@TypeOf(T.deinit)) { + type => T.deinit == void, + void => true, + else => false, + }; +} + /// Calls `deinit` on `ptr_or_slice`, or on every element of `ptr_or_slice`, if such a `deinit` /// method exists. /// @@ -60,18 +83,30 @@ pub fn deinit(ptr_or_slice: anytype) void { const ptr_info = @typeInfo(@TypeOf(ptr_or_slice)); const Child = ptr_info.pointer.child; const mutable = !ptr_info.pointer.is_const; - if (comptime std.meta.hasFn(Child, "deinit")) { - switch (comptime ptr_info.pointer.size) { - .one => { + + const needs_deinit = comptime switch (@typeInfo(Child)) { + .@"struct" => true, + .@"union" => |u| u.tag_type != null, + else => false, + }; + const should_call_deinit = comptime needs_deinit and + !exemptedFromDeinit(Child) and + !deinitIsVoid(Child); + + switch (comptime ptr_info.pointer.size) { + .one => { + if (comptime should_call_deinit) { ptr_or_slice.deinit(); - if (comptime mutable) ptr_or_slice.* = undefined; - }, - .slice => for (ptr_or_slice) |*elem| { + } + if (comptime mutable) ptr_or_slice.* = undefined; + }, + .slice => for (ptr_or_slice) |*elem| { + if (comptime should_call_deinit) { elem.deinit(); - if (comptime mutable) elem.* = undefined; - }, - else => @compileError("unsupported pointer type"), - } + } + if (comptime mutable) elem.* = undefined; + }, + else => @compileError("unsupported pointer type"), } } diff --git a/src/threading/Mutex.zig b/src/threading/Mutex.zig index 9cb40b4170..20828d3580 100644 --- a/src/threading/Mutex.zig +++ b/src/threading/Mutex.zig @@ -48,6 +48,8 @@ pub fn unlock(self: *Mutex) void { self.impl.unlock(); } +pub const deinit = void; + const Impl = if (builtin.mode == .Debug and !builtin.single_threaded) DebugImpl else From 266fca2e5cd6ce50fda7b5f6dfe45fbf3be9fec1 Mon Sep 17 00:00:00 2001 From: "taylor.fish" Date: Fri, 26 Sep 2025 15:15:58 -0700 Subject: [PATCH 29/43] Add `ExternalShared` and `RawRefCount` (#23013) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `bun.ptr.ExternalShared`, a shared pointer whose reference count is managed externally; e.g., by extern functions. This can be used to work with `RefCounted` C++ objects in Zig. For example: ```cpp // C++: struct MyType : RefCounted { ... }; extern "C" void MyType__ref(MyType* self) { self->ref(); } extern "C" void MyType__ref(MyType* self) { self->deref(); } ``` ```zig // Zig: const MyType = opaque { extern fn MyType__ref(self: *MyType) void; extern fn MyType__deref(self: *MyType) void; pub const Ref = bun.ptr.ExternalShared(MyType); // This enables `ExternalShared` to work. pub const external_shared_descriptor = struct { pub const ref = MyType__ref; pub const deref = MyType__deref; }; }; // Now `MyType.Ref` behaves just like `Ref` in C++: var some_ref: MyType.Ref = someFunctionReturningMyTypeRef(); const ptr: *MyType = some_ref.get(); // gets the inner pointer var some_other_ref = some_ref.clone(); // increments the ref count some_ref.deinit(); // decrements the ref count // decrements the ref count again; if no other refs exist, the object // is destroyed some_other_ref.deinit(); ``` This commit also adds `RawRefCount`, a simple wrapper around an integer reference count that can be used to implement the interface required by `ExternalShared`. Generally, for reference-counted Zig types, `bun.ptr.Shared` is preferred, but occasionally it is useful to have an “intrusive” reference-counted type where the ref count is stored in the type itself. For this purpose, `ExternalShared` + `RawRefCount` is more flexible and less error-prone than the deprecated `bun.ptr.RefCounted` type. (For internal tracking: fixes STAB-1287, STAB-1288) --- src/ptr.zig | 4 ++ src/ptr/external_shared.zig | 111 +++++++++++++++++++++++++++++++++++ src/ptr/raw_ref_count.zig | 74 +++++++++++++++++++++++ src/string/WTFStringImpl.zig | 8 +++ 4 files changed, 197 insertions(+) create mode 100644 src/ptr/external_shared.zig create mode 100644 src/ptr/raw_ref_count.zig diff --git a/src/ptr.zig b/src/ptr.zig index 608b0efc50..4df6fac14c 100644 --- a/src/ptr.zig +++ b/src/ptr.zig @@ -13,6 +13,7 @@ pub const DynamicOwned = owned.Dynamic; // owned pointer allocated with any `std pub const shared = @import("./ptr/shared.zig"); pub const Shared = shared.Shared; pub const AtomicShared = shared.AtomicShared; +pub const ExternalShared = @import("./ptr/external_shared.zig").ExternalShared; pub const ref_count = @import("./ptr/ref_count.zig"); /// Deprecated; use `Shared(*T)`. @@ -22,6 +23,9 @@ pub const ThreadSafeRefCount = ref_count.ThreadSafeRefCount; /// Deprecated; use `Shared(*T)`. pub const RefPtr = ref_count.RefPtr; +pub const raw_ref_count = @import("./ptr/raw_ref_count.zig"); +pub const RawRefCount = raw_ref_count.RawRefCount; + pub const TaggedPointer = @import("./ptr/tagged_pointer.zig").TaggedPointer; pub const TaggedPointerUnion = @import("./ptr/tagged_pointer.zig").TaggedPointerUnion; diff --git a/src/ptr/external_shared.zig b/src/ptr/external_shared.zig new file mode 100644 index 0000000000..de59780b7d --- /dev/null +++ b/src/ptr/external_shared.zig @@ -0,0 +1,111 @@ +/// A shared pointer whose reference count is managed externally; e.g., by extern functions. +/// +/// `T.external_shared_descriptor` must be a struct of the following form: +/// +/// pub const external_shared_descriptor = struct { +/// pub fn ref(T*) void; +/// pub fn deref(T*) void; +/// }; +pub fn ExternalShared(comptime T: type) type { + _ = T.external_shared_descriptor.ref; // must define a `ref` function + _ = T.external_shared_descriptor.deref; // must define a `deref` function + return struct { + const Self = @This(); + + #impl: *T, + + /// `incremented_raw` should have already had its ref count incremented by 1. + pub fn adopt(incremented_raw: *T) Self { + return .{ .#impl = incremented_raw }; + } + + /// Deinitializes the shared pointer, decrementing the ref count. + pub fn deinit(self: *Self) void { + T.external_shared_descriptor.deref(self.#impl); + self.* = undefined; + } + + /// Gets the underlying pointer. This pointer may not be valid after `self` is + /// deinitialized. + pub fn get(self: Self) *T { + return self.#impl; + } + + /// Clones the shared pointer, incrementing the ref count. + pub fn clone(self: Self) Self { + T.external_shared_descriptor.ref(self.#impl); + return self; + } + + pub fn cloneFromRaw(raw: *T) Self { + T.external_shared_descriptor.ref(raw); + return .{ .#impl = raw }; + } + + /// Returns the raw pointer without decrementing the ref count. Invalidates `self`. + pub fn leak(self: *Self) *T { + defer self.* = undefined; + return self.#impl; + } + + const NonOptional = Self; + + pub const Optional = struct { + #impl: ?*T = null, + + pub fn initNull() Optional { + return .{}; + } + + /// `incremented_raw`, if non-null, should have already had its ref count incremented + /// by 1. + pub fn adopt(incremented_raw: ?*T) Optional { + return .{ .#impl = incremented_raw }; + } + + pub fn deinit(self: *Optional) void { + if (self.#impl) |impl| { + T.external_shared_descriptor.deref(impl); + } + self.* = undefined; + } + + pub fn get(self: Optional) ?*T { + return self.#impl; + } + + /// Sets `self` to null. + pub fn take(self: *Optional) ?NonOptional { + const result: NonOptional = .{ .#impl = self.#impl orelse return null }; + self.#impl = null; + return result; + } + + pub fn clone(self: Optional) Optional { + if (self.#impl) |impl| { + T.external_shared_descriptor.ref(impl); + } + return self; + } + + pub fn cloneFromRaw(raw: ?*T) Optional { + if (raw) |some_raw| { + T.external_shared_descriptor.ref(some_raw); + } + return .{ .#impl = raw }; + } + + /// Returns the raw pointer without decrementing the ref count. Invalidates `self`. + pub fn leak(self: *Optional) ?*T { + defer self.* = undefined; + return self.#impl; + } + }; + + /// Invalidates `self`. + pub fn intoOptional(self: *Self) Optional { + defer self.* = undefined; + return .{ .#impl = self.#impl }; + } + }; +} diff --git a/src/ptr/raw_ref_count.zig b/src/ptr/raw_ref_count.zig new file mode 100644 index 0000000000..7d65d98fb4 --- /dev/null +++ b/src/ptr/raw_ref_count.zig @@ -0,0 +1,74 @@ +pub const ThreadSafety = enum { + single_threaded, + thread_safe, +}; + +pub const DecrementResult = enum { + keep_alive, + should_destroy, +}; + +/// A simple wrapper around an integer reference count. This type doesn't do any memory management +/// itself. +/// +/// This type may be useful for implementing the interface required by `bun.ptr.ExternalShared`. +pub fn RawRefCount(comptime Int: type, comptime thread_safety: ThreadSafety) type { + return struct { + const Self = @This(); + + raw_value: if (thread_safety == .thread_safe) std.atomic.Value(Int) else Int, + #thread_lock: if (thread_safety == .single_threaded) bun.safety.ThreadLock else void, + + /// Usually the initial count should be 1. + pub fn init(initial_count: Int) Self { + return .{ + .raw_value = switch (comptime thread_safety) { + .single_threaded => initial_count, + .thread_safe => .init(initial_count), + }, + .#thread_lock = switch (comptime thread_safety) { + .single_threaded => .initLockedIfNonComptime(), + .thread_safe => {}, + }, + }; + } + + pub fn increment(self: *Self) void { + switch (comptime thread_safety) { + .single_threaded => { + self.#thread_lock.lockOrAssert(); + self.raw_value += 1; + }, + .thread_safe => { + const old = self.raw_value.fetchAdd(1, .monotonic); + bun.assertf( + old != std.math.maxInt(Int), + "overflow of thread-safe ref count", + .{}, + ); + }, + } + } + + pub fn decrement(self: *Self) DecrementResult { + const new_count = blk: switch (comptime thread_safety) { + .single_threaded => { + self.#thread_lock.lockOrAssert(); + self.raw_value -= 1; + break :blk self.raw_value; + }, + .thread_safe => { + const old = self.raw_value.fetchSub(1, .acq_rel); + bun.assertf(old != 0, "underflow of thread-safe ref count", .{}); + break :blk old - 1; + }, + }; + return if (new_count == 0) .should_destroy else .keep_alive; + } + + pub const deinit = void; + }; +} + +const bun = @import("bun"); +const std = @import("std"); diff --git a/src/string/WTFStringImpl.zig b/src/string/WTFStringImpl.zig index 48810d2a3e..c0afcbed5d 100644 --- a/src/string/WTFStringImpl.zig +++ b/src/string/WTFStringImpl.zig @@ -217,8 +217,16 @@ pub const WTFStringImplStruct = extern struct { pub fn hasPrefix(self: WTFStringImpl, text: []const u8) bool { return bun.cpp.Bun__WTFStringImpl__hasPrefix(self, text.ptr, text.len); } + + pub const external_shared_descriptor = struct { + pub const ref = WTFStringImplStruct.ref; + pub const deref = WTFStringImplStruct.deref; + }; }; +/// Behaves like `WTF::Ref`. +pub const WTFString = bun.ptr.ExternalShared(WTFStringImplStruct); + pub const StringImplAllocator = struct { fn alloc(ptr: *anyopaque, len: usize, _: std.mem.Alignment, _: usize) ?[*]u8 { var this = bun.cast(WTFStringImpl, ptr); From c58d2e39116101935ab298534c0599a98e409b90 Mon Sep 17 00:00:00 2001 From: "taylor.fish" Date: Fri, 26 Sep 2025 15:19:45 -0700 Subject: [PATCH 30/43] Add generic-allocator `ArrayList` (#22917) Add a version of `ArrayList` that takes a generic `Allocator` type parameter. This matches the interface of smart pointers like `bun.ptr.Owned` and `bun.ptr.Shared`. This type behaves like a managed `ArrayList` but has no overhead if `Allocator` is a zero-sized type, like `bun.DefaultAllocator`. (For internal tracking: fixes STAB-1267) --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/collections.zig | 8 + src/collections/array_list.zig | 421 +++++++++++++++++++++++++++++++++ 2 files changed, 429 insertions(+) create mode 100644 src/collections/array_list.zig diff --git a/src/collections.zig b/src/collections.zig index 7215058899..de5c8238fb 100644 --- a/src/collections.zig +++ b/src/collections.zig @@ -7,3 +7,11 @@ pub const bit_set = @import("./collections/bit_set.zig"); pub const AutoBitSet = bit_set.AutoBitSet; pub const HiveArray = @import("./collections/hive_array.zig").HiveArray; pub const BoundedArray = @import("./collections/bounded_array.zig").BoundedArray; + +pub const array_list = @import("./collections/array_list.zig"); +pub const ArrayList = array_list.ArrayList; // any `std.mem.Allocator` +pub const ArrayListDefault = array_list.ArrayListDefault; // always default allocator (no overhead) +pub const ArrayListIn = array_list.ArrayListIn; // specific type of generic allocator +pub const ArrayListAligned = array_list.ArrayListAligned; +pub const ArrayListAlignedDefault = array_list.ArrayListAlignedDefault; +pub const ArrayListAlignedIn = array_list.ArrayListAlignedIn; diff --git a/src/collections/array_list.zig b/src/collections/array_list.zig new file mode 100644 index 0000000000..8653989abd --- /dev/null +++ b/src/collections/array_list.zig @@ -0,0 +1,421 @@ +/// Managed `ArrayList` using an arbitrary `std.mem.Allocator`. +/// Prefer using a concrete type, like `ArrayListDefault` or `ArrayListIn(MimallocArena)`. +/// +/// NOTE: Unlike `std.ArrayList`, this type's `deinit` method calls `deinit` on each of the items. +pub fn ArrayList(comptime T: type) type { + return ArrayListIn(T, std.mem.Allocator); +} + +/// Managed `ArrayList` using the default allocator. No overhead compared to an unmanaged +/// `ArrayList`. +/// +/// NOTE: Unlike `std.ArrayList`, this type's `deinit` method calls `deinit` on each of the items. +pub fn ArrayListDefault(comptime T: type) type { + return ArrayListIn(T, bun.DefaultAllocator); +} + +/// Managed `ArrayList` using a specific kind of allocator. No overhead if `Allocator` is a +/// zero-sized type. +/// +/// NOTE: Unlike `std.ArrayList`, this type's `deinit` method calls `deinit` on each of the items. +pub fn ArrayListIn(comptime T: type, comptime Allocator: type) type { + return ArrayListAlignedIn(T, Allocator, null); +} + +/// Managed `ArrayListAligned` using an arbitrary `std.mem.Allocator`. +/// Prefer using a concrete type, like `ArrayListAlignedDefault` or +/// `ArrayListAlignedIn(MimallocArena)`. +/// +/// NOTE: Unlike `std.ArrayList`, this type's `deinit` method calls `deinit` on each of the items. +pub fn ArrayListAligned(comptime T: type, comptime alignment: ?u29) type { + return ArrayListAlignedIn(T, std.mem.Allocator, alignment); +} + +/// Managed `ArrayListAligned` using the default allocator. No overhead compared to an unmanaged +/// `ArrayListAligned`. +/// +/// NOTE: Unlike `std.ArrayList`, this type's `deinit` method calls `deinit` on each of the items. +pub fn ArrayListAlignedDefault(comptime T: type, comptime alignment: ?u29) type { + return ArrayListAlignedIn(T, bun.DefaultAllocator, alignment); +} + +/// Managed `ArrayListAligned` using a specific kind of allocator. No overhead if `Allocator` is a +/// zero-sized type. +/// +/// NOTE: Unlike `std.ArrayList`, this type's `deinit` method calls `deinit` on each of the items. +pub fn ArrayListAlignedIn( + comptime T: type, + comptime Allocator: type, + comptime alignment: ?u29, +) type { + return struct { + const Self = @This(); + + #unmanaged: Unmanaged = .empty, + #allocator: Allocator, + + pub fn items(self: *const Self) Slice { + return self.#unmanaged.items; + } + + pub fn capacity(self: *const Self) usize { + return self.#unmanaged.capacity; + } + + pub const SentinelSlice = Unmanaged.SentinelSlice; + pub const Slice = Unmanaged.Slice; + pub const Unmanaged = std.ArrayListAlignedUnmanaged(T, alignment); + + pub fn init() Self { + return .initIn(bun.memory.initDefault(Allocator)); + } + + pub fn initIn(allocator_: Allocator) Self { + return .{ + .#unmanaged = .empty, + .#allocator = allocator_, + }; + } + + pub fn initCapacity(num: usize) AllocError!Self { + return .initCapacityIn(num, bun.memory.initDefault(Allocator)); + } + + pub fn initCapacityIn(num: usize, allocator_: Allocator) AllocError!Self { + return .{ + .#unmanaged = try .initCapacity(bun.allocators.asStd(allocator_), num), + .#allocator = allocator_, + }; + } + + /// NOTE: Unlike `std.ArrayList`, this method calls `deinit` on every item in the list, + /// if such a method exists. If you don't want that behavior, use `deinitShallow`. + pub fn deinit(self: *Self) void { + bun.memory.deinit(self.items()); + self.deinitShallow(); + } + + pub fn deinitShallow(self: *Self) void { + defer self.* = undefined; + self.#unmanaged.deinit(self.getStdAllocator()); + bun.memory.deinit(&self.#allocator); + } + + pub fn fromOwnedSlice(allocator_: Allocator, slice: Slice) Self { + return .{ + .#unmanaged = .fromOwnedSlice(slice), + .#allocator = allocator_, + }; + } + + pub fn fromOwnedSliceSentinel( + allocator_: Allocator, + comptime sentinel: T, + slice: [:sentinel]T, + ) Self { + return .{ + .#unmanaged = .fromOwnedSliceSentinel(sentinel, slice), + .#allocator = allocator_, + }; + } + + /// Returns a borrowed version of the allocator. + pub fn allocator(self: *const Self) bun.allocators.Borrowed(Allocator) { + return bun.allocators.borrow(self.#allocator); + } + + /// This method empties `self`. + pub fn moveToUnmanaged(self: *Self) Unmanaged { + defer self.#unmanaged = .empty; + return self.#unmanaged; + } + + /// Unlike `moveToUnmanaged`, this method *invalidates* `self`. + pub fn intoUnmanagedWithAllocator(self: *Self) struct { Unmanaged, Allocator } { + defer self.* = undefined; + return .{ self.#unmanaged, self.#allocator }; + } + + /// The contents of `unmanaged` must have been allocated by `allocator`. + /// This function invalidates `unmanaged`; don't call `deinit` on it. + pub fn fromUnmanaged(allocator_: Allocator, unmanaged: Unmanaged) Self { + return .{ + .#unmanaged = unmanaged, + .#allocator = allocator_, + }; + } + + pub fn toOwnedSlice(self: *Self) AllocError!Slice { + return self.#unmanaged.toOwnedSlice(self.getStdAllocator()); + } + + /// Creates a copy of this `ArrayList` with *shallow* copies of its items. + /// + /// The returned list uses a default-initialized `Allocator`. If `Allocator` cannot be + /// default-initialized, use `cloneIn` instead. + /// + /// Be careful with this method if `T` has a `deinit` method. You will have to use + /// `deinitShallow` on one of the `ArrayList`s to prevent `deinit` from being called twice + /// on each element. + pub fn clone(self: *const Self) AllocError!Self { + return self.cloneIn(bun.memory.initDefault(Allocator)); + } + + /// Creates a copy of this `ArrayList` using the provided allocator, with *shallow* copies + /// of this list's items. + pub fn cloneIn( + self: *const Self, + allocator_: anytype, + ) AllocError!ArrayListAlignedIn(T, @TypeOf(allocator_), alignment) { + return .{ + .#unmanaged = try self.#unmanaged.clone(bun.allocators.asStd(allocator_)), + .#allocator = allocator_, + }; + } + + pub fn insert(self: *Self, i: usize, item: T) AllocError!void { + return self.#unmanaged.insert(self.getStdAllocator(), i, item); + } + + pub fn insertAssumeCapacity(self: *Self, i: usize, item: T) void { + self.#unmanaged.insertAssumeCapacity(i, item); + } + + /// Note that this creates *shallow* copies of `value`. + pub fn addManyAt(self: *Self, index: usize, value: T, count: usize) AllocError![]T { + const result = try self.#unmanaged.addManyAt(self.getStdAllocator(), index, count); + @memset(result, value); + return result; + } + + /// Note that this creates *shallow* copies of `value`. + pub fn addManyAtAssumeCapacity(self: *Self, index: usize, value: T, count: usize) []T { + const result = self.#unmanaged.addManyAt(index, count); + @memset(result, value); + return result; + } + + /// This method takes ownership of all elements in `new_items`. + pub fn insertSlice(self: *Self, index: usize, new_items: []const T) AllocError!void { + return self.#unmanaged.insertSlice(self.getStdAllocator(), index, new_items); + } + + /// This method `deinit`s the removed items. + /// This method takes ownership of all elements in `new_items`. + pub fn replaceRange( + self: *Self, + start: usize, + len: usize, + new_items: []const T, + ) AllocError!void { + bun.memory.deinit(self.items()[start .. start + len]); + return self.replaceRangeShallow(start, len, new_items); + } + + /// This method does *not* `deinit` the removed items. + /// This method takes ownership of all elements in `new_items`. + pub fn replaceRangeShallow( + self: *Self, + start: usize, + len: usize, + new_items: []const T, + ) AllocError!void { + return self.#unmanaged.replaceRange(self.getStdAllocator(), start, len, new_items); + } + + /// This method `deinit`s the removed items. + /// This method takes ownership of all elements in `new_items`. + pub fn replaceRangeAssumeCapacity( + self: *Self, + start: usize, + len: usize, + new_items: []const T, + ) void { + for (self.items()[start .. start + len]) |*item| { + bun.memory.deinit(item); + } + self.replaceRangeAssumeCapacityShallow(start, len, new_items); + } + + /// This method does *not* `deinit` the removed items. + /// This method takes ownership of all elements in `new_items`. + pub fn replaceRangeAssumeCapacityShallow( + self: *Self, + start: usize, + len: usize, + new_items: []const T, + ) void { + self.#unmanaged.replaceRangeAssumeCapacity(start, len, new_items); + } + + pub fn append(self: *Self, item: T) AllocError!void { + return self.#unmanaged.append(self.getStdAllocator(), item); + } + + pub fn appendAssumeCapacity(self: *Self, item: T) void { + self.#unmanaged.appendAssumeCapacity(item); + } + + pub fn orderedRemove(self: *Self, i: usize) T { + return self.#unmanaged.orderedRemove(i); + } + + pub fn swapRemove(self: *Self, i: usize) T { + return self.#unmanaged.swapRemove(i); + } + + /// This method takes ownership of all elements in `new_items`. + pub fn appendSlice(self: *Self, new_items: []const T) AllocError!void { + return self.#unmanaged.appendSlice(self.getStdAllocator(), new_items); + } + + /// This method takes ownership of all elements in `new_items`. + pub fn appendSliceAssumeCapacity(self: *Self, new_items: []const T) void { + self.#unmanaged.appendSliceAssumeCapacity(new_items); + } + + /// This method takes ownership of all elements in `new_items`. + pub fn appendUnalignedSlice(self: *Self, new_items: []align(1) const T) AllocError!void { + return self.#unmanaged.appendUnalignedSlice(self.getStdAllocator(), new_items); + } + + /// This method takes ownership of all elements in `new_items`. + pub fn appendUnalignedSliceAssumeCapacity(self: *Self, new_items: []align(1) const T) void { + self.#unmanaged.appendUnalignedSliceAssumeCapacity(new_items); + } + + /// Note that this creates *shallow* copies of `value`. + pub inline fn appendNTimes(self: *Self, value: T, n: usize) AllocError!void { + return self.#unmanaged.appendNTimes(self.getStdAllocator(), value, n); + } + + /// Note that this creates *shallow* copies of `value`. + pub inline fn appendNTimesAssumeCapacity(self: *Self, value: T, n: usize) void { + self.#unmanaged.appendNTimesAssumeCapacity(value, n); + } + + /// If `new_len` is less than the current length, this method will call `deinit` on the + /// removed items. + /// + /// If `new_len` is greater than the current length, note that this creates *shallow* copies + /// of `init_value`. + pub fn resize(self: *Self, init_value: T, new_len: usize) AllocError!void { + const len = self.items().len; + try self.resizeWithoutDeinit(init_value, new_len); + if (new_len < len) { + bun.memory.deinit(self.items().ptr[new_len..len]); + } + } + + /// If `new_len` is less than the current length, this method will *not* call `deinit` on + /// the removed items. + /// + /// If `new_len` is greater than the current length, note that this creates *shallow* copies + /// of `init_value`. + pub fn resizeWithoutDeinit(self: *Self, init_value: T, new_len: usize) AllocError!void { + const len = self.items().len; + try self.#unmanaged.resize(self.getStdAllocator(), new_len); + if (new_len > len) { + @memset(self.items()[len..], init_value); + } + } + + /// This method `deinit`s the removed items. + pub fn shrinkAndFree(self: *Self, new_len: usize) void { + self.prepareForDeepShrink(new_len); + self.shrinkAndFreeShallow(new_len); + } + + /// This method does *not* `deinit` the removed items. + pub fn shrinkAndFreeShallow(self: *Self, new_len: usize) void { + self.#unmanaged.shrinkAndFree(self.getStdAllocator(), new_len); + } + + /// This method `deinit`s the removed items. + pub fn shrinkRetainingCapacity(self: *Self, new_len: usize) void { + self.prepareForDeepShrink(new_len); + self.shrinkRetainingCapacityShallow(new_len); + } + + /// This method does *not* `deinit` the removed items. + pub fn shrinkRetainingCapacityShallow(self: *Self, new_len: usize) void { + self.#unmanaged.shrinkRetainingCapacity(new_len); + } + + /// This method `deinit`s all items. + pub fn clearRetainingCapacity(self: *Self) void { + bun.memory.deinit(self.items()); + self.clearRetainingCapacityShallow(); + } + + /// This method does *not* `deinit` any items. + pub fn clearRetainingCapacityShallow(self: *Self) void { + self.#unmanaged.clearRetainingCapacity(); + } + + /// This method `deinit`s all items. + pub fn clearAndFree(self: *Self) void { + bun.memory.deinit(self.items()); + self.clearAndFreeShallow(); + } + + /// This method does *not* `deinit` any items. + pub fn clearAndFreeShallow(self: *Self) void { + self.#unmanaged.clearAndFree(self.getStdAllocator()); + } + + pub fn ensureTotalCapacity(self: *Self, new_capacity: usize) AllocError!void { + return self.#unmanaged.ensureTotalCapacity(self.getStdAllocator(), new_capacity); + } + + pub fn ensureTotalCapacityPrecise(self: *Self, new_capacity: usize) AllocError!void { + return self.#unmanaged.ensureTotalCapacityPrecise(self.getStdAllocator(), new_capacity); + } + + pub fn ensureUnusedCapacity(self: *Self, additional_count: usize) AllocError!void { + return self.#unmanaged.ensureUnusedCapacity(self.getStdAllocator(), additional_count); + } + + /// Note that this creates *shallow* copies of `init_value`. + pub fn expandToCapacity(self: *Self, init_value: T) void { + const len = self.items().len; + self.#unmanaged.expandToCapacity(); + @memset(self.items()[len..], init_value); + } + + pub fn pop(self: *Self) ?T { + return self.#unmanaged.pop(); + } + + pub fn getLast(self: *const Self) *T { + const items_ = self.items(); + return &items_[items_.len - 1]; + } + + pub fn getLastOrNull(self: *const Self) ?*T { + return if (self.isEmpty()) null else self.getLast(); + } + + pub fn isEmpty(self: *const Self) bool { + return self.items().len == 0; + } + + fn prepareForDeepShrink(self: *Self, new_len: usize) void { + const items_ = self.items(); + bun.assertf( + new_len <= items_.len, + "new_len ({d}) cannot exceed current len ({d})", + .{ new_len, items_.len }, + ); + bun.memory.deinit(items_[new_len..]); + } + + fn getStdAllocator(self: *const Self) std.mem.Allocator { + return bun.allocators.asStd(self.#allocator); + } + }; +} + +const bun = @import("bun"); +const std = @import("std"); +const AllocError = std.mem.Allocator.Error; From 0511fbf7b61c52d5fde0d64c2830b536988fd1f0 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Fri, 26 Sep 2025 15:18:28 -0700 Subject: [PATCH 31/43] Skip failing test on arm64 linux musl caused by third-party dependency --- test/cli/install/bun-repl.test.ts | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/cli/install/bun-repl.test.ts b/test/cli/install/bun-repl.test.ts index 7f95580d16..a906767e0b 100644 --- a/test/cli/install/bun-repl.test.ts +++ b/test/cli/install/bun-repl.test.ts @@ -1,7 +1,12 @@ import { expect, test } from "bun:test"; import "harness"; +import { isArm64, isMusl } from "harness"; // https://github.com/oven-sh/bun/issues/12070 -test("bun repl", () => { +test.skipIf( + // swc, which bun-repl uses, published a glibc build for arm64 musl + // and so it crashes on process.exit. + isMusl && isArm64, +)("bun repl", () => { expect(["repl", "-e", "process.exit(0)"]).toRun(); }); From 250d30eb7dc52ce8f302e3a903e1a41853d6714d Mon Sep 17 00:00:00 2001 From: pfg Date: Fri, 26 Sep 2025 16:39:08 -0700 Subject: [PATCH 32/43] Concurrent limit `--max-concurrency`, defaults to 20 (#22944) ### What does this PR do? Adds a max-concurrency flag to limit the amount of concurrent tests that run at once. Defaults to 20. Jest and Vitest both default to 5. ### How did you verify your code works? Tests --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Jarred Sumner --- docs/cli/test.md | 79 ++++++++++++++++++++++ src/bun.js/test/Execution.zig | 36 ++++++++-- src/bun.js/test/bun_test.zig | 4 +- src/bun.js/test/jest.zig | 1 + src/cli.zig | 1 + src/cli/Arguments.zig | 10 +++ src/cli/test_command.zig | 1 + test/js/bun/test/concurrent-max.fixture.ts | 49 ++++++++++++++ test/js/bun/test/concurrent.test.ts | 67 ++++++++++++++++++ test/js/sql/sql-mysql.test.ts | 3 + 10 files changed, 243 insertions(+), 8 deletions(-) create mode 100644 test/js/bun/test/concurrent-max.fixture.ts diff --git a/docs/cli/test.md b/docs/cli/test.md index e3729b8b62..627f48978b 100644 --- a/docs/cli/test.md +++ b/docs/cli/test.md @@ -109,6 +109,85 @@ Use the `--timeout` flag to specify a _per-test_ timeout in milliseconds. If a t $ bun test --timeout 20 ``` +## Concurrent test execution + +By default, Bun runs all tests sequentially within each test file. You can enable concurrent execution to run async tests in parallel, significantly speeding up test suites with independent tests. + +### `--concurrent` flag + +Use the `--concurrent` flag to run all tests concurrently within their respective files: + +```sh +$ bun test --concurrent +``` + +When this flag is enabled, all tests will run in parallel unless explicitly marked with `test.serial`. + +### `--max-concurrency` flag + +Control the maximum number of tests running simultaneously with the `--max-concurrency` flag: + +```sh +# Limit to 4 concurrent tests +$ bun test --concurrent --max-concurrency 4 + +# Default: 20 +$ bun test --concurrent +``` + +This helps prevent resource exhaustion when running many concurrent tests. The default value is 20. + +### `test.concurrent` + +Mark individual tests to run concurrently, even when the `--concurrent` flag is not used: + +```ts +import { test, expect } from "bun:test"; + +// These tests run in parallel with each other +test.concurrent("concurrent test 1", async () => { + await fetch("/api/endpoint1"); + expect(true).toBe(true); +}); + +test.concurrent("concurrent test 2", async () => { + await fetch("/api/endpoint2"); + expect(true).toBe(true); +}); + +// This test runs sequentially +test("sequential test", () => { + expect(1 + 1).toBe(2); +}); +``` + +### `test.serial` + +Force tests to run sequentially, even when the `--concurrent` flag is enabled: + +```ts +import { test, expect } from "bun:test"; + +let sharedState = 0; + +// These tests must run in order +test.serial("first serial test", () => { + sharedState = 1; + expect(sharedState).toBe(1); +}); + +test.serial("second serial test", () => { + // Depends on the previous test + expect(sharedState).toBe(1); + sharedState = 2; +}); + +// This test can run concurrently if --concurrent is enabled +test("independent test", () => { + expect(true).toBe(true); +}); +``` + ## Rerun tests Use the `--rerun-each` flag to run each test multiple times. This is useful for detecting flaky or non-deterministic test failures. diff --git a/src/bun.js/test/Execution.zig b/src/bun.js/test/Execution.zig index c57ae6414c..2c02034bd6 100644 --- a/src/bun.js/test/Execution.zig +++ b/src/bun.js/test/Execution.zig @@ -44,6 +44,8 @@ group_index: usize, pub const ConcurrentGroup = struct { sequence_start: usize, sequence_end: usize, + /// Index of the next sequence that has not been started yet + next_sequence_index: usize, executing: bool, remaining_incomplete_entries: usize, /// used by beforeAll to skip directly to afterAll if it fails @@ -56,6 +58,7 @@ pub const ConcurrentGroup = struct { .executing = false, .remaining_incomplete_entries = sequence_end - sequence_start, .failure_skip_to = next_index, + .next_sequence_index = 0, }; } pub fn tryExtend(this: *ConcurrentGroup, next_sequence_start: usize, next_sequence_end: usize) bool { @@ -243,11 +246,24 @@ pub fn step(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject bun.assert(sequence.active_index < sequence.entries(this).len); this.advanceSequence(sequence, group); - const sequence_result = try stepSequence(buntest_strong, globalThis, sequence, group, sequence_index, &now); + const sequence_result = try stepSequence(buntest_strong, globalThis, group, sequence_index, &now); switch (sequence_result) { .done => {}, .execute => |exec| return .{ .waiting = .{ .timeout = exec.timeout } }, } + // this sequence is complete; execute the next sequence + while (group.next_sequence_index < group.sequences(this).len) : (group.next_sequence_index += 1) { + const target_sequence = &group.sequences(this)[group.next_sequence_index]; + if (target_sequence.executing) continue; + const sequence_status = try stepSequence(buntest_strong, globalThis, group, group.next_sequence_index, &now); + switch (sequence_status) { + .done => continue, + .execute => |exec| { + return .{ .waiting = .{ .timeout = exec.timeout } }; + }, + } + } + // all sequences have started if (group.remaining_incomplete_entries == 0) { return try stepGroup(buntest_strong, globalThis, &now); } @@ -299,14 +315,21 @@ fn stepGroupOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalOb const buntest = buntest_strong.get(); const this = &buntest.execution; var final_status: AdvanceStatus = .done; - for (group.sequences(this), 0..) |*sequence, sequence_index| { - const sequence_status = try stepSequence(buntest_strong, globalThis, sequence, group, sequence_index, now); + const concurrent_limit = if (buntest.reporter) |reporter| reporter.jest.max_concurrency else blk: { + bun.assert(false); // probably can't get here because reporter is only set null when the file is exited + break :blk 20; + }; + var active_count: usize = 0; + for (0..group.sequences(this).len) |sequence_index| { + const sequence_status = try stepSequence(buntest_strong, globalThis, group, sequence_index, now); switch (sequence_status) { .done => {}, .execute => |exec| { const prev_timeout: bun.timespec = if (final_status == .execute) final_status.execute.timeout else .epoch; const this_timeout = exec.timeout; final_status = .{ .execute = .{ .timeout = prev_timeout.minIgnoreEpoch(this_timeout) } }; + active_count += 1; + if (concurrent_limit != 0 and active_count >= concurrent_limit) break; }, } } @@ -320,18 +343,19 @@ const AdvanceSequenceStatus = union(enum) { timeout: bun.timespec = .epoch, }, }; -fn stepSequence(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, sequence: *ExecutionSequence, group: *ConcurrentGroup, sequence_index: usize, now: *bun.timespec) !AdvanceSequenceStatus { +fn stepSequence(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, group: *ConcurrentGroup, sequence_index: usize, now: *bun.timespec) !AdvanceSequenceStatus { while (true) { - return try stepSequenceOne(buntest_strong, globalThis, sequence, group, sequence_index, now) orelse continue; + return try stepSequenceOne(buntest_strong, globalThis, group, sequence_index, now) orelse continue; } } /// returns null if the while loop should continue -fn stepSequenceOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, sequence: *ExecutionSequence, group: *ConcurrentGroup, sequence_index: usize, now: *bun.timespec) !?AdvanceSequenceStatus { +fn stepSequenceOne(buntest_strong: bun_test.BunTestPtr, globalThis: *jsc.JSGlobalObject, group: *ConcurrentGroup, sequence_index: usize, now: *bun.timespec) !?AdvanceSequenceStatus { groupLog.begin(@src()); defer groupLog.end(); const buntest = buntest_strong.get(); const this = &buntest.execution; + const sequence = &group.sequences(this)[sequence_index]; if (sequence.executing) { const active_entry = sequence.activeEntry(this) orelse { bun.debugAssert(false); // sequence is executing with no active entry diff --git a/src/bun.js/test/bun_test.zig b/src/bun.js/test/bun_test.zig index f25c65fb06..d4a93b45c1 100644 --- a/src/bun.js/test/bun_test.zig +++ b/src/bun.js/test/bun_test.zig @@ -152,12 +152,12 @@ pub const BunTest = struct { arena_allocator: std.heap.ArenaAllocator, arena: std.mem.Allocator, file_id: jsc.Jest.TestRunner.File.ID, - /// null if the runner has moved on to the next file + /// null if the runner has moved on to the next file but a strong reference to BunTest is stll keeping it alive reporter: ?*test_command.CommandLineReporter, timer: bun.api.Timer.EventLoopTimer = .{ .next = .epoch, .tag = .BunTest }, result_queue: ResultQueue, /// Whether tests in this file should default to concurrent execution - default_concurrent: bool = false, + default_concurrent: bool, phase: enum { collection, diff --git a/src/bun.js/test/jest.zig b/src/bun.js/test/jest.zig index 8ec2dbdfd4..521c3d064d 100644 --- a/src/bun.js/test/jest.zig +++ b/src/bun.js/test/jest.zig @@ -59,6 +59,7 @@ pub const TestRunner = struct { concurrent_test_glob: ?[]const []const u8 = null, last_file: u64 = 0, bail: u32 = 0, + max_concurrency: u32, allocator: std.mem.Allocator, diff --git a/src/cli.zig b/src/cli.zig index e5fbda4598..0fada12099 100644 --- a/src/cli.zig +++ b/src/cli.zig @@ -348,6 +348,7 @@ pub const Command = struct { coverage: TestCommand.CodeCoverageOptions = .{}, test_filter_pattern: ?[]const u8 = null, test_filter_regex: ?*RegularExpression = null, + max_concurrency: u32 = 20, file_reporter: ?TestCommand.FileReporter = null, reporter_outfile: ?[]const u8 = null, diff --git a/src/cli/Arguments.zig b/src/cli/Arguments.zig index 324c8165e2..0a3ac55964 100644 --- a/src/cli/Arguments.zig +++ b/src/cli/Arguments.zig @@ -206,6 +206,7 @@ pub const test_only_params = [_]ParamType{ clap.parseParam("-t, --test-name-pattern Run only tests with a name that matches the given regex.") catch unreachable, clap.parseParam("--reporter Test output reporter format. Available: 'junit' (requires --reporter-outfile). Default: console output.") catch unreachable, clap.parseParam("--reporter-outfile Output file path for the reporter format (required with --reporter).") catch unreachable, + clap.parseParam("--max-concurrency Maximum number of concurrent tests to execute at once. Default is 20.") catch unreachable, }; pub const test_params = test_only_params ++ runtime_params_ ++ transpiler_params_ ++ base_params_; @@ -417,6 +418,15 @@ pub fn parse(allocator: std.mem.Allocator, ctx: Command.Context, comptime cmd: C } } + if (args.option("--max-concurrency")) |max_concurrency| { + if (max_concurrency.len > 0) { + ctx.test_options.max_concurrency = std.fmt.parseInt(u32, max_concurrency, 10) catch { + Output.prettyErrorln("error: Invalid max-concurrency: \"{s}\"", .{max_concurrency}); + Global.exit(1); + }; + } + } + if (!ctx.test_options.coverage.enabled) { ctx.test_options.coverage.enabled = args.flag("--coverage"); } diff --git a/src/cli/test_command.zig b/src/cli/test_command.zig index e155123ab8..318cb518c7 100644 --- a/src/cli/test_command.zig +++ b/src/cli/test_command.zig @@ -1311,6 +1311,7 @@ pub const TestCommand = struct { .run_todo = ctx.test_options.run_todo, .only = ctx.test_options.only, .bail = ctx.test_options.bail, + .max_concurrency = ctx.test_options.max_concurrency, .filter_regex = ctx.test_options.test_filter_regex, .snapshots = Snapshots{ .allocator = ctx.allocator, diff --git a/test/js/bun/test/concurrent-max.fixture.ts b/test/js/bun/test/concurrent-max.fixture.ts new file mode 100644 index 0000000000..1838e9cdd7 --- /dev/null +++ b/test/js/bun/test/concurrent-max.fixture.ts @@ -0,0 +1,49 @@ +import { test, expect, afterAll, beforeEach, afterEach } from "bun:test"; + +// Track concurrent executions +let currentlyExecuting = 0; +const executionLog: number[] = []; + +beforeEach(() => { + currentlyExecuting++; + executionLog.push(currentlyExecuting); +}); +afterEach(() => currentlyExecuting--); + +function queue(fn: () => void) { + resolveQueue.push(fn); + if (!timeout) { + const set = () => + setTimeout(() => { + const cb = resolveQueue.shift(); + if (!cb) { + timeout = false; + return; + } + cb(); + set(); + }, 0); + set(); + timeout = true; + } else { + timeout = true; + } +} + +const resolveQueue: (() => void)[] = []; +let timeout: boolean = false; + +test.concurrent.each(Array.from({ length: 100 }, (_, i) => i + 1))(`concurrent test %d`, (i, done) => { + console.log(`start test ${i}`); + // Small delay to ensure tests overlap + queue(() => { + console.log(`end test ${i}`); + done(); + }); +}); + +// afterAll to report the max concurrency observed +afterAll(() => { + // Log execution pattern + console.log("Execution pattern: " + JSON.stringify(executionLog)); +}); diff --git a/test/js/bun/test/concurrent.test.ts b/test/js/bun/test/concurrent.test.ts index 05d03eb532..df1fe530c2 100644 --- a/test/js/bun/test/concurrent.test.ts +++ b/test/js/bun/test/concurrent.test.ts @@ -57,3 +57,70 @@ test("concurrent order", async () => { } `); }); + +test("max-concurrency limits concurrent tests", async () => { + // Test with max-concurrency=3 + const result = await Bun.spawn({ + cmd: [bunExe(), "test", "--max-concurrency", "3", import.meta.dir + "/concurrent-max.fixture.ts"], + stdout: "pipe", + stderr: "pipe", + env: bunEnv, + }); + const exitCode = await result.exited; + const stdout = await result.stdout.text(); + + expect(exitCode).toBe(0); + + // Extract max concurrent value from output + const maxMatch = stdout.match(/Execution pattern: ([^\n]+)/); + expect(maxMatch).toBeTruthy(); + const executionPattern = JSON.parse(maxMatch![1]); + + // Should be 1,2,3,3,3,3,3,... + const expected = Array.from({ length: 100 }, (_, i) => Math.min(i + 1, 3)); + expect(executionPattern).toEqual(expected); +}); + +test("max-concurrency default is 20", async () => { + const result = await Bun.spawn({ + cmd: [bunExe(), "test", import.meta.dir + "/concurrent-max.fixture.ts"], + stdout: "pipe", + stderr: "pipe", + env: bunEnv, + }); + const exitCode = await result.exited; + const stdout = await result.stdout.text(); + + expect(exitCode).toBe(0); + + // Extract max concurrent value from output + const maxMatch = stdout.match(/Execution pattern: ([^\n]+)/); + expect(maxMatch).toBeTruthy(); + const executionPattern = JSON.parse(maxMatch![1]); + + // Should be 1,2,3,...,18,19,20,20,20,20,20,20,... + const expected = Array.from({ length: 100 }, (_, i) => Math.min(i + 1, 20)); + expect(executionPattern).toEqual(expected); +}); + +test("zero removes max-concurrency", async () => { + const result = await Bun.spawn({ + cmd: [bunExe(), "test", "--max-concurrency", "0", import.meta.dir + "/concurrent-max.fixture.ts"], + stdout: "pipe", + stderr: "pipe", + env: bunEnv, + }); + const exitCode = await result.exited; + const stdout = await result.stdout.text(); + + expect(exitCode).toBe(0); + + // Extract max concurrent value from output + const maxMatch = stdout.match(/Execution pattern: ([^\n]+)/); + expect(maxMatch).toBeTruthy(); + const executionPattern = JSON.parse(maxMatch![1]); + + // Should be 1,2,3,...,18,19,20,20,20,20,20,20,... + const expected = Array.from({ length: 100 }, (_, i) => i + 1); + expect(executionPattern).toEqual(expected); +}); diff --git a/test/js/sql/sql-mysql.test.ts b/test/js/sql/sql-mysql.test.ts index e3a6e6a4e4..ca60698759 100644 --- a/test/js/sql/sql-mysql.test.ts +++ b/test/js/sql/sql-mysql.test.ts @@ -287,11 +287,13 @@ if (isDockerEnabled()) { }); test("Create table", async () => { + await using sql = new SQL({ ...getOptions(), max: 1 }); await sql`create table test_my_table(id int)`; await sql`drop table test_my_table`; }); test("Drop table", async () => { + await using sql = new SQL({ ...getOptions(), max: 1 }); await sql`create table drop_table_test(id int)`; await sql`drop table drop_table_test`; // Verify that table is dropped @@ -498,6 +500,7 @@ if (isDockerEnabled()) { }); test("Prepared transaction", async () => { + await using sql = new SQL({ ...getOptions(), max: 1 }); await sql`create table test_prepared_transaction (a int)`; try { From eb04e4e6404bd5de869ae36d3414d90955e6728e Mon Sep 17 00:00:00 2001 From: "taylor.fish" Date: Fri, 26 Sep 2025 17:18:30 -0700 Subject: [PATCH 33/43] Make `bun.webcore.Blob` smaller and ref-counted (#23015) Reduce the size of `bun.webcore.Blob` from 120 bytes to 96. Also make it ref-counted: in-progress work on improving the bindings generator depends on this, as it means C++ can pass a pointer to the `Blob` to Zig without risking it being destroyed if the GC collects the associated `JSBlob`. Note that this PR depends on #23013. (For internal tracking: fixes STAB-1289, STAB-1290) --- src/StandaloneModuleGraph.zig | 1 - src/ast/Macro.zig | 15 +-- src/bun.js/api/BunObject.zig | 4 - src/bun.js/api/server/FileRoute.zig | 4 +- src/bun.js/api/server/StaticRoute.zig | 6 +- src/bun.js/bindings/blob.cpp | 6 +- src/bun.js/bindings/blob.h | 68 ++++++---- src/bun.js/webcore/Blob.zig | 160 ++++++++++++++--------- src/bun.js/webcore/Body.zig | 10 +- src/bun.js/webcore/ObjectURLRegistry.zig | 1 - src/bun.js/webcore/S3Client.zig | 1 - src/bun.js/webcore/S3File.zig | 4 +- src/bun.js/webcore/blob/read_file.zig | 3 +- 13 files changed, 165 insertions(+), 118 deletions(-) diff --git a/src/StandaloneModuleGraph.zig b/src/StandaloneModuleGraph.zig index 68b3006688..0da83a40e6 100644 --- a/src/StandaloneModuleGraph.zig +++ b/src/StandaloneModuleGraph.zig @@ -199,7 +199,6 @@ pub const StandaloneModuleGraph = struct { store.ref(); const b = bun.webcore.Blob.initWithStore(store, globalObject).new(); - b.allocator = bun.default_allocator; if (bun.http.MimeType.byExtensionNoDefault(bun.strings.trimLeadingChar(std.fs.path.extension(this.name), '.'))) |mime| { store.mime_type = mime; diff --git a/src/ast/Macro.zig b/src/ast/Macro.zig index 623d996554..97863921a4 100644 --- a/src/ast/Macro.zig +++ b/src/ast/Macro.zig @@ -325,7 +325,7 @@ pub const Runner = struct { return _entry.value_ptr.*; } - var blob_: ?jsc.WebCore.Blob = null; + var blob_: ?*const jsc.WebCore.Blob = null; const mime_type: ?MimeType = null; if (value.jsType() == .DOMWrapper) { @@ -334,30 +334,23 @@ pub const Runner = struct { } else if (value.as(jsc.WebCore.Request)) |resp| { return this.run(try resp.getBlobWithoutCallFrame(this.global)); } else if (value.as(jsc.WebCore.Blob)) |resp| { - blob_ = resp.*; - blob_.?.allocator = null; + blob_ = resp; } else if (value.as(bun.api.ResolveMessage) != null or value.as(bun.api.BuildMessage) != null) { _ = this.macro.vm.uncaughtException(this.global, value, false); return error.MacroFailed; } } - if (blob_) |*blob| { - const out_expr = Expr.fromBlob( + if (blob_) |blob| { + return Expr.fromBlob( blob, this.allocator, mime_type, this.log, this.caller.loc, ) catch { - blob.deinit(); return error.MacroFailed; }; - if (out_expr.data == .e_string) { - blob.deinit(); - } - - return out_expr; } return Expr.init(E.String, E.String.empty, this.caller.loc); diff --git a/src/bun.js/api/BunObject.zig b/src/bun.js/api/BunObject.zig index a5af653677..edf05d293a 100644 --- a/src/bun.js/api/BunObject.zig +++ b/src/bun.js/api/BunObject.zig @@ -1337,7 +1337,6 @@ pub fn getEmbeddedFiles(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) bun.J // We call .dupe() on this to ensure that we don't return a blob that might get freed later. const input_blob = file.blob(globalThis); const blob = jsc.WebCore.Blob.new(input_blob.dupeWithContentType(true)); - blob.allocator = bun.default_allocator; blob.name = input_blob.name.dupeRef(); try array.putIndex(globalThis, i, blob.toJS(globalThis)); i += 1; @@ -2048,7 +2047,6 @@ pub fn createBunStdin(globalThis: *jsc.JSGlobalObject) callconv(.C) jsc.JSValue var blob = jsc.WebCore.Blob.new( jsc.WebCore.Blob.initWithStore(store, globalThis), ); - blob.allocator = bun.default_allocator; return blob.toJS(globalThis); } @@ -2059,7 +2057,6 @@ pub fn createBunStderr(globalThis: *jsc.JSGlobalObject) callconv(.C) jsc.JSValue var blob = jsc.WebCore.Blob.new( jsc.WebCore.Blob.initWithStore(store, globalThis), ); - blob.allocator = bun.default_allocator; return blob.toJS(globalThis); } @@ -2070,7 +2067,6 @@ pub fn createBunStdout(globalThis: *jsc.JSGlobalObject) callconv(.C) jsc.JSValue var blob = jsc.WebCore.Blob.new( jsc.WebCore.Blob.initWithStore(store, globalThis), ); - blob.allocator = bun.default_allocator; return blob.toJS(globalThis); } diff --git a/src/bun.js/api/server/FileRoute.zig b/src/bun.js/api/server/FileRoute.zig index 88b3eb03a1..ab4b57bfa8 100644 --- a/src/bun.js/api/server/FileRoute.zig +++ b/src/bun.js/api/server/FileRoute.zig @@ -68,7 +68,7 @@ pub fn fromJS(globalThis: *jsc.JSGlobalObject, argument: jsc.JSValue) bun.JSErro var blob = response.body.value.use(); blob.globalThis = globalThis; - blob.allocator = null; + bun.assertf(!blob.isHeapAllocated(), "expected blob not to be heap-allocated", .{}); response.body.value = .{ .Blob = blob.dupe() }; const headers = bun.handleOom(Headers.from(response.init.headers, bun.default_allocator, .{ .body = &.{ .Blob = blob } })); @@ -87,7 +87,7 @@ pub fn fromJS(globalThis: *jsc.JSGlobalObject, argument: jsc.JSValue) bun.JSErro if (blob.needsToReadFile()) { var b = blob.dupe(); b.globalThis = globalThis; - b.allocator = null; + bun.assertf(!b.isHeapAllocated(), "expected blob not to be heap-allocated", .{}); return bun.new(FileRoute, .{ .ref_count = .init(), .server = null, diff --git a/src/bun.js/api/server/StaticRoute.zig b/src/bun.js/api/server/StaticRoute.zig index 7fbfcd5758..9e931826a8 100644 --- a/src/bun.js/api/server/StaticRoute.zig +++ b/src/bun.js/api/server/StaticRoute.zig @@ -113,7 +113,11 @@ pub fn fromJS(globalThis: *jsc.JSGlobalObject, argument: jsc.JSValue) bun.JSErro } var blob = response.body.value.use(); blob.globalThis = globalThis; - blob.allocator = null; + bun.assertf( + !blob.isHeapAllocated(), + "expected blob not to be heap-allocated", + .{}, + ); response.body.value = .{ .Blob = blob.dupe() }; break :brk .{ .Blob = blob }; diff --git a/src/bun.js/bindings/blob.cpp b/src/bun.js/bindings/blob.cpp index 2e650ae479..315b5d0468 100644 --- a/src/bun.js/bindings/blob.cpp +++ b/src/bun.js/bindings/blob.cpp @@ -2,14 +2,14 @@ #include "ZigGeneratedClasses.h" extern "C" JSC::EncodedJSValue SYSV_ABI Blob__create(JSC::JSGlobalObject* globalObject, void* impl); -extern "C" void* Blob__setAsFile(void* impl, BunString* filename); +extern "C" void Blob__setAsFile(void* impl, BunString* filename); namespace WebCore { JSC::JSValue toJS(JSC::JSGlobalObject* lexicalGlobalObject, JSDOMGlobalObject* globalObject, WebCore::Blob& impl) { BunString filename = Bun::toString(impl.fileName()); - impl.m_impl = Blob__setAsFile(impl.impl(), &filename); + Blob__setAsFile(impl.impl(), &filename); return JSC::JSValue::decode(Blob__create(lexicalGlobalObject, Blob__dupe(impl.impl()))); } @@ -28,7 +28,7 @@ JSC::JSValue toJSNewlyCreated(JSC::JSGlobalObject* lexicalGlobalObject, JSDOMGlo size_t Blob::memoryCost() const { - return sizeof(Blob) + JSBlob::memoryCost(m_impl); + return sizeof(Blob) + JSBlob::memoryCost(impl()); } } diff --git a/src/bun.js/bindings/blob.h b/src/bun.js/bindings/blob.h index 9e59790070..5b3b923154 100644 --- a/src/bun.js/bindings/blob.h +++ b/src/bun.js/bindings/blob.h @@ -8,44 +8,60 @@ namespace WebCore { extern "C" void* Blob__dupeFromJS(JSC::EncodedJSValue impl); extern "C" void* Blob__dupe(void* impl); -extern "C" void Blob__destroy(void* impl); extern "C" void* Blob__getDataPtr(JSC::EncodedJSValue blob); extern "C" size_t Blob__getSize(JSC::EncodedJSValue blob); extern "C" void* Blob__fromBytes(JSC::JSGlobalObject* globalThis, const void* ptr, size_t len); +extern "C" void* Blob__ref(void* impl); +extern "C" void* Blob__deref(void* impl); +// Opaque type corresponding to `bun.webcore.Blob`. +class BlobImpl; + +struct BlobImplRefDerefTraits { + static ALWAYS_INLINE BlobImpl* refIfNotNull(BlobImpl* ptr) + { + if (ptr) [[likely]] + Blob__ref(ptr); + return ptr; + } + + static ALWAYS_INLINE BlobImpl& ref(BlobImpl& ref) + { + Blob__ref(&ref); + return ref; + } + + static ALWAYS_INLINE void derefIfNotNull(BlobImpl* ptr) + { + if (ptr) [[likely]] + Blob__deref(ptr); + } +}; + +using BlobRef = Ref, BlobImplRefDerefTraits>; +using BlobRefPtr = RefPtr, BlobImplRefDerefTraits>; + +// TODO: Now that `bun.webcore.Blob` is ref-counted, can `RefPtr` be replaced with `Blob`? class Blob : public RefCounted { public: - void* impl() + BlobImpl* impl() const { - return m_impl; + return m_impl.get(); } static RefPtr create(JSC::JSValue impl) { - void* implPtr = Blob__dupeFromJS(JSValue::encode(impl)); - if (!implPtr) - return nullptr; - - return adoptRef(*new Blob(implPtr)); + return createAdopted(Blob__dupeFromJS(JSValue::encode(impl))); } static RefPtr create(std::span bytes, JSC::JSGlobalObject* globalThis) { - return adoptRef(*new Blob(Blob__fromBytes(globalThis, bytes.data(), bytes.size()))); + return createAdopted(Blob__fromBytes(globalThis, bytes.data(), bytes.size())); } static RefPtr create(void* ptr) { - void* implPtr = Blob__dupe(ptr); - if (!implPtr) - return nullptr; - - return adoptRef(*new Blob(implPtr)); - } - - ~Blob() - { - Blob__destroy(m_impl); + return createAdopted(Blob__dupe(ptr)); } String fileName() @@ -57,17 +73,25 @@ public: { m_fileName = fileName; } - void* m_impl; size_t memoryCost() const; private: Blob(void* impl, String fileName = String()) + : m_impl(adoptRef, BlobImplRefDerefTraits>( + static_cast(impl))) + , m_fileName(std::move(fileName)) { - m_impl = impl; - m_fileName = fileName; } + static RefPtr createAdopted(void* ptr) + { + if (!ptr) + return nullptr; + return adoptRef(new Blob(ptr)); + } + + BlobRefPtr m_impl; String m_fileName; }; diff --git a/src/bun.js/webcore/Blob.zig b/src/bun.js/webcore/Blob.zig index 4b8242bb5c..37e6ff6f25 100644 --- a/src/bun.js/webcore/Blob.zig +++ b/src/bun.js/webcore/Blob.zig @@ -13,7 +13,12 @@ pub const read_file = @import("./blob/read_file.zig"); pub const write_file = @import("./blob/write_file.zig"); pub const copy_file = @import("./blob/copy_file.zig"); -pub const new = bun.TrivialNew(@This()); +pub fn new(blob: Blob) *Blob { + const result = bun.new(Blob, blob); + result.#ref_count = .init(1); + return result; +} + pub const js = jsc.Codegen.JSBlob; pub const fromJS = js.fromJS; pub const fromJSDirect = js.fromJSDirect; @@ -22,9 +27,6 @@ reported_estimated_size: usize = 0, size: SizeType = 0, offset: SizeType = 0, -/// When set, the blob will be freed on finalization callbacks -/// If the blob is contained in Response or Request, this must be null -allocator: ?std.mem.Allocator = null, store: ?*Store = null, content_type: string = "", content_type_allocated: bool = false, @@ -32,11 +34,15 @@ content_type_was_set: bool = false, /// JavaScriptCore strings are either latin1 or UTF-16 /// When UTF-16, they're nearly always due to non-ascii characters -is_all_ascii: ?bool = null, +charset: Charset = .unknown, /// Was it created via file constructor? is_jsdom_file: bool = false, +/// Reference count, for use with `bun.ptr.ExternalShared`. If the reference count is 0, that means +/// this blob is *not* heap-allocated, and will not be freed in `deinit`. +#ref_count: bun.ptr.RawRefCount(u32, .single_threaded) = .init(0), + globalThis: *JSGlobalObject = undefined, last_modified: f64 = 0.0, @@ -45,6 +51,8 @@ last_modified: f64 = 0.0, /// https://github.com/oven-sh/bun/issues/10178 name: bun.String = .dead, +pub const Ref = bun.ptr.ExternalShared(Blob); + /// Max int of double precision /// 9 petabytes is probably enough for awhile /// We want to avoid coercing to a BigInt because that's a heap allocation @@ -72,7 +80,7 @@ pub fn getFormDataEncoding(this: *Blob) ?*bun.FormData.AsyncFormData { var content_type_slice: ZigString.Slice = this.getContentType() orelse return null; defer content_type_slice.deinit(); const encoding = bun.FormData.Encoding.get(content_type_slice.slice()) orelse return null; - return bun.handleOom(bun.FormData.AsyncFormData.init(this.allocator orelse bun.default_allocator, encoding)); + return bun.handleOom(bun.FormData.AsyncFormData.init(bun.default_allocator, encoding)); } pub fn hasContentTypeFromUser(this: *const Blob) bool { @@ -113,6 +121,7 @@ pub fn doReadFromS3(this: *Blob, comptime Function: anytype, global: *JSGlobalOb }; return S3BlobDownloadTask.init(global, this, WrappedFn.wrapped); } + pub fn doReadFile(this: *Blob, comptime Function: anytype, global: *JSGlobalObject) JSValue { debug("doReadFile", .{}); @@ -497,7 +506,7 @@ fn _onStructuredCloneDeserialize( if (version == 3) break :versions; } - blob.allocator = allocator; + bun.assertf(blob.isHeapAllocated(), "expected blob to be heap-allocated", .{}); blob.offset = @as(u52, @intCast(offset)); if (content_type.len > 0) { blob.content_type = content_type; @@ -612,7 +621,7 @@ export fn Blob__dupeFromJS(value: jsc.JSValue) ?*Blob { return Blob__dupe(this); } -export fn Blob__setAsFile(this: *Blob, path_str: *bun.String) *Blob { +export fn Blob__setAsFile(this: *Blob, path_str: *bun.String) void { this.is_jsdom_file = true; // This is not 100% correct... @@ -624,19 +633,10 @@ export fn Blob__setAsFile(this: *Blob, path_str: *bun.String) *Blob { } } } - - return this; } -export fn Blob__dupe(ptr: *anyopaque) *Blob { - const this = bun.cast(*Blob, ptr); - const new_ptr = new(this.dupeWithContentType(true)); - new_ptr.allocator = bun.default_allocator; - return new_ptr; -} - -export fn Blob__destroy(this: *Blob) void { - this.finalize(); +export fn Blob__dupe(this: *Blob) *Blob { + return new(this.dupeWithContentType(true)); } export fn Blob__getFileNameString(this: *Blob) callconv(.C) bun.String { @@ -649,7 +649,6 @@ export fn Blob__getFileNameString(this: *Blob) callconv(.C) bun.String { comptime { _ = Blob__dupeFromJS; - _ = Blob__destroy; _ = Blob__dupe; _ = Blob__setAsFile; _ = Blob__getFileNameString; @@ -1076,10 +1075,7 @@ pub fn writeFileWithSourceDestination(ctx: *jsc.JSGlobalObject, source_blob: *Bl // this is an edgecase // it will happen if someone did Bun.write(new Blob([123]), new Blob([456])) // eventually, this could be like Buffer.concat - var clone = source_blob.dupe(); - clone.allocator = bun.default_allocator; - const cloned = Blob.new(clone); - cloned.allocator = bun.default_allocator; + const cloned = Blob.new(source_blob.dupe()); return JSPromise.resolvedPromiseValue(ctx, cloned.toJS(ctx)); } else if (destination_type == .bytes and (source_type == .file or source_type == .s3)) { const blob_value = source_blob.getSliceFrom(ctx, 0, 0, "", false); @@ -1820,7 +1816,6 @@ pub fn JSDOMFile__construct_(globalThis: *jsc.JSGlobalObject, callframe: *jsc.Ca } var blob_ = Blob.new(blob); - blob_.allocator = allocator; blob_.is_jsdom_file = true; return blob_; } @@ -1908,7 +1903,6 @@ pub fn constructBunFile( } var ptr = Blob.new(blob); - ptr.allocator = bun.default_allocator; return ptr.toJS(globalObject); } @@ -2769,7 +2763,7 @@ pub fn getSliceFrom(this: *Blob, globalThis: *jsc.JSGlobalObject, relativeStart: const offset = this.offset +| @as(SizeType, @intCast(relativeStart)); const len = @as(SizeType, @intCast(@max(relativeEnd -| relativeStart, 0))); - // This copies over the is_all_ascii flag + // This copies over the charset field // which is okay because this will only be a <= slice var blob = this.dupe(); blob.offset = offset; @@ -2785,7 +2779,6 @@ pub fn getSliceFrom(this: *Blob, globalThis: *jsc.JSGlobalObject, relativeStart: blob.content_type_was_set = this.content_type_was_set or content_type_was_allocated; var blob_ = Blob.new(blob); - blob_.allocator = bun.default_allocator; return blob_.toJS(globalThis); } @@ -2806,7 +2799,6 @@ pub fn getSlice( if (this.size == 0) { const empty = Blob.initEmpty(globalThis); var ptr = Blob.new(empty); - ptr.allocator = allocator; return ptr.toJS(globalThis); } @@ -3059,15 +3051,12 @@ export fn Blob__getSize(value: jsc.JSValue) callconv(.C) usize { export fn Blob__fromBytes(globalThis: *jsc.JSGlobalObject, ptr: ?[*]const u8, len: usize) callconv(.C) *Blob { if (ptr == null or len == 0) { const blob = new(initEmpty(globalThis)); - blob.allocator = bun.default_allocator; return blob; } const bytes = bun.handleOom(bun.default_allocator.dupe(u8, ptr.?[0..len])); const store = Store.init(bytes, bun.default_allocator); - var blob = initWithStore(store, globalThis); - blob.allocator = bun.default_allocator; - return new(blob); + return new(initWithStore(store, globalThis)); } pub fn getStat(this: *Blob, globalThis: *jsc.JSGlobalObject, callback: *jsc.CallFrame) bun.JSError!jsc.JSValue { @@ -3237,14 +3226,17 @@ pub fn constructor(globalThis: *jsc.JSGlobalObject, callframe: *jsc.CallFrame) b } blob.calculateEstimatedByteSize(); - - var blob_ = Blob.new(blob); - blob_.allocator = allocator; - return blob_; + return Blob.new(blob); } pub fn finalize(this: *Blob) void { - this.deinit(); + bun.assertf( + this.isHeapAllocated(), + "`finalize` may only be called on a heap-allocated Blob", + .{}, + ); + var shared = Blob.Ref.adopt(this); + shared.deinit(); } pub fn initWithAllASCII(bytes: []u8, allocator: std.mem.Allocator, globalThis: *JSGlobalObject, is_all_ascii: bool) Blob { @@ -3257,10 +3249,9 @@ pub fn initWithAllASCII(bytes: []u8, allocator: std.mem.Allocator, globalThis: * return Blob{ .size = @as(SizeType, @truncate(bytes.len)), .store = store, - .allocator = null, .content_type = "", .globalThis = globalThis, - .is_all_ascii = is_all_ascii, + .charset = .fromIsAllASCII(is_all_ascii), }; } @@ -3272,7 +3263,6 @@ pub fn init(bytes: []u8, allocator: std.mem.Allocator, globalThis: *JSGlobalObje Blob.Store.init(bytes, allocator) else null, - .allocator = null, .content_type = "", .globalThis = globalThis, }; @@ -3290,7 +3280,6 @@ pub fn createWithBytesAndAllocator( Blob.Store.init(bytes, allocator) else null, - .allocator = null, .content_type = if (was_string) MimeType.text.value else "", .globalThis = globalThis, }; @@ -3343,7 +3332,6 @@ pub fn initWithStore(store: *Blob.Store, globalThis: *JSGlobalObject) Blob { return Blob{ .size = store.size(), .store = store, - .allocator = null, .content_type = if (store.data == .file) store.data.file.mime_type.value else @@ -3356,7 +3344,6 @@ pub fn initEmpty(globalThis: *JSGlobalObject) Blob { return Blob{ .size = 0, .store = null, - .allocator = null, .content_type = "", .globalThis = globalThis, }; @@ -3383,7 +3370,8 @@ pub fn dupe(this: *const Blob) Blob { pub fn dupeWithContentType(this: *const Blob, include_content_type: bool) Blob { if (this.store != null) this.store.?.ref(); var duped = this.*; - if (duped.content_type_allocated and duped.allocator != null and !include_content_type) { + duped.setNotHeapAllocated(); + if (duped.content_type_allocated and duped.isHeapAllocated() and !include_content_type) { // for now, we just want to avoid a use-after-free here if (jsc.VirtualMachine.get().mimeType(duped.content_type)) |mime| { @@ -3400,18 +3388,16 @@ pub fn dupeWithContentType(this: *const Blob, include_content_type: bool) Blob { if (this.content_type_was_set) { duped.content_type_was_set = duped.content_type.len > 0; } - } else if (duped.content_type_allocated and duped.allocator != null and include_content_type) { + } else if (duped.content_type_allocated and duped.isHeapAllocated() and include_content_type) { duped.content_type = bun.handleOom(bun.default_allocator.dupe(u8, this.content_type)); } duped.name = duped.name.dupeRef(); - - duped.allocator = null; return duped; } pub fn toJS(this: *Blob, globalObject: *jsc.JSGlobalObject) jsc.JSValue { // if (comptime Environment.allow_assert) { - // assert(this.allocator != null); + // assert(this.isHeapAllocated()); // } this.calculateEstimatedByteSize(); @@ -3427,10 +3413,7 @@ pub fn deinit(this: *Blob) void { this.name.deref(); this.name = .dead; - // TODO: remove this field, make it a boolean. - if (this.allocator) |alloc| { - this.allocator = null; - bun.debugAssert(alloc.vtable == bun.default_allocator.vtable); + if (this.isHeapAllocated()) { bun.destroy(this); } } @@ -3445,8 +3428,9 @@ pub fn sharedView(this: *const Blob) []const u8 { } pub const Lifetime = jsc.WebCore.Lifetime; + pub fn setIsASCIIFlag(this: *Blob, is_all_ascii: bool) void { - this.is_all_ascii = is_all_ascii; + this.charset = .fromIsAllASCII(is_all_ascii); // if this Blob represents the entire binary data // which will be pretty common // we can update the store's is_all_ascii flag @@ -3479,7 +3463,7 @@ pub fn toStringWithBytes(this: *Blob, global: *JSGlobalObject, raw_bytes: []cons // null == unknown // false == can't be - const could_be_all_ascii = this.is_all_ascii orelse this.store.?.is_all_ascii; + const could_be_all_ascii = this.isAllASCII() orelse this.store.?.is_all_ascii; if (could_be_all_ascii == null or !could_be_all_ascii.?) { // if toUTF16Alloc returns null, it means there are no non-ASCII characters @@ -3589,7 +3573,7 @@ pub fn toJSONWithBytes(this: *Blob, global: *JSGlobalObject, raw_bytes: []const } // null == unknown // false == can't be - const could_be_all_ascii = this.is_all_ascii orelse this.store.?.is_all_ascii; + const could_be_all_ascii = this.isAllASCII() orelse this.store.?.is_all_ascii; defer if (comptime lifetime == .temporary) bun.default_allocator.free(@constCast(buf)); if (could_be_all_ascii == null or !could_be_all_ascii.?) { @@ -3870,7 +3854,7 @@ fn fromJSWithoutDeferGC( if (top_value.as(Blob)) |blob| { if (comptime move) { var _blob = blob.*; - _blob.allocator = null; + _blob.setNotHeapAllocated(); blob.transfer(); return _blob; } else { @@ -3978,7 +3962,7 @@ fn fromJSWithoutDeferGC( .DOMWrapper => { if (item.as(Blob)) |blob| { - could_have_non_ascii = could_have_non_ascii or !(blob.is_all_ascii orelse false); + could_have_non_ascii = could_have_non_ascii or blob.charset != .all_ascii; joiner.pushStatic(blob.sharedView()); continue; } else if (current.toSliceClone(global)) |sliced| { @@ -3997,7 +3981,7 @@ fn fromJSWithoutDeferGC( .DOMWrapper => { if (current.as(Blob)) |blob| { - could_have_non_ascii = could_have_non_ascii or !(blob.is_all_ascii orelse false); + could_have_non_ascii = could_have_non_ascii or blob.charset != .all_ascii; joiner.pushStatic(blob.sharedView()); } else if (current.toSliceClone(global)) |sliced| { const allocator = sliced.allocator.get(); @@ -4144,7 +4128,6 @@ pub const Any = union(enum) { }, .blob => { const result = Blob.new(this.toBlob(globalThis)); - result.allocator = bun.default_allocator; result.globalThis = globalThis; return result.toJS(globalThis); }, @@ -4355,7 +4338,7 @@ pub const Any = union(enum) { pub fn wasString(self: *const @This()) bool { return switch (self.*) { - .Blob => self.Blob.is_all_ascii orelse false, + .Blob => self.Blob.charset == .all_ascii, .WTFStringImpl => true, // .InlineBlob => self.InlineBlob.was_string, .InternalBlob => self.InternalBlob.was_string, @@ -4761,6 +4744,61 @@ pub fn FileCloser(comptime This: type) type { }; } +/// This takes up less space than a `?bool`. +pub const Charset = enum { + unknown, + all_ascii, + non_ascii, + + pub fn fromIsAllASCII(is_all_ascii: ?bool) Charset { + return if (is_all_ascii orelse return .unknown) + .all_ascii + else + .non_ascii; + } +}; + +pub fn isAllASCII(self: *const Blob) ?bool { + return switch (self.charset) { + .unknown => null, + .all_ascii => true, + .non_ascii => false, + }; +} + +/// Takes ownership of `self` by value. Invalidates `self`. +pub fn takeOwnership(self: *Blob) Blob { + var result = self.*; + self.* = undefined; + result.setNotHeapAllocated(); + return result; +} + +pub fn isHeapAllocated(self: *const Blob) bool { + return self.#ref_count.raw_value != 0; +} + +fn setNotHeapAllocated(self: *Blob) void { + self.#ref_count = .init(0); +} + +pub const external_shared_descriptor = struct { + pub const ref = Blob__ref; + pub const deref = Blob__deref; +}; + +export fn Blob__ref(self: *Blob) void { + bun.assertf(self.isHeapAllocated(), "cannot ref: this Blob is not heap-allocated", .{}); + self.#ref_count.increment(); +} + +export fn Blob__deref(self: *Blob) void { + bun.assertf(self.isHeapAllocated(), "cannot deref: this Blob is not heap-allocated", .{}); + if (self.#ref_count.decrement() == .should_destroy) { + self.deinit(); + } +} + const WriteFilePromise = write_file.WriteFilePromise; const WriteFileWaitFromLockedValueTask = write_file.WriteFileWaitFromLockedValueTask; const NewReadFileHandler = read_file.NewReadFileHandler; diff --git a/src/bun.js/webcore/Body.zig b/src/bun.js/webcore/Body.zig index fc47b6570c..04dbf20f0f 100644 --- a/src/bun.js/webcore/Body.zig +++ b/src/bun.js/webcore/Body.zig @@ -717,7 +717,6 @@ pub const Value = union(Tag) { }, .none, .getBlob => { var blob = Blob.new(new.use()); - blob.allocator = bun.default_allocator; if (headers) |fetch_headers| { if (fetch_headers.fastGet(.ContentType)) |content_type| { var content_slice = content_type.toSlice(bun.default_allocator); @@ -761,7 +760,7 @@ pub const Value = union(Tag) { switch (this.*) { .Blob => { const new_blob = this.Blob; - assert(new_blob.allocator == null); // owned by Body + assert(!new_blob.isHeapAllocated()); // owned by Body this.* = .{ .Used = {} }; return new_blob; }, @@ -1080,7 +1079,7 @@ pub fn extract( body.value = try Value.fromJS(globalThis, value); if (body.value == .Blob) { - assert(body.value.Blob.allocator == null); // owned by Body + assert(!body.value.Blob.isHeapAllocated()); // owned by Body } return body; } @@ -1289,14 +1288,13 @@ pub fn Mixin(comptime Type: type) type { } var blob = Blob.new(value.use()); - blob.allocator = bun.default_allocator; if (blob.content_type.len == 0) { if (this.getFetchHeaders()) |fetch_headers| { if (fetch_headers.fastGet(.ContentType)) |content_type| { - var content_slice = content_type.toSlice(blob.allocator.?); + var content_slice = content_type.toSlice(bun.default_allocator); defer content_slice.deinit(); var allocated = false; - const mimeType = MimeType.init(content_slice.slice(), blob.allocator.?, &allocated); + const mimeType = MimeType.init(content_slice.slice(), bun.default_allocator, &allocated); blob.content_type = mimeType.value; blob.content_type_allocated = allocated; blob.content_type_was_set = true; diff --git a/src/bun.js/webcore/ObjectURLRegistry.zig b/src/bun.js/webcore/ObjectURLRegistry.zig index 82f1dbba5d..c7c99da422 100644 --- a/src/bun.js/webcore/ObjectURLRegistry.zig +++ b/src/bun.js/webcore/ObjectURLRegistry.zig @@ -65,7 +65,6 @@ pub fn resolveAndDupe(this: *ObjectURLRegistry, pathname: []const u8) ?jsc.WebCo pub fn resolveAndDupeToJS(this: *ObjectURLRegistry, pathname: []const u8, globalObject: *jsc.JSGlobalObject) ?jsc.JSValue { var blob = jsc.WebCore.Blob.new(this.resolveAndDupe(pathname) orelse return null); - blob.allocator = bun.default_allocator; return blob.toJS(globalObject); } diff --git a/src/bun.js/webcore/S3Client.zig b/src/bun.js/webcore/S3Client.zig index 5a6a239f5b..414caa1d1a 100644 --- a/src/bun.js/webcore/S3Client.zig +++ b/src/bun.js/webcore/S3Client.zig @@ -136,7 +136,6 @@ pub const S3Client = struct { errdefer path.deinit(); const options = args.nextEat(); var blob = Blob.new(try S3File.constructS3FileWithS3CredentialsAndOptions(globalThis, path, options, ptr.credentials, ptr.options, ptr.acl, ptr.storage_class)); - blob.allocator = bun.default_allocator; return blob.toJS(globalThis); } diff --git a/src/bun.js/webcore/S3File.zig b/src/bun.js/webcore/S3File.zig index b8dfbb73a9..945a5c9772 100644 --- a/src/bun.js/webcore/S3File.zig +++ b/src/bun.js/webcore/S3File.zig @@ -343,9 +343,7 @@ fn constructS3FileInternal( path: jsc.Node.PathLike, options: ?jsc.JSValue, ) bun.JSError!*Blob { - var ptr = Blob.new(try constructS3FileInternalStore(globalObject, path, options)); - ptr.allocator = bun.default_allocator; - return ptr; + return Blob.new(try constructS3FileInternalStore(globalObject, path, options)); } pub const S3BlobStatTask = struct { diff --git a/src/bun.js/webcore/blob/read_file.zig b/src/bun.js/webcore/blob/read_file.zig index ece0bcc56d..6dcacad14f 100644 --- a/src/bun.js/webcore/blob/read_file.zig +++ b/src/bun.js/webcore/blob/read_file.zig @@ -10,8 +10,7 @@ pub fn NewReadFileHandler(comptime Function: anytype) type { pub fn run(handler: *@This(), maybe_bytes: ReadFileResultType) void { var promise = handler.promise.swap(); - var blob = handler.context; - blob.allocator = null; + var blob = handler.context.takeOwnership(); const globalThis = handler.globalThis; bun.destroy(handler); switch (maybe_bytes) { From e14e42b4021f50ae5fd638872d61ac558847ecbf Mon Sep 17 00:00:00 2001 From: pfg Date: Fri, 26 Sep 2025 18:05:01 -0700 Subject: [PATCH 34/43] fix lint (#23019) Format action was failing --- src/js/eval/feedback.ts | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/js/eval/feedback.ts b/src/js/eval/feedback.ts index 0bfd1e136c..38f7b13161 100644 --- a/src/js/eval/feedback.ts +++ b/src/js/eval/feedback.ts @@ -71,9 +71,6 @@ function openTerminal(): TerminalIO | null { const logError = (message: string) => { process.stderr.write(`${symbols.cross} ${message}\n`); }; -const logInfo = (message: string) => { - process.stdout.write(`${bold}${message}${reset}\n`); -}; const isValidEmail = (value: string | undefined): value is string => { if (!value) return false; @@ -548,7 +545,7 @@ function getOldestGitSha(): string | undefined { } async function main() { - const rawArgv = [...process.argv.slice(1)]; + const rawArgv = process.argv.slice(1); let terminal: TerminalIO | null = null; try { @@ -634,7 +631,6 @@ async function main() { const form = new FormData(); form.append("email", normalizedEmail); - const fileList = positionalContent.files.map(file => file.filename); form.append("message", messageBody); for (const file of positionalContent.files) { form.append("files[]", new Blob([file.content]), file.filename); @@ -672,7 +668,7 @@ async function main() { try { const networkInterfaces = Object.entries(os.networkInterfaces() || {}); - for (const [name, interfaces] of networkInterfaces) { + for (const [_name, interfaces] of networkInterfaces) { for (const networkInterface of interfaces || []) { if (networkInterface.family === "IPv4") { if (networkInterface.internal) { From d3ce459f0ef8c87658b913f478e1703a714b34d2 Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Fri, 26 Sep 2025 18:57:06 -0700 Subject: [PATCH 35/43] fix(valkey/redis) fix tls (includes pub/sub) (#22981) ### What does this PR do? Fix tls property not being properly set Fixes https://github.com/oven-sh/bun/issues/22186 ### How did you verify your code works? Tests + Manually test with upstash using `rediss` protocol and tls: true options --------- Co-authored-by: Marko Vejnovic Co-authored-by: Marko Vejnovic Co-authored-by: Jarred Sumner Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/bun.js/api/server/SSLConfig.zig | 62 + src/deps/uws.zig | 10 + src/valkey/js_valkey.zig | 166 +- src/valkey/js_valkey_functions.zig | 6 +- src/valkey/valkey.zig | 24 +- test/js/valkey/docker-tls/Dockerfile | 10 +- test/js/valkey/docker-tls/server.crt | 48 +- test/js/valkey/docker-tls/server.key | 78 +- test/js/valkey/test-utils.ts | 14 +- test/js/valkey/valkey.failing-subscriber.ts | 26 +- test/js/valkey/valkey.test.ts | 1503 ++++++++++--------- 11 files changed, 1092 insertions(+), 855 deletions(-) diff --git a/src/bun.js/api/server/SSLConfig.zig b/src/bun.js/api/server/SSLConfig.zig index b304cbc5a1..fe309c8d4e 100644 --- a/src/bun.js/api/server/SSLConfig.zig +++ b/src/bun.js/api/server/SSLConfig.zig @@ -227,6 +227,68 @@ pub fn deinit(this: *SSLConfig) void { this.ca = null; } } +pub fn clone(this: *const SSLConfig) SSLConfig { + var cloned: SSLConfig = .{ + .secure_options = this.secure_options, + .request_cert = this.request_cert, + .reject_unauthorized = this.reject_unauthorized, + .client_renegotiation_limit = this.client_renegotiation_limit, + .client_renegotiation_window = this.client_renegotiation_window, + .requires_custom_request_ctx = this.requires_custom_request_ctx, + .is_using_default_ciphers = this.is_using_default_ciphers, + .low_memory_mode = this.low_memory_mode, + .protos_len = this.protos_len, + }; + const fields_cloned_by_memcopy = .{ + "server_name", + "key_file_name", + "cert_file_name", + "ca_file_name", + "dh_params_file_name", + "passphrase", + "protos", + }; + + if (!this.is_using_default_ciphers) { + if (this.ssl_ciphers) |slice_ptr| { + const slice = std.mem.span(slice_ptr); + if (slice.len > 0) { + cloned.ssl_ciphers = bun.handleOom(bun.default_allocator.dupeZ(u8, slice)); + } else { + cloned.ssl_ciphers = null; + } + } + } + + inline for (fields_cloned_by_memcopy) |field| { + if (@field(this, field)) |slice_ptr| { + const slice = std.mem.span(slice_ptr); + @field(cloned, field) = bun.handleOom(bun.default_allocator.dupeZ(u8, slice)); + } + } + + const array_fields_cloned_by_memcopy = .{ + "cert", + "key", + "ca", + }; + inline for (array_fields_cloned_by_memcopy) |field| { + if (@field(this, field)) |array| { + const cloned_array = bun.handleOom(bun.default_allocator.alloc([*c]const u8, @field(this, field ++ "_count"))); + @field(cloned, field) = cloned_array; + @field(cloned, field ++ "_count") = @field(this, field ++ "_count"); + for (0..@field(this, field ++ "_count")) |i| { + const slice = std.mem.span(array[i]); + if (slice.len > 0) { + cloned_array[i] = bun.handleOom(bun.default_allocator.dupeZ(u8, slice)); + } else { + cloned_array[i] = ""; + } + } + } + } + return cloned; +} pub const zero = SSLConfig{}; diff --git a/src/deps/uws.zig b/src/deps/uws.zig index 9050e98730..f25f48094e 100644 --- a/src/deps/uws.zig +++ b/src/deps/uws.zig @@ -71,6 +71,16 @@ pub const create_bun_socket_error_t = enum(c_int) { invalid_ca, invalid_ciphers, + pub fn message(this: create_bun_socket_error_t) ?[]const u8 { + return switch (this) { + .none => null, + .load_ca_file => "Failed to load CA file", + .invalid_ca_file => "Invalid CA file", + .invalid_ca => "Invalid CA", + .invalid_ciphers => "Invalid ciphers", + }; + } + pub fn toJS(this: create_bun_socket_error_t, globalObject: *jsc.JSGlobalObject) jsc.JSValue { return switch (this) { .none => brk: { diff --git a/src/valkey/js_valkey.zig b/src/valkey/js_valkey.zig index 65406c5b48..7594dbdb22 100644 --- a/src/valkey/js_valkey.zig +++ b/src/valkey/js_valkey.zig @@ -227,7 +227,7 @@ pub const JSValkeyClient = struct { poll_ref: bun.Async.KeepAlive = .{}, _subscription_ctx: ?SubscriptionCtx, - + _socket_ctx: ?*uws.SocketContext = null, timer: Timer.EventLoopTimer = .{ .tag = .ValkeyConnectionTimeout, .next = .{ @@ -362,6 +362,7 @@ pub const JSValkeyClient = struct { }, }, }, + .tls = if (options.tls != .none) options.tls else if (uri.isTLS()) .enabled else .none, .database = database, .allocator = this_allocator, .flags = .{ @@ -409,6 +410,8 @@ pub const JSValkeyClient = struct { const orig_hostname = this.client.address.hostname(); const hostname = bun.memory.rebaseSlice(orig_hostname, base_ptr, new_base); const new_alloc = this.client.allocator; + // TODO: we could ref count it instead of cloning it + const tls: valkey.TLS = this.client.tls.clone(); return JSValkeyClient.new(.{ .ref_count = .init(), @@ -438,6 +441,7 @@ pub const JSValkeyClient = struct { }, }, }, + .tls = tls, .database = this.client.database, .allocator = new_alloc, .flags = .{ @@ -537,7 +541,7 @@ pub const JSValkeyClient = struct { // If was manually closed, reset that flag this.client.flags.is_manually_closed = false; - this.this_value.setStrong(this_value, globalObject); + defer this.updatePollRef(); if (this.client.flags.needs_to_open_socket) { debug("Need to open socket, starting connection process.", .{}); @@ -565,7 +569,6 @@ pub const JSValkeyClient = struct { this.reconnect(); }, .failed => { - this.client.status = .disconnected; this.client.flags.is_reconnecting = true; this.client.retry_attempts = 0; this.reconnect(); @@ -735,8 +738,6 @@ pub const JSValkeyClient = struct { this.ref(); defer this.deref(); - this.client.status = .connecting; - // Ref the poll to keep event loop alive during connection this.poll_ref.disable(); this.poll_ref = .{}; @@ -755,14 +756,28 @@ pub const JSValkeyClient = struct { // Callback for when Valkey client connects pub fn onValkeyConnect(this: *JSValkeyClient, value: *protocol.RESPValue) void { bun.debugAssert(this.client.status == .connected); + // we should always have a strong reference to the object here + bun.debugAssert(this.this_value.isStrong()); + defer { + this.client.onWritable(); + // update again after running the callback + this.updatePollRef(); + } const globalObject = this.globalObject; const event_loop = this.client.vm.eventLoop(); event_loop.enter(); defer event_loop.exit(); if (this.this_value.tryGet()) |this_value| { - const hello_value: JSValue = value.toJS(globalObject) catch .js_undefined; + const hello_value: JSValue = js_hello: { + break :js_hello value.toJS(globalObject) catch |err| { + // TODO: how should we handle this? old code ignore the exception instead of cleaning it up + // now we clean it up, and behave the same as old code + _ = globalObject.takeException(err); + break :js_hello .js_undefined; + }; + }; js.helloSetCached(this_value, globalObject, hello_value); // Call onConnect callback if defined by the user if (js.onconnectGetCached(this_value)) |on_connect| { @@ -776,8 +791,7 @@ pub const JSValkeyClient = struct { const js_promise = promise.asPromise().?; if (this.client.flags.connection_promise_returns_client) { debug("Resolving connection promise with client instance", .{}); - const this_js = this.toJS(globalObject); - js_promise.resolve(globalObject, this_js); + js_promise.resolve(globalObject, this_value); } else { debug("Resolving connection promise with HELLO response", .{}); js_promise.resolve(globalObject, hello_value); @@ -785,9 +799,6 @@ pub const JSValkeyClient = struct { this.client.flags.connection_promise_returns_client = false; } } - - this.client.onWritable(); - this.updatePollRef(); } /// Invoked when the Valkey client receives a new listener. @@ -897,13 +908,15 @@ pub const JSValkeyClient = struct { // Callback for when Valkey client closes pub fn onValkeyClose(this: *JSValkeyClient) void { const globalObject = this.globalObject; - this.poll_ref.disable(); - defer this.deref(); + + defer { + // Update poll reference to allow garbage collection of disconnected clients + this.updatePollRef(); + this.deref(); + } const this_jsvalue = this.this_value.tryGet() orelse return; - this.this_value.setWeak(this_jsvalue); - this.ref(); - defer this.deref(); + this_jsvalue.ensureStillAlive(); // Create an error value const error_value = protocol.valkeyErrorToJS(globalObject, "Connection closed", protocol.RedisError.ConnectionClosed); @@ -927,9 +940,6 @@ pub const JSValkeyClient = struct { &[_]JSValue{error_value}, ) catch |e| globalObject.reportActiveExceptionAsUnhandled(e); } - - // Update poll reference to allow garbage collection of disconnected clients - this.updatePollRef(); } // Callback for when Valkey client times out @@ -956,15 +966,39 @@ pub const JSValkeyClient = struct { } } - pub fn finalize(this: *JSValkeyClient) void { + fn closeSocketNextTick(this: *JSValkeyClient) void { + if (this.client.socket.isClosed()) return; + this.ref(); + // socket close can potentially call JS so we need to enqueue the deinit + const Holder = struct { + ctx: *JSValkeyClient, + task: jsc.AnyTask, + + pub fn run(self: *@This()) void { + defer bun.default_allocator.destroy(self); + + self.ctx.client.close(); + self.ctx.deref(); + } + }; + var holder = bun.handleOom(bun.default_allocator.create(Holder)); + holder.* = .{ + .ctx = this, + .task = undefined, + }; + holder.task = jsc.AnyTask.New(Holder, Holder.run).init(holder); + + this.client.vm.enqueueTask(jsc.Task.init(&holder.task)); + } + + pub fn finalize(this: *JSValkeyClient) void { defer this.deref(); this.stopTimers(); this.this_value.finalize(); this.client.flags.finalized = true; - this.client.close(); - + this.closeSocketNextTick(); // We do not need to free the subscription context here because we're // guaranteed to have freed it by virtue of the fact that we are // garbage collected now and the subscription context holds a reference @@ -983,17 +1017,27 @@ pub const JSValkeyClient = struct { } } + fn failWithInvalidSocketContext(this: *JSValkeyClient) void { + // if the context is invalid is not worth retrying + this.client.flags.enable_auto_reconnect = false; + this.clientFail(if (this.client.tls == .none) "Failed to create TCP context" else "Failed to create TLS context", protocol.RedisError.ConnectionClosed); + this.client.onValkeyClose(); + } + fn connect(this: *JSValkeyClient) !void { debug("Connecting to Redis.", .{}); this.client.flags.needs_to_open_socket = false; const vm = this.client.vm; - const ctx: *uws.SocketContext, const deinit_context: bool = + const ctx: *uws.SocketContext, const own_ctx: bool = switch (this.client.tls) { .none => .{ vm.rareData().valkey_context.tcp orelse brk_ctx: { // TCP socket - const ctx_ = uws.SocketContext.createNoSSLContext(vm.uwsLoop(), @sizeOf(*JSValkeyClient)).?; + const ctx_ = uws.SocketContext.createNoSSLContext(vm.uwsLoop(), @sizeOf(*JSValkeyClient)) orelse { + this.failWithInvalidSocketContext(); + return; + }; uws.NewSocketHandler(false).configure(ctx_, true, *JSValkeyClient, SocketHandler(false)); vm.rareData().valkey_context.tcp = ctx_; break :brk_ctx ctx_; @@ -1004,7 +1048,10 @@ pub const JSValkeyClient = struct { vm.rareData().valkey_context.tls orelse brk_ctx: { // TLS socket, default config var err: uws.create_bun_socket_error_t = .none; - const ctx_ = uws.SocketContext.createSSLContext(vm.uwsLoop(), @sizeOf(*JSValkeyClient), uws.SocketContext.BunSocketContextOptions{}, &err).?; + const ctx_ = uws.SocketContext.createSSLContext(vm.uwsLoop(), @sizeOf(*JSValkeyClient), uws.SocketContext.BunSocketContextOptions{}, &err) orelse { + this.failWithInvalidSocketContext(); + return; + }; uws.NewSocketHandler(true).configure(ctx_, true, *JSValkeyClient, SocketHandler(true)); vm.rareData().valkey_context.tls = ctx_; break :brk_ctx ctx_; @@ -1012,32 +1059,36 @@ pub const JSValkeyClient = struct { false, }, .custom => |*custom| brk_ctx: { + if (this._socket_ctx) |ctx| { + break :brk_ctx .{ ctx, true }; + } // TLS socket, custom config var err: uws.create_bun_socket_error_t = .none; const options = custom.asUSockets(); - const ctx_ = uws.SocketContext.createSSLContext(vm.uwsLoop(), @sizeOf(*JSValkeyClient), options, &err).?; + + const ctx_ = uws.SocketContext.createSSLContext(vm.uwsLoop(), @sizeOf(*JSValkeyClient), options, &err) orelse { + this.failWithInvalidSocketContext(); + return; + }; uws.NewSocketHandler(true).configure(ctx_, true, *JSValkeyClient, SocketHandler(true)); break :brk_ctx .{ ctx_, true }; }, }; this.ref(); - defer { - if (deinit_context) { - // This is actually unref(). uws.Context is reference counted. - ctx.deinit(true); - } + if (own_ctx) { + // save the context so we deinit it later (if we reconnect we can reuse the same context) + this._socket_ctx = ctx; } + this.client.status = .connecting; + this.updatePollRef(); this.client.socket = try this.client.address.connect(&this.client, ctx, this.client.tls != .none); } - pub fn send(this: *JSValkeyClient, globalThis: *jsc.JSGlobalObject, this_jsvalue: JSValue, command: *const Command) !*jsc.JSPromise { + pub fn send(this: *JSValkeyClient, globalThis: *jsc.JSGlobalObject, _: JSValue, command: *const Command) !*jsc.JSPromise { if (this.client.flags.needs_to_open_socket) { @branchHint(.unlikely); - if (this.this_value != .strong) - this.this_value.setStrong(this_jsvalue, globalThis); - this.connect() catch |err| { this.client.flags.needs_to_open_socket = true; const err_value = globalThis.ERR(.SOCKET_CLOSED_BEFORE_CONNECTION, " {s} connecting to Valkey", .{@errorName(err)}).toJS(); @@ -1073,14 +1124,38 @@ pub const JSValkeyClient = struct { return memory_cost; } + fn deinitSocketContextNextTick(this: *JSValkeyClient) void { + const ctx = this._socket_ctx orelse return; + this._socket_ctx = null; + // socket close can potentially call JS so we need to enqueue the deinit + // this should only be the case tls socket with custom config + const Holder = struct { + ctx: *uws.SocketContext, + task: jsc.AnyTask, + + pub fn run(self: *@This()) void { + defer bun.default_allocator.destroy(self); + self.ctx.deinit(true); + } + }; + var holder = bun.handleOom(bun.default_allocator.create(Holder)); + holder.* = .{ + .ctx = ctx, + .task = undefined, + }; + holder.task = jsc.AnyTask.New(Holder, Holder.run).init(holder); + + this.client.vm.enqueueTask(jsc.Task.init(&holder.task)); + } + fn deinit(this: *JSValkeyClient) void { bun.debugAssert(this.client.socket.isClosed()); - + this.deinitSocketContextNextTick(); this.client.deinit(null); this.poll_ref.disable(); this.stopTimers(); - this.this_value.finalize(); this.ref_count.assertNoRefs(); + bun.destroy(this); } @@ -1103,7 +1178,7 @@ pub const JSValkeyClient = struct { else true; - const has_activity = has_pending_commands or !subs_deletable; + const has_activity = has_pending_commands or !subs_deletable or this.client.flags.is_reconnecting; // There's a couple cases to handle here: if (has_activity) { @@ -1117,6 +1192,7 @@ pub const JSValkeyClient = struct { } if (this.this_value.isEmpty()) { + debug("this_value is empty, skipping updatePollRef", .{}); return; } @@ -1135,16 +1211,19 @@ pub const JSValkeyClient = struct { // // It is 100% safe to drop the strong reference there and let // the object be GC'd, but we're not doing that now. + debug("upgrading this_value since we are connected/connecting", .{}); this.this_value.upgrade(this.globalObject); }, .disconnected, .failed => { // If we're disconnected or failed, we need to check if we have // any pending activity. if (has_activity) { + debug("upgrading this_value since there is pending activity", .{}); // If we have pending activity, we need to keep the object // alive. this.this_value.upgrade(this.globalObject); } else { + debug("downgrading this_value since there is no pending activity", .{}); // If we don't have any pending activity, we can drop the // strong reference. this.this_value.downgrade(); @@ -1244,10 +1323,16 @@ fn SocketHandler(comptime ssl: bool) type { } fn onHandshake_(this: *JSValkeyClient, _: anytype, success: i32, ssl_error: uws.us_bun_verify_error_t) void { - debug("onHandshake: {d} {d}", .{ success, ssl_error.error_no }); + debug("onHandshake: {d} error={d} reason={s} code={s}", .{ + success, + ssl_error.error_no, + if (ssl_error.reason != null) bun.span(ssl_error.reason[0..bun.len(ssl_error.reason) :0]) else "no reason", + if (ssl_error.code != null) bun.span(ssl_error.code[0..bun.len(ssl_error.code) :0]) else "no code", + }); const handshake_success = if (success == 1) true else false; this.ref(); defer this.deref(); + defer this.updatePollRef(); if (handshake_success) { const vm = this.client.vm; if (this.client.tls.rejectUnauthorized(vm)) { @@ -1262,14 +1347,15 @@ fn SocketHandler(comptime ssl: bool) type { const loop = vm.eventLoop(); loop.enter(); defer loop.exit(); - this.client.status = .failed; this.client.flags.is_manually_closed = true; this.client.failWithJSValue(this.globalObject, ssl_error.toJS(this.globalObject)); this.client.close(); + return; } } } } + this.client.start(); } } diff --git a/src/valkey/js_valkey_functions.zig b/src/valkey/js_valkey_functions.zig index 9361e404cf..7afeb4a19d 100644 --- a/src/valkey/js_valkey_functions.zig +++ b/src/valkey/js_valkey_functions.zig @@ -954,12 +954,8 @@ pub fn duplicate( var new_client: *JSValkeyClient = try this.cloneWithoutConnecting(globalObject); const new_client_js = new_client.toJS(globalObject); - new_client.this_value = - if (this.client.status == .connected and !this.client.flags.is_manually_closed) - jsc.JSRef.initStrong(new_client_js, globalObject) - else - jsc.JSRef.initWeak(new_client_js); + new_client.this_value = jsc.JSRef.initWeak(new_client_js); // If the original client is already connected and not manually closed, start connecting the new client. if (this.client.status == .connected and !this.client.flags.is_manually_closed) { // Use strong reference during connection to prevent premature GC diff --git a/src/valkey/valkey.zig b/src/valkey/valkey.zig index f2c22bb2c6..ad23100e2c 100644 --- a/src/valkey/valkey.zig +++ b/src/valkey/valkey.zig @@ -87,6 +87,13 @@ pub const TLS = union(enum) { enabled, custom: jsc.API.ServerConfig.SSLConfig, + pub fn clone(this: *const TLS) TLS { + return switch (this.*) { + .custom => |*ssl_config| .{ .custom = ssl_config.clone() }, + else => this.*, + }; + } + pub fn deinit(this: *TLS) void { switch (this.*) { .custom => |*ssl_config| ssl_config.deinit(), @@ -457,6 +464,9 @@ pub const ValkeyClient = struct { /// Handle connection closed event pub fn onClose(this: *ValkeyClient) void { + this.socket = .{ .SocketTCP = .detached }; + this.status = .disconnected; + this.unregisterAutoFlusher(); this.write_buffer.clearAndFree(this.allocator); @@ -489,7 +499,6 @@ pub const ValkeyClient = struct { debug("reconnect in {d}ms (attempt {d}/{d})", .{ delay_ms, this.retry_attempts, this.max_retries }); - this.status = .disconnected; this.flags.is_reconnecting = true; this.flags.is_authenticated = false; this.flags.is_selecting_db_internal = false; @@ -945,11 +954,15 @@ pub const ValkeyClient = struct { this.socket = socket; this.write_buffer.clearAndFree(this.allocator); this.read_buffer.clearAndFree(this.allocator); - this.start(); + if (this.socket == .SocketTCP) { + // if is tcp, we need to start the connection process + // if is tls, we need to wait for the handshake to complete + this.start(); + } } /// Start the connection process - fn start(this: *ValkeyClient) void { + pub fn start(this: *ValkeyClient) void { this.authenticate(); _ = this.flushData(); } @@ -1064,7 +1077,7 @@ pub const ValkeyClient = struct { const js_promise = promise.promise.get(); // Handle disconnected state with offline queue switch (this.status) { - .connecting, .connected => { + .connected => { try this.enqueue(command, &promise); // Schedule auto-flushing to process this command if pipelining is enabled @@ -1076,7 +1089,7 @@ pub const ValkeyClient = struct { this.registerAutoFlusher(this.vm); } }, - .disconnected => { + .connecting, .disconnected => { // Only queue if offline queue is enabled if (this.flags.enable_offline_queue) { try this.enqueue(command, &promise); @@ -1104,7 +1117,6 @@ pub const ValkeyClient = struct { this.flags.is_manually_closed = true; this.unregisterAutoFlusher(); if (this.status == .connected or this.status == .connecting) { - this.status = .disconnected; this.close(); } } diff --git a/test/js/valkey/docker-tls/Dockerfile b/test/js/valkey/docker-tls/Dockerfile index 6b043ca1d6..3733f03c7b 100644 --- a/test/js/valkey/docker-tls/Dockerfile +++ b/test/js/valkey/docker-tls/Dockerfile @@ -16,7 +16,7 @@ RUN echo '#!/bin/bash\n\ set -e\n\ \n\ # Wait for Redis to start\n\ -until redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key --cacert /etc/redis/certs/server.crt ping; do\n\ +until redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key ping; do\n\ echo "Waiting for Redis TLS to start..."\n\ sleep 1\n\ done\n\ @@ -24,16 +24,16 @@ done\n\ echo "Redis TLS is ready!"\n\ \n\ # Set up some test data for persistence tests\n\ -redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key --cacert /etc/redis/certs/server.crt set bun_valkey_tls_test_init "initialization_successful"\n\ +redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key set bun_valkey_tls_test_init "initialization_successful"\n\ \n\ # Create test hash\n\ -redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key --cacert /etc/redis/certs/server.crt hset bun_valkey_tls_test_hash name "test_user" age "25" active "true"\n\ +redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key hset bun_valkey_tls_test_hash name "test_user" age "25" active "true"\n\ \n\ # Create test set\n\ -redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key --cacert /etc/redis/certs/server.crt sadd bun_valkey_tls_test_set "red" "green" "blue"\n\ +redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key sadd bun_valkey_tls_test_set "red" "green" "blue"\n\ \n\ # Create test list\n\ -redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key --cacert /etc/redis/certs/server.crt lpush bun_valkey_tls_test_list "first" "second" "third"\n\ +redis-cli --tls --cert /etc/redis/certs/server.crt --key /etc/redis/certs/server.key lpush bun_valkey_tls_test_list "first" "second" "third"\n\ ' > /docker-entrypoint-initdb.d/init-redis.sh # Make the script executable diff --git a/test/js/valkey/docker-tls/server.crt b/test/js/valkey/docker-tls/server.crt index 4e79fbc049..49ea0e1e8a 100644 --- a/test/js/valkey/docker-tls/server.crt +++ b/test/js/valkey/docker-tls/server.crt @@ -1,19 +1,33 @@ -----BEGIN CERTIFICATE----- -MIIDHzCCAgegAwIBAgIUOvkvGE7rI3OXABlz71VQMatWElgwDQYJKoZIhvcNAQEL -BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI1MDQwNzEwMjgzN1oXDTI2MDQw -NzEwMjgzN1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF -AAOCAQ8AMIIBCgKCAQEA88KqRAdx13qQcROKeSotdpfUPzPekpbNfetNZjBsmf6N -t4mtAhnAaJpPKkWvs1pDA5/qD3ZxAcLEa31y9AY76TvgZKq0yiD2MTOYFFTcstx5 -Voi2MLrSYN8Xobq7K7r5zQD7TrHEu0S3sSdA8GDtyrx2W8owuNtqUt1oBDRYRZoT -Nu3/bwjzuBGUrrdYwzBQvr5XOA3v2yAexgffOeSz8cZtvR+BL0sxu6SDN1VpQe// -KHQy1jZEHX0mOvRoB+95MfxHEgC7O8fYcrxsHkpjvacjh7TrOllbkcEAmr/exOCw -MLnZl57Xi7bQMVAPM1TR41mSmvHessPuPXCrzVKn0QIDAQABo2kwZzAdBgNVHQ4E -FgQUJszPLUfpqnggGY7NuVuGl388G44wHwYDVR0jBBgwFoAUJszPLUfpqnggGY7N -uVuGl388G44wDwYDVR0TAQH/BAUwAwEB/zAUBgNVHREEDTALgglsb2NhbGhvc3Qw -DQYJKoZIhvcNAQELBQADggEBAG4R3o6EZnODINfNIrM+Cag9ATmyEqm4MNMTyH9e -58ltgU+k5RQKBywdxlC/71BW2I4lsbMz8qS+fcFTOC5a87rEO2qCWFw9Ew4mKJkA -4gz1RBS4xShNyQewYV2U+ZhrDqp5tnwn+ZXGgMN5Jl0EwNeL6Q5U0zERfDbaE4xZ -BHrGvnHh9Rm7nkSG9uAIITJ71uKjO5ogPgzzPe++47Xug0o4e3gn3De7WATaSuYa -Oe9sIYB1YuZoQQoa1u+74+sguKV8/RdkP0rxaSKuGl8KUooNH6MLPnT8n+y+7mQS -gIAFeezbuqGrFPL2P6ZXmEX39Tlz9f9OmqpmzruUZd1lvBs= +MIIFxjCCA66gAwIBAgIUDfpkxHY/sHFNJv/Zn6XgYDg+Y98wDQYJKoZIhvcNAQEL +BQAwYjELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQHDA1TYW4gRnJh +bmNpc2NvMQwwCgYDVQQKDANCdW4xDDAKBgNVBAsMA0J1bjESMBAGA1UEAwwJbG9j +YWxob3N0MB4XDTI1MDEyMzAxMjA1OFoXDTM1MDEyMTAxMjA1OFowYjELMAkGA1UE +BhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQHDA1TYW4gRnJhbmNpc2NvMQwwCgYD +VQQKDANCdW4xDDAKBgNVBAsMA0J1bjESMBAGA1UEAwwJbG9jYWxob3N0MIICIjAN +BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAycLMJ6rxyy8uxoOmYeOH1VQNmSXD +KgQhRvkbd+CtOHUke8sW5WrZiV6aVHYCd7P+Phbyt1SXdvy0ZPiS+umfVrSt6QWV +s6H8Aw1gcDX7aoaCoqFpx6/PZpbnZ4HSTqZTdwbrwaJTCS9zRornVaB0yyhQ1VOL +XNqQxN74Fa3mh02Q2gaacEIRwAmGM/Lfbu3zKzaHtoJcH+IIRZ2nk05WzBOjmtQR +CDI7nAHFr69MFH+lUF7gYrl2FF1gIl2xxAFA3x7CeTPVqfM4qhzICdeacjN48jrM +1V+gZKp3JpDxOygUDJkR4tufpHHllKreDnw0SJxCzWEj9V1PhaTyN4+hFAmWj1ia +90ZlQQcMVceIEwFeW2goRKCh690y3PYqZBeHaOKi48Uyvd3betnv8NaCofbJ95oM +l+744nWpIcMTVi12Aszps8uAWONbO+1eyrjSnx/Bl8ZcnXrRB2S7C2PJnBIBzQIG +a75i7St7L5qW2In+y6a4F2qe2zRNTWuGssnhmWEt4ZKIfv1Mfqr+q67xl8VWii7k +7DT+1lv8wF9vJiieJuL9gYkmtFcj+XgbYW1auEtyKL/Liz/Dny54PoJ3bQeOqo20 +VgkcPfXwxUj6CpRJ8l2xi2Jfmt75EQFuTvGo3zNUmQYbqLRocfkYjxL3kVjzAggX +OqXfPxw5ngA0yIcCAwEAAaN0MHIwHQYDVR0OBBYEFGk2RthCDB9NGIJKHa9gP9s6 +bgr8MB8GA1UdIwQYMBaAFGk2RthCDB9NGIJKHa9gP9s6bgr8MA8GA1UdEwEB/wQF +MAMBAf8wHwYDVR0RBBgwFoIJbG9jYWxob3N0ggkxMjcuMC4wLjEwDQYJKoZIhvcN +AQELBQADggIBACDkcgDj9w6tY9q/LkGFBT2gWRQnb/3AaXWFv0cWMO7iFGdaUesP +dT7KuOweIZAz5f7PToOWwUN5Y5W774OzY8Fy6WIfo+fUzut3vO5M3FSTqM4Yrm/d +Vapfoa0fNMwKrnO5RyKZjUqeLUtwownFY67qCbg5xdlImb1GXtBplnJKZN50cQqL +08aZWUPEwpzGqPMNZWFufA9A/bx6SY8n3JJVnpvXq5P4ndK5Slq129QUcbCk89r9 +6Iog+1dTTifIaHIJ5suKbgSTBoRSs8J/xgnqcaBrwpLkpvg21QvlRjvxGxwQ5ybR +2Z5KCWa+QzpLYlYV0OfPKsKQRQ5TuCYd6y9n8zQtjzjINuZysw/YMvlSKuiR53Wk +2vjjuL91ICtV0Ye6Mj7GzPBdmBdthyLRCTKn5TVWFPBm/pAANus8v3mCgiFBPl/Y +G4cC1yaXKGiD9jvQOSkZTNP0kvtOLVI75cHiGap13XF8MeOsv4AhnUgDp7Ow3XPG +AJhs37tweYTsW8sAQinLpFM63xU9xZgutKggopftRzvQe5flfKhxV0D91WZgcjyE +vHmM8/DpU4/udEPFrqYb9NcYsCEdwVuFT1TC5ZuOqFfQZUuCco3sUvBFAqYqfxoq +LCjHe/xxbnhU7PBRHgoo7oKGldlvIqkIB9pTlIolXL0XaOMoqoGAmWKC -----END CERTIFICATE----- diff --git a/test/js/valkey/docker-tls/server.key b/test/js/valkey/docker-tls/server.key index 8185ea8b0c..e2ee5a9ea1 100644 --- a/test/js/valkey/docker-tls/server.key +++ b/test/js/valkey/docker-tls/server.key @@ -1,28 +1,52 @@ -----BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDzwqpEB3HXepBx -E4p5Ki12l9Q/M96Sls19601mMGyZ/o23ia0CGcBomk8qRa+zWkMDn+oPdnEBwsRr -fXL0BjvpO+BkqrTKIPYxM5gUVNyy3HlWiLYwutJg3xehursruvnNAPtOscS7RLex -J0DwYO3KvHZbyjC422pS3WgENFhFmhM27f9vCPO4EZSut1jDMFC+vlc4De/bIB7G -B9855LPxxm29H4EvSzG7pIM3VWlB7/8odDLWNkQdfSY69GgH73kx/EcSALs7x9hy -vGweSmO9pyOHtOs6WVuRwQCav97E4LAwudmXnteLttAxUA8zVNHjWZKa8d6yw+49 -cKvNUqfRAgMBAAECggEAVuBq6asTll4+66YwxKVVJb7QLSRx76HipD3ATKr2kd3p -KWBesnB2JHHWxDSo/c2uM7UDaTZn6V4+viasWS99m8801vwGSkH8LKX8Tka+j9rH -PiGkeXKkN1VbqU8RlXDixf9TEgWGnc3MgE2Ctgl9xrNrpaRGwCOnXdg+Ub1MNqWU -nmxMGP+G2b0ZgbirwECpdGIOvdeygQ8/Jo6brbGaPcewo1/5545wWHPpb+zXrBra -E//3i6geb9NJUAMKl0cTURoZWY3pF4yElV2ZZcE+cH5/fesYN7ZTZAccvo1lBPjy -OiC/fEuhSAH5iobH1KPrQbOby2YdUUOxg3dXR6qV5wKBgQD88jFV/a4togqaEv9x -yoCOtSE6CQmRZmCGhZSk60yR6JdFWQ1Uvc8brJdc2MRC/WRWr+nI6jGgobEbMZFO -KjHKALwOokFiSzj8I9u+IxGLM7TuGJFtKAT3AEwbQrWp3CxCrKsrIYwVkSwtViG0 -cOVXQCkmuw3ewneHdnnWxiJq/wKBgQD2tBTQRWc/fOeR9TucGp6yQpFQaYL8qA4m -FpihxbX2JcTS8ge28ggc+c1seMqYfSBtnIRCWWdDB7BjquQT/lnQqZq5TEW1/FbD -9MML0VCM5TCAH1v6rcgMRRpsRYhB+vhbLcenSZGQoojozxEQts8/1HCSHTw0PzjK -dRVClb39LwKBgHGdY7WpPZw3paVxFRYKjFYNW8BSoN6TapXh2FN/cSQ0ogW/KzK+ -ExHuIwrMPtOMN46Mc2kQcHwjRIbfa9H9N+HxFIdKMC4zdYQjoyczX0T0U7eCh4fN -KvW7R3QTMb/7KlJEdpnn9qEVVQ+EGZ2P/COFqTZBXMiK9t98wttKodkHAoGAfGV1 -kUdNto+u3MRBWId7ufsi9t8dM3UyHTaLpBbjl8iXpJ5yEWede67iTG3kClwdu+eO -MT6PeRcpdDg5ZXN9ql+7KvAwvoEM5yZGK3FSIpl2iURGxvJVywoVNr8g49Q+4wsE -f2/zPHEYg/vVaQ4lFtRyJtsi/l1ar4u2Oqry7/UCgYEAh9+0SIBhdVny48EqQXmi -5WFiKhb5BcEUuLFlZji/z+y6a7LopCisDoegvboiDByxPizbKGzIChxLqO3SQKXb -kKiGsAITmMZ2Kl6jRhMUDgq4/DsjBo/h3guk+xQZ2DHtT9v1FLvBqCi8poP4XUMy -9BbnrT03dl7N2+9fIdjnC9Y= ------END PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQDJwswnqvHLLy7G +g6Zh44fVVA2ZJcMqBCFG+Rt34K04dSR7yxblatmJXppUdgJ3s/4+FvK3VJd2/LRk ++JL66Z9WtK3pBZWzofwDDWBwNftqhoKioWnHr89mludngdJOplN3BuvBolMJL3NG +iudVoHTLKFDVU4tc2pDE3vgVreaHTZDaBppwQhHACYYz8t9u7fMrNoe2glwf4ghF +naeTTlbME6Oa1BEIMjucAcWvr0wUf6VQXuBiuXYUXWAiXbHEAUDfHsJ5M9Wp8ziq +HMgJ15pyM3jyOszVX6BkqncmkPE7KBQMmRHi25+kceWUqt4OfDRInELNYSP1XU+F +pPI3j6EUCZaPWJr3RmVBBwxVx4gTAV5baChEoKHr3TLc9ipkF4do4qLjxTK93dt6 +2e/w1oKh9sn3mgyX7vjidakhwxNWLXYCzOmzy4BY41s77V7KuNKfH8GXxlydetEH +ZLsLY8mcEgHNAgZrvmLtK3svmpbYif7LprgXap7bNE1Na4ayyeGZYS3hkoh+/Ux+ +qv6rrvGXxVaKLuTsNP7WW/zAX28mKJ4m4v2BiSa0VyP5eBthbVq4S3Iov8uLP8Of +Lng+gndtB46qjbRWCRw99fDFSPoKlEnyXbGLYl+a3vkRAW5O8ajfM1SZBhuotGhx ++RiPEveRWPMCCBc6pd8/HDmeADTIhwIDAQABAoICAEiGc2iW9E+7aC8Hx9lMNtmi +Wzj/8AW8clHW3d7brqiqwzCUsmhJXmUY0pUlzoFE/FFJYnowODoXYKkjCYKUVCiQ +zisDTOrDgZl/R3lOjk+ehnr7VtDnC8Cu4gO9EOIgu8P/guOZ/AtDOUbUS4/mG9Wj +alskquX30y5RkBAK8OEWKsmUshNETKkhQ1KNLW/srQqNkX8zoPX9BEgyAbjb4it9 +q8POE0lE9VSA9pTOiKSdtckMMdCLJjzvy8zOrUXtxWnu3q0+ysFKosXTjryq+eOv +SPyZ0mOo+jj1ZdtBItXG9F4K7/kCRYKRRpuISEYgs5KeSQ0WrBxZLGq3/jGmuZmb ++knLcL2iWf9tC93TcQxlYVyz8v4p8/cjW1elCe1JRYkDEriLvAwbMt+WZCIdPvSz +p2SK3x979vbRPDbhvc0gLjpKGGpW1yBgnh+Il+V4Nnl27IxY8kC2BSwENb2+ikTI +EDo+VfmfvZswKrSYcwWj2ml2WF09qUksvNeam075HbZ3AUOgXMrxr1jedaMD6M0O +hhLOOPoGBttmoowlD6wfkWmEfUU8xuxAtfJdnZkBF2Kh5MACN8YcYwmYu3WY0eUL +QM2zC4ReL+E9coWtDcSb9zg+om91wxk6ZqwClIJ7H4hUE1+yEnSAKSRa+vvtY2qt +bO3v109W2g19sqx2zP3xAoIBAQD9FzEGdqk8PF9gQbFcN9r3BQe6LafBQsZ5Ktvd ++gkC2urcG0XQtIFVTfiov9Y19/UdSjvuXMKGUTv9AFFDe+2URkX1RFUSVzMSIXKD +7RfcZ8eHv03DihfqNmZ6YhLfaA3WJpzGP4nPT14CD2712ne2dfqBav4Yb0tlGYR0 +4uVJSePJNRoQJ6tjAZzvpiswV3xQnmUCUIy8rnbTmqnY4tHgwAMfIEEKwPV54fHV +l2ZfClscBDxxkElWmZwYvu0k2LgS5st2d5R48iWCitbt9sP1+aMhV5gsirn6GcFR +Uj1sKOC5TQCOx3W9zb3563lYioUgklj4ku94GAdv0oNLzJBlAoIBAQDMFIzgjmWF +lrm3L7c57NU8HxoHIiqsiQ7s8puHfcRupFbPvgU01v+JEFCEYxt1sXLQdO3qdQTG +tod/sJ2TuyajGqEVxlA8LThsjN9mBDRC+pHmk2P3Z9tjSm5kO29wtkfQOHGlP2VR +Cb9N3oqDqVawXnGj25a+zfgFjs01HTB+hT2Hi0zkdRb+Tq3bF86F6A4ebLcXG/HF +BiMvH7SC5h6bZR2Bw4tHTREWIfB4uOUvNt+dzJ4N2+MKuuNr6Gk0VOarb9qHQsLO +H8zNrp4kNOtGZzblTQoM1f9095VPCrEX8NdAderzfcTrXzZww1dQ8DABnPphHOTm +Fe7NrNLso0h7AoIBADNTv7qK2BmCOOmBiSGlpj+QgpesaKgWDcBHA94Jtkgg8559 +3XPNF6mgLXyzoxLA3bH5+xuFLmIlGWBe7xwbhvwaIFf0arhUfOQBaoL802j8lwed +sXylheIW9EN/nko2hQ/YNtUxz5X+h5ctYBh2HO8hEBOtCikUcRroyOcXmN57ILoO +jeGW2fgzPIuRjJK6O1jyNpP4mAIv86NIa4ezwFKvPjLSzL4MkfwM6Ymisb02kXGm +Hkf9thHdBz4xglCFrxcOPVciOzcoDJlj5ODPucApx36ckB0AaWUiUgVXA2PrCmAq +EKHkK6m5jvyfV7WwKf2IEIkg63XUkbWI4N2/d80CggEBAJwbQCPpaMkGIat5qWt6 +uSXTGKLKROBTuwIPFl9PGfoUZX9leDASIcfjneOWuAOQKCZCu1b0CiJCr2VCYVcG ++qgbD4tLdkaBxL5sB9rObnepmf9JUVeHry7FWan8OON723TwKCZiVwrlLNvQ1h2e +Y/xnUgAoUahEf2so79moKVcuboGHUdsTofIHlz+Xd1fAyUQGnwrjSk4Ows0iMH9M +ra7qaua/AIQa9G38qih+LnmuPOFFCsXJJGQpzxrU3dy08PnEhuGedMsdUhkncDp7 +7FifTUObaYumClGbrS+YGx0YEl9xk7aLxxzQaSFamykDgYVKYc/1PTavIktb3sA6 +qo8CggEANSBmEGXRAecktzHvl1FhSKcqjdgpPwrQbqknhyjpAHCUkfTOolVe1BQB +4HJJAnwfVm3hP4zWsYJmE4H8TfdVdayZY2tN8ECU7X/WgGci6VIChMu0nXS2uAu0 +B/3pdOoChyaf25kIeZfB+NB2QRhYGU5VMtSW6VID9PbXTZ6U7MopYE9lY/sUTjIR +wRi2MkiNkjTalllqZnAJQV1EjG2SsrlxyPRRPPjqumqW6/cRiOLCCdiLbbYykfDV +AwfXoIFiYo5Cljm6bGjDKGDTaFjQzEmFUcAzs6QjG+BzFOLwFuCQoNOF8FZ1y4y3 +AWDbBPL8WN2F2/Q0QBxC2BECKSVxhg== +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/test/js/valkey/test-utils.ts b/test/js/valkey/test-utils.ts index 04deb19046..ee09d55884 100644 --- a/test/js/valkey/test-utils.ts +++ b/test/js/valkey/test-utils.ts @@ -80,10 +80,16 @@ export const DEFAULT_REDIS_OPTIONS = { export const TLS_REDIS_OPTIONS = { ...DEFAULT_REDIS_OPTIONS, db: 1, - tls: true, - tls_cert_file: path.join(import.meta.dir, "docker-unified", "server.crt"), - tls_key_file: path.join(import.meta.dir, "docker-unified", "server.key"), - tls_ca_file: path.join(import.meta.dir, "docker-unified", "server.crt"), + tls: { + cert: Bun.file(path.join(import.meta.dir, "docker-unified", "server.crt")), + key: Bun.file(path.join(import.meta.dir, "docker-unified", "server.key")), + ca: Bun.file(path.join(import.meta.dir, "docker-unified", "server.crt")), + }, + tlsPaths: { + cert: path.join(import.meta.dir, "docker-unified", "server.crt"), + key: path.join(import.meta.dir, "docker-unified", "server.key"), + ca: path.join(import.meta.dir, "docker-unified", "server.crt"), + }, }; export const UNIX_REDIS_OPTIONS = { diff --git a/test/js/valkey/valkey.failing-subscriber.ts b/test/js/valkey/valkey.failing-subscriber.ts index 872ea5a072..379718e43c 100644 --- a/test/js/valkey/valkey.failing-subscriber.ts +++ b/test/js/valkey/valkey.failing-subscriber.ts @@ -3,7 +3,7 @@ // // DO NOT IMPORT FROM test-utils.ts. That import is janky and will have different state at different from different // importers. -import {RedisClient} from "bun"; +import { RedisClient } from "bun"; function trySend(msg: any) { if (process === undefined || process.send === undefined) { @@ -13,14 +13,18 @@ function trySend(msg: any) { process.send(msg); } -let redisUrlResolver: (url: string) => void; -const redisUrl = new Promise((resolve) => { +export interface RedisTestStartMessage { + tlsPaths?: { cert: string; key: string; ca: string }; + url: string; +} +let redisUrlResolver: (msg: RedisTestStartMessage) => void; +const redisUrl = new Promise(resolve => { redisUrlResolver = resolve; }); process.on("message", (msg: any) => { if (msg.event === "start") { - redisUrlResolver(msg.url); + redisUrlResolver(msg); } else { throw new Error("Unknown event " + msg.event); } @@ -29,8 +33,16 @@ process.on("message", (msg: any) => { const CHANNEL = "error-callback-channel"; // We will wait for the parent process to tell us to start. -const url = await redisUrl; -const subscriber = new RedisClient(url); +const { url, tlsPaths } = await redisUrl; +const subscriber = new RedisClient(url, { + tls: tlsPaths + ? { + cert: Bun.file(tlsPaths.cert), + key: Bun.file(tlsPaths.key), + ca: Bun.file(tlsPaths.ca), + } + : undefined, +}); await subscriber.connect(); trySend({ event: "ready" }); @@ -43,6 +55,6 @@ await subscriber.subscribe(CHANNEL, () => { trySend({ event: "message", index: counter }); }); -process.on("uncaughtException", (e) => { +process.on("uncaughtException", e => { trySend({ event: "exception", exMsg: e.message }); }); diff --git a/test/js/valkey/valkey.test.ts b/test/js/valkey/valkey.test.ts index 11b1902fbd..1cd2de2fd0 100644 --- a/test/js/valkey/valkey.test.ts +++ b/test/js/valkey/valkey.test.ts @@ -1,809 +1,824 @@ import { randomUUIDv7, RedisClient, spawn } from "bun"; import { beforeAll, beforeEach, describe, expect, test } from "bun:test"; import { + ctx as _ctx, awaitableCounter, ConnectionType, createClient, - ctx, DEFAULT_REDIS_URL, expectType, isEnabled, randomCoinFlip, setupDockerContainer, + TLS_REDIS_OPTIONS, + TLS_REDIS_URL, } from "./test-utils"; +import type { RedisTestStartMessage } from "./valkey.failing-subscriber"; -describe.skipIf(!isEnabled)("Valkey Redis Client", () => { - beforeAll(async () => { - // Ensure container is ready before tests run - await setupDockerContainer(); - if (!ctx.redis) { - ctx.redis = createClient(ConnectionType.TCP); - } - }); - - beforeEach(async () => { - // Don't create a new client, just ensure we have one - if (!ctx.redis) { - ctx.redis = createClient(ConnectionType.TCP); - } - - // Flush all data for clean test state - await ctx.redis.connect(); - await ctx.redis.send("FLUSHALL", ["SYNC"]); - }); - - describe("Basic Operations", () => { - test("should set and get strings", async () => { - const redis = ctx.redis; - const testKey = "greeting"; - const testValue = "Hello from Bun Redis!"; - - // Using direct set and get methods - const setResult = await redis.set(testKey, testValue); - expect(setResult).toMatchInlineSnapshot(`"OK"`); - - const setResult2 = await redis.set(testKey, testValue, "GET"); - expect(setResult2).toMatchInlineSnapshot(`"${testValue}"`); - - // GET should return the value we set - const getValue = await redis.get(testKey); - expect(getValue).toMatchInlineSnapshot(`"${testValue}"`); - }); - - test("should test key existence", async () => { - const redis = ctx.redis; - // Let's set a key first - await redis.set("greeting", "test existence"); - - // EXISTS in Redis normally returns integer 1 if key exists, 0 if not - // The current implementation doesn't transform exists correctly yet - const exists = await redis.exists("greeting"); - expect(exists).toBeDefined(); - // Should be true for existing keys (fixed in special handling for EXISTS) - expect(exists).toBe(true); - - // For non-existent keys - const randomKey = "nonexistent-key-" + randomUUIDv7(); - const notExists = await redis.exists(randomKey); - expect(notExists).toBeDefined(); - // Should be false for non-existing keys - expect(notExists).toBe(false); - }); - - test("should increment and decrement counters", async () => { - const redis = ctx.redis; - const counterKey = "counter"; - // First set a counter value - await redis.set(counterKey, "10"); - - // INCR should increment and return the new value - const incrementedValue = await redis.incr(counterKey); - expect(incrementedValue).toBeDefined(); - expect(typeof incrementedValue).toBe("number"); - expect(incrementedValue).toBe(11); - - // DECR should decrement and return the new value - const decrementedValue = await redis.decr(counterKey); - expect(decrementedValue).toBeDefined(); - expect(typeof decrementedValue).toBe("number"); - expect(decrementedValue).toBe(10); - }); - - test("should manage key expiration", async () => { - const redis = ctx.redis; - // Set a key first - const tempKey = "temporary"; - await redis.set(tempKey, "will expire"); - - // EXPIRE should return 1 if the timeout was set, 0 otherwise - const result = await redis.expire(tempKey, 60); - // Using native expire command instead of send() - expect(result).toMatchInlineSnapshot(`1`); - - // Use the TTL command directly - const ttl = await redis.ttl(tempKey); - expectType(ttl, "number"); - expect(ttl).toBeGreaterThan(0); - expect(ttl).toBeLessThanOrEqual(60); // Should be positive and not exceed our set time - }); - - test("should implement TTL command correctly for different cases", async () => { - const redis = ctx.redis; - // 1. Key with expiration - const tempKey = "ttl-test-key"; - await redis.set(tempKey, "ttl test value"); - await redis.expire(tempKey, 60); - - // Use native ttl command - const ttl = await redis.ttl(tempKey); - expectType(ttl, "number"); - expect(ttl).toBeGreaterThan(0); - expect(ttl).toBeLessThanOrEqual(60); - - // 2. Key with no expiration - const permanentKey = "permanent-key"; - await redis.set(permanentKey, "no expiry"); - const noExpiry = await redis.ttl(permanentKey); - expect(noExpiry).toMatchInlineSnapshot(`-1`); // -1 indicates no expiration - - // 3. Non-existent key - const nonExistentKey = "non-existent-" + randomUUIDv7(); - const noKey = await redis.ttl(nonExistentKey); - expect(noKey).toMatchInlineSnapshot(`-2`); // -2 indicates key doesn't exist - }); - }); - - describe("Connection State", () => { - test("should have a connected property", () => { - const redis = ctx.redis; - // The client should expose a connected property - expect(typeof redis.connected).toBe("boolean"); - }); - }); - - describe("RESP3 Data Types", () => { - test("should handle hash maps (dictionaries) as command responses", async () => { - const redis = ctx.redis; - // HSET multiple fields - const userId = "user:" + randomUUIDv7().substring(0, 8); - const setResult = await redis.send("HSET", [userId, "name", "John", "age", "30", "active", "true"]); - expect(setResult).toBeDefined(); - - // HGETALL returns object with key-value pairs - const hash = await redis.send("HGETALL", [userId]); - expect(hash).toBeDefined(); - - // Proper structure checking when RESP3 maps are fixed - if (typeof hash === "object" && hash !== null) { - expect(hash).toHaveProperty("name"); - expect(hash).toHaveProperty("age"); - expect(hash).toHaveProperty("active"); - - expect(hash.name).toBe("John"); - expect(hash.age).toBe("30"); - expect(hash.active).toBe("true"); +for (const connectionType of [ConnectionType.TLS, ConnectionType.TCP]) { + const ctx = { ..._ctx, redis: connectionType ? _ctx.redis : _ctx.redisTLS }; + describe.skipIf(!isEnabled)(`Valkey Redis Client (${connectionType})`, () => { + beforeAll(async () => { + // Ensure container is ready before tests run + await setupDockerContainer(); + if (!ctx.redis) { + ctx.redis = createClient(connectionType); } }); - test("should handle sets as command responses", async () => { - const redis = ctx.redis; - // Add items to a set - const setKey = "colors:" + randomUUIDv7().substring(0, 8); - const addResult = await redis.send("SADD", [setKey, "red", "blue", "green"]); - expect(addResult).toBeDefined(); - - // Get set members - const setMembers = await redis.send("SMEMBERS", [setKey]); - expect(setMembers).toBeDefined(); - - // Check if the response is an array - expect(Array.isArray(setMembers)).toBe(true); - - // Should contain our colors - expect(setMembers).toContain("red"); - expect(setMembers).toContain("blue"); - expect(setMembers).toContain("green"); - }); - }); - - describe("Connection Options", () => { - test("connection errors", async () => { - const url = new URL(DEFAULT_REDIS_URL); - url.username = "badusername"; - url.password = "secretpassword"; - const customRedis = new RedisClient(url.toString()); - - expect(async () => { - await customRedis.get("test"); - }).toThrowErrorMatchingInlineSnapshot(`"WRONGPASS invalid username-password pair or user is disabled."`); - }); - - const testKeyUniquePerDb = crypto.randomUUID(); - test.each([...Array(16).keys()])("Connecting to database with url $url succeeds", async (dbId: number) => { - const redis = createClient(ConnectionType.TCP, {}, dbId); - - // Ensure the value is not in the database. - const testValue = await redis.get(testKeyUniquePerDb); - expect(testValue).toBeNull(); - - redis.close(); - }); - }); - - describe("Reconnections", () => { - test.skip("should automatically reconnect after connection drop", async () => { - // NOTE: This test was already broken before the Docker Compose migration. - // It times out after 31 seconds with "Max reconnection attempts reached" - // This appears to be an issue with the Redis client's automatic reconnection - // behavior, not related to the Docker infrastructure changes. - const TEST_KEY = "test-key"; - const TEST_VALUE = "test-value"; - - // Ensure we have a working client to start - if (!ctx.redis || !ctx.redis.connected) { - ctx.redis = createClient(ConnectionType.TCP); - } - - const valueBeforeStart = await ctx.redis.get(TEST_KEY); - expect(valueBeforeStart).toBeNull(); - - // Set some value - await ctx.redis.set(TEST_KEY, TEST_VALUE); - const valueAfterSet = await ctx.redis.get(TEST_KEY); - expect(valueAfterSet).toBe(TEST_VALUE); - - await ctx.restartServer(); - - const valueAfterStop = await ctx.redis.get(TEST_KEY); - expect(valueAfterStop).toBe(TEST_VALUE); - }); - }); - - describe("PUB/SUB", () => { - var i = 0; - const testChannel = () => { - return `test-channel-${i++}`; - }; - const testKey = () => { - return `test-key-${i++}`; - }; - const testValue = () => { - return `test-value-${i++}`; - }; - const testMessage = () => { - return `test-message-${i++}`; - }; - beforeEach(async () => { - // The PUB/SUB tests expect that ctx.redis is connected but not in subscriber mode. - await ctx.cleanupSubscribers(); - }); - - test("publishing to a channel does not fail", async () => { - expect(await ctx.redis.publish(testChannel(), testMessage())).toBe(0); - }); - - test("setting in subscriber mode gracefully fails", async () => { - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - - await subscriber.subscribe(testChannel(), () => {}); - - expect(() => subscriber.set(testKey(), testValue())).toThrow( - "RedisClient.prototype.set cannot be called while in subscriber mode", - ); - - await subscriber.unsubscribe(testChannel()); - }); - - test("setting after unsubscribing works", async () => { - const channel = testChannel(); - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - await subscriber.subscribe(channel, () => {}); - await subscriber.unsubscribe(channel); - expect(ctx.redis.set(testKey(), testValue())).resolves.toEqual("OK"); - }); - - test("subscribing to a channel receives messages", async () => { - const TEST_MESSAGE_COUNT = 128; - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - const channel = testChannel(); - const message = testMessage(); - - const counter = awaitableCounter(); - await subscriber.subscribe(channel, (message, channel) => { - counter.increment(); - expect(channel).toBe(channel); - expect(message).toBe(message); - }); - - Array.from({ length: TEST_MESSAGE_COUNT }).forEach(async () => { - expect(await ctx.redis.publish(channel, message)).toBe(1); - }); - - await counter.untilValue(TEST_MESSAGE_COUNT); - expect(counter.count()).toBe(TEST_MESSAGE_COUNT); - }); - - test("messages are received in order", async () => { - const channel = testChannel(); - - await ctx.redis.set("START-TEST", "1"); - const TEST_MESSAGE_COUNT = 1024; - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - - const counter = awaitableCounter(); - var receivedMessages: string[] = []; - await subscriber.subscribe(channel, message => { - receivedMessages.push(message); - counter.increment(); - }); - - const sentMessages = Array.from({ length: TEST_MESSAGE_COUNT }).map(() => { - return randomUUIDv7(); - }); - await Promise.all( - sentMessages.map(async message => { - expect(await ctx.redis.publish(channel, message)).toBe(1); - }), - ); - - await counter.untilValue(TEST_MESSAGE_COUNT); - expect(receivedMessages.length).toBe(sentMessages.length); - expect(receivedMessages).toEqual(sentMessages); - - await subscriber.unsubscribe(channel); - - await ctx.redis.set("STOP-TEST", "1"); - }); - - test("subscribing to multiple channels receives messages", async () => { - const TEST_MESSAGE_COUNT = 128; - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - - const channels = [testChannel(), testChannel()]; - const counter = awaitableCounter(); - - var receivedMessages: { [channel: string]: string[] } = {}; - await subscriber.subscribe(channels, (message, channel) => { - receivedMessages[channel] = receivedMessages[channel] || []; - receivedMessages[channel].push(message); - counter.increment(); - }); - - var sentMessages: { [channel: string]: string[] } = {}; - for (let i = 0; i < TEST_MESSAGE_COUNT; i++) { - const channel = channels[randomCoinFlip() ? 0 : 1]; - const message = randomUUIDv7(); - - expect(await ctx.redis.publish(channel, message)).toBe(1); - - sentMessages[channel] = sentMessages[channel] || []; - sentMessages[channel].push(message); + // Don't create a new client, just ensure we have one + if (!ctx.redis) { + ctx.redis = createClient(connectionType); } - await counter.untilValue(TEST_MESSAGE_COUNT); + // Flush all data for clean test state + await ctx.redis.connect(); + await ctx.redis.send("FLUSHALL", ["SYNC"]); + }); - // Check that we received messages on both channels - expect(Object.keys(receivedMessages).sort()).toEqual(Object.keys(sentMessages).sort()); + describe("Basic Operations", () => { + test("should set and get strings", async () => { + const redis = ctx.redis; + const testKey = "greeting"; + const testValue = "Hello from Bun Redis!"; - // Check messages match for each channel - for (const channel of channels) { - if (sentMessages[channel]) { - expect(receivedMessages[channel]).toEqual(sentMessages[channel]); + // Using direct set and get methods + const setResult = await redis.set(testKey, testValue); + expect(setResult).toMatchInlineSnapshot(`"OK"`); + + const setResult2 = await redis.set(testKey, testValue, "GET"); + expect(setResult2).toMatchInlineSnapshot(`"${testValue}"`); + + // GET should return the value we set + const getValue = await redis.get(testKey); + expect(getValue).toMatchInlineSnapshot(`"${testValue}"`); + }); + + test("should test key existence", async () => { + const redis = ctx.redis; + // Let's set a key first + await redis.set("greeting", "test existence"); + + // EXISTS in Redis normally returns integer 1 if key exists, 0 if not + // The current implementation doesn't transform exists correctly yet + const exists = await redis.exists("greeting"); + expect(exists).toBeDefined(); + // Should be true for existing keys (fixed in special handling for EXISTS) + expect(exists).toBe(true); + + // For non-existent keys + const randomKey = "nonexistent-key-" + randomUUIDv7(); + const notExists = await redis.exists(randomKey); + expect(notExists).toBeDefined(); + // Should be false for non-existing keys + expect(notExists).toBe(false); + }); + + test("should increment and decrement counters", async () => { + const redis = ctx.redis; + const counterKey = "counter"; + // First set a counter value + await redis.set(counterKey, "10"); + + // INCR should increment and return the new value + const incrementedValue = await redis.incr(counterKey); + expect(incrementedValue).toBeDefined(); + expect(typeof incrementedValue).toBe("number"); + expect(incrementedValue).toBe(11); + + // DECR should decrement and return the new value + const decrementedValue = await redis.decr(counterKey); + expect(decrementedValue).toBeDefined(); + expect(typeof decrementedValue).toBe("number"); + expect(decrementedValue).toBe(10); + }); + + test("should manage key expiration", async () => { + const redis = ctx.redis; + // Set a key first + const tempKey = "temporary"; + await redis.set(tempKey, "will expire"); + + // EXPIRE should return 1 if the timeout was set, 0 otherwise + const result = await redis.expire(tempKey, 60); + // Using native expire command instead of send() + expect(result).toMatchInlineSnapshot(`1`); + + // Use the TTL command directly + const ttl = await redis.ttl(tempKey); + expectType(ttl, "number"); + expect(ttl).toBeGreaterThan(0); + expect(ttl).toBeLessThanOrEqual(60); // Should be positive and not exceed our set time + }); + + test("should implement TTL command correctly for different cases", async () => { + const redis = ctx.redis; + // 1. Key with expiration + const tempKey = "ttl-test-key"; + await redis.set(tempKey, "ttl test value"); + await redis.expire(tempKey, 60); + + // Use native ttl command + const ttl = await redis.ttl(tempKey); + expectType(ttl, "number"); + expect(ttl).toBeGreaterThan(0); + expect(ttl).toBeLessThanOrEqual(60); + + // 2. Key with no expiration + const permanentKey = "permanent-key"; + await redis.set(permanentKey, "no expiry"); + const noExpiry = await redis.ttl(permanentKey); + expect(noExpiry).toMatchInlineSnapshot(`-1`); // -1 indicates no expiration + + // 3. Non-existent key + const nonExistentKey = "non-existent-" + randomUUIDv7(); + const noKey = await redis.ttl(nonExistentKey); + expect(noKey).toMatchInlineSnapshot(`-2`); // -2 indicates key doesn't exist + }); + }); + + describe("Connection State", () => { + test("should have a connected property", () => { + const redis = ctx.redis; + // The client should expose a connected property + expect(typeof redis.connected).toBe("boolean"); + }); + }); + + describe("RESP3 Data Types", () => { + test("should handle hash maps (dictionaries) as command responses", async () => { + const redis = ctx.redis; + // HSET multiple fields + const userId = "user:" + randomUUIDv7().substring(0, 8); + const setResult = await redis.send("HSET", [userId, "name", "John", "age", "30", "active", "true"]); + expect(setResult).toBeDefined(); + + // HGETALL returns object with key-value pairs + const hash = await redis.send("HGETALL", [userId]); + expect(hash).toBeDefined(); + + // Proper structure checking when RESP3 maps are fixed + if (typeof hash === "object" && hash !== null) { + expect(hash).toHaveProperty("name"); + expect(hash).toHaveProperty("age"); + expect(hash).toHaveProperty("active"); + + expect(hash.name).toBe("John"); + expect(hash.age).toBe("30"); + expect(hash.active).toBe("true"); } - } - - await subscriber.unsubscribe(channels); - }); - - test("unsubscribing from specific channels while remaining subscribed to others", async () => { - const channel1 = "channel-1"; - const channel2 = "channel-2"; - const channel3 = "channel-3"; - - const subscriber = createClient(ConnectionType.TCP); - await subscriber.connect(); - - let receivedMessages: { [channel: string]: string[] } = {}; - - // Total counter for all messages we expect to receive: 3 initial + 2 after unsubscribe = 5 total - const counter = awaitableCounter(); - - // Subscribe to three channels - await subscriber.subscribe([channel1, channel2, channel3], (message, channel) => { - receivedMessages[channel] = receivedMessages[channel] || []; - receivedMessages[channel].push(message); - counter.increment(); }); - // Send initial messages to all channels - expect(await ctx.redis.publish(channel1, "msg1-before")).toBe(1); - expect(await ctx.redis.publish(channel2, "msg2-before")).toBe(1); - expect(await ctx.redis.publish(channel3, "msg3-before")).toBe(1); + test("should handle sets as command responses", async () => { + const redis = ctx.redis; + // Add items to a set + const setKey = "colors:" + randomUUIDv7().substring(0, 8); + const addResult = await redis.send("SADD", [setKey, "red", "blue", "green"]); + expect(addResult).toBeDefined(); - // Wait for initial messages, then unsubscribe from channel2 - await counter.untilValue(3); - await subscriber.unsubscribe(channel2); + // Get set members + const setMembers = await redis.send("SMEMBERS", [setKey]); + expect(setMembers).toBeDefined(); - // Send messages after unsubscribing from channel2 - expect(await ctx.redis.publish(channel1, "msg1-after")).toBe(1); - expect(await ctx.redis.publish(channel2, "msg2-after")).toBe(0); - expect(await ctx.redis.publish(channel3, "msg3-after")).toBe(1); + // Check if the response is an array + expect(Array.isArray(setMembers)).toBe(true); - await counter.untilValue(5); - - // Check we received messages only on subscribed channels - expect(receivedMessages[channel1]).toEqual(["msg1-before", "msg1-after"]); - expect(receivedMessages[channel2]).toEqual(["msg2-before"]); // No "msg2-after" - expect(receivedMessages[channel3]).toEqual(["msg3-before", "msg3-after"]); - - await subscriber.unsubscribe([channel1, channel3]); + // Should contain our colors + expect(setMembers).toContain("red"); + expect(setMembers).toContain("blue"); + expect(setMembers).toContain("green"); + }); }); - test("subscribing to the same channel multiple times", async () => { - const subscriber = createClient(ConnectionType.TCP); - await subscriber.connect(); - const channel = testChannel(); + describe("Connection Options", () => { + test("connection errors", async () => { + const url = new URL(connectionType === ConnectionType.TLS ? TLS_REDIS_URL : DEFAULT_REDIS_URL); + url.username = "badusername"; + url.password = "secretpassword"; + const customRedis = new RedisClient(url.toString(), { + tls: connectionType === ConnectionType.TLS ? TLS_REDIS_OPTIONS.tls : false, + }); - const counter = awaitableCounter(); - - let callCount = 0; - const listener = () => { - callCount++; - counter.increment(); - }; - - let callCount2 = 0; - const listener2 = () => { - callCount2++; - counter.increment(); - }; - - // Subscribe to the same channel twice - await subscriber.subscribe(channel, listener); - await subscriber.subscribe(channel, listener2); - - // Publish a single message - expect(await ctx.redis.publish(channel, "test-message")).toBe(1); - - await counter.untilValue(2); - - // Both listeners should have been called once. - expect(callCount).toBe(1); - expect(callCount2).toBe(1); - - await subscriber.unsubscribe(channel); - }); - - test("empty string messages", async () => { - const channel = "empty-message-channel"; - const subscriber = createClient(ConnectionType.TCP); - await subscriber.connect(); - - const counter = awaitableCounter(); - let receivedMessage: string | undefined = undefined; - await subscriber.subscribe(channel, message => { - receivedMessage = message; - counter.increment(); + expect(async () => { + await customRedis.get("test"); + }).toThrowErrorMatchingInlineSnapshot(`"WRONGPASS invalid username-password pair or user is disabled."`); }); - expect(await ctx.redis.publish(channel, "")).toBe(1); - await counter.untilValue(1); + const testKeyUniquePerDb = crypto.randomUUID(); + test.each([...Array(16).keys()])("Connecting to database with url $url succeeds", async (dbId: number) => { + const redis = createClient(connectionType, {}, dbId); - expect(receivedMessage).not.toBeUndefined(); - expect(receivedMessage!).toBe(""); + // Ensure the value is not in the database. + const testValue = await redis.get(testKeyUniquePerDb); + expect(testValue).toBeNull(); - await subscriber.unsubscribe(channel); + redis.close(); + }); }); - test("special characters in channel names", async () => { - const subscriber = createClient(ConnectionType.TCP); - await subscriber.connect(); + describe("Reconnections", () => { + test.skip("should automatically reconnect after connection drop", async () => { + // NOTE: This test was already broken before the Docker Compose migration. + // It times out after 31 seconds with "Max reconnection attempts reached" + // This appears to be an issue with the Redis client's automatic reconnection + // behavior, not related to the Docker infrastructure changes. + const TEST_KEY = "test-key"; + const TEST_VALUE = "test-value"; - const specialChannels = [ - "channel:with:colons", - "channel with spaces", - "channel-with-unicode-😀", - "channel[with]brackets", - "channel@with#special$chars", - ]; + // Ensure we have a working client to start + if (!ctx.redis || !ctx.redis.connected) { + ctx.redis = createClient(connectionType); + } + + const valueBeforeStart = await ctx.redis.get(TEST_KEY); + expect(valueBeforeStart).toBeNull(); + + // Set some value + await ctx.redis.set(TEST_KEY, TEST_VALUE); + const valueAfterSet = await ctx.redis.get(TEST_KEY); + expect(valueAfterSet).toBe(TEST_VALUE); + + await ctx.restartServer(); + + const valueAfterStop = await ctx.redis.get(TEST_KEY); + expect(valueAfterStop).toBe(TEST_VALUE); + }); + }); + + describe("PUB/SUB", () => { + var i = 0; + const testChannel = () => { + return `test-channel-${i++}`; + }; + const testKey = () => { + return `test-key-${i++}`; + }; + const testValue = () => { + return `test-value-${i++}`; + }; + const testMessage = () => { + return `test-message-${i++}`; + }; + + beforeEach(async () => { + // The PUB/SUB tests expect that ctx.redis is connected but not in subscriber mode. + await ctx.cleanupSubscribers(); + }); + + test("publishing to a channel does not fail", async () => { + expect(await ctx.redis.publish(testChannel(), testMessage())).toBe(0); + }); + + test("setting in subscriber mode gracefully fails", async () => { + const subscriber = await ctx.newSubscriberClient(connectionType); + + await subscriber.subscribe(testChannel(), () => {}); + + expect(() => subscriber.set(testKey(), testValue())).toThrow( + "RedisClient.prototype.set cannot be called while in subscriber mode", + ); + + await subscriber.unsubscribe(testChannel()); + }); + + test("setting after unsubscribing works", async () => { + const channel = testChannel(); + const subscriber = await ctx.newSubscriberClient(connectionType); + await subscriber.subscribe(channel, () => {}); + await subscriber.unsubscribe(channel); + expect(ctx.redis.set(testKey(), testValue())).resolves.toEqual("OK"); + }); + + test("subscribing to a channel receives messages", async () => { + const TEST_MESSAGE_COUNT = 128; + const subscriber = await ctx.newSubscriberClient(connectionType); + const channel = testChannel(); + const message = testMessage(); - for (const channel of specialChannels) { const counter = awaitableCounter(); - let received = false; - await subscriber.subscribe(channel, () => { - received = true; + await subscriber.subscribe(channel, (message, channel) => { + counter.increment(); + expect(channel).toBe(channel); + expect(message).toBe(message); + }); + + Array.from({ length: TEST_MESSAGE_COUNT }).forEach(async () => { + expect(await ctx.redis.publish(channel, message)).toBe(1); + }); + + await counter.untilValue(TEST_MESSAGE_COUNT); + expect(counter.count()).toBe(TEST_MESSAGE_COUNT); + }); + + test("messages are received in order", async () => { + const channel = testChannel(); + + await ctx.redis.set("START-TEST", "1"); + const TEST_MESSAGE_COUNT = 1024; + const subscriber = await ctx.newSubscriberClient(connectionType); + + const counter = awaitableCounter(); + var receivedMessages: string[] = []; + await subscriber.subscribe(channel, message => { + receivedMessages.push(message); counter.increment(); }); - expect(await ctx.redis.publish(channel, "test")).toBe(1); - await counter.untilValue(1); + const sentMessages = Array.from({ length: TEST_MESSAGE_COUNT }).map(() => { + return randomUUIDv7(); + }); + await Promise.all( + sentMessages.map(async message => { + expect(await ctx.redis.publish(channel, message)).toBe(1); + }), + ); + + await counter.untilValue(TEST_MESSAGE_COUNT); + expect(receivedMessages.length).toBe(sentMessages.length); + expect(receivedMessages).toEqual(sentMessages); - expect(received).toBe(true); await subscriber.unsubscribe(channel); - } - }); - test("ping works in subscription mode", async () => { - const channel = "ping-test-channel"; - - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - await subscriber.subscribe(channel, () => {}); - - // Ping should work in subscription mode - const pong = await subscriber.ping(); - expect(pong).toBe("PONG"); - - const customPing = await subscriber.ping("hello"); - expect(customPing).toBe("hello"); - }); - - test("publish does not work from a subscribed client", async () => { - const channel = "self-publish-channel"; - - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - await subscriber.subscribe(channel, () => {}); - - // Publishing from the same client should work - expect(async () => subscriber.publish(channel, "self-published")).toThrow(); - }); - - test("complete unsubscribe restores normal command mode", async () => { - const channel = "restore-test-channel"; - const testKey = "restore-test-key"; - - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - await subscriber.subscribe(channel, () => {}); - - // Should fail in subscription mode - expect(() => subscriber.set(testKey, testValue())).toThrow( - "RedisClient.prototype.set cannot be called while in subscriber mode.", - ); - - // Unsubscribe from all channels - await subscriber.unsubscribe(); - - // Should work after unsubscribing - const result = await ctx.redis.set(testKey, "value"); - expect(result).toBe("OK"); - - const value = await ctx.redis.get(testKey); - expect(value).toBe("value"); - }); - - test("publishing without subscribers succeeds", async () => { - const channel = "no-subscribers-channel"; - - // Publishing without subscribers should not throw - expect(await ctx.redis.publish(channel, "message")).toBe(0); - }); - - test("unsubscribing from non-subscribed channels", async () => { - const channel = "never-subscribed-channel"; - - expect(() => ctx.redis.unsubscribe(channel)).toThrow( - "RedisClient.prototype.unsubscribe can only be called while in subscriber mode.", - ); - }); - - test("callback errors don't crash the client", async () => { - const channel = "error-callback-channel"; - - const STEP_SUBSCRIBED = 1; - const STEP_FIRST_MESSAGE = 2; - const STEP_SECOND_MESSAGE = 3; - const STEP_THIRD_MESSAGE = 4; - - // stepCounter is a slight hack to track the progress of the subprocess. - const stepCounter = awaitableCounter(); - let currentMessage: any = {}; - - const subscriberProc = spawn({ - cmd: [self.process.execPath, "run", `${__dirname}/valkey.failing-subscriber.ts`], - stdout: "inherit", - stderr: "inherit", - ipc: msg => { - currentMessage = msg; - stepCounter.increment(); - }, - env: { - ...process.env, - NODE_ENV: "development", - }, + await ctx.redis.set("STOP-TEST", "1"); }); - subscriberProc.send({ event: "start", url: DEFAULT_REDIS_URL }); + test("subscribing to multiple channels receives messages", async () => { + const TEST_MESSAGE_COUNT = 128; + const subscriber = await ctx.newSubscriberClient(connectionType); - try { - await stepCounter.untilValue(STEP_SUBSCRIBED); - expect(currentMessage.event).toBe("ready"); + const channels = [testChannel(), testChannel()]; + const counter = awaitableCounter(); - // Send multiple messages - expect(await ctx.redis.publish(channel, "message1")).toBe(1); - await stepCounter.untilValue(STEP_FIRST_MESSAGE); - expect(currentMessage.event).toBe("message"); - expect(currentMessage.index).toBe(1); + var receivedMessages: { [channel: string]: string[] } = {}; + await subscriber.subscribe(channels, (message, channel) => { + receivedMessages[channel] = receivedMessages[channel] || []; + receivedMessages[channel].push(message); + counter.increment(); + }); - // Now, the subscriber process will crash - expect(await ctx.redis.publish(channel, "message2")).toBe(1); - await stepCounter.untilValue(STEP_SECOND_MESSAGE); - expect(currentMessage.event).toBe("exception"); - //expect(currentMessage.index).toBe(2); + var sentMessages: { [channel: string]: string[] } = {}; + for (let i = 0; i < TEST_MESSAGE_COUNT; i++) { + const channel = channels[randomCoinFlip() ? 0 : 1]; + const message = randomUUIDv7(); - // But it should recover and continue receiving messages - expect(await ctx.redis.publish(channel, "message3")).toBe(1); - await stepCounter.untilValue(STEP_THIRD_MESSAGE); - expect(currentMessage.event).toBe("message"); - expect(currentMessage.index).toBe(3); - } finally { - subscriberProc.kill(); - } + expect(await ctx.redis.publish(channel, message)).toBe(1); + + sentMessages[channel] = sentMessages[channel] || []; + sentMessages[channel].push(message); + } + + await counter.untilValue(TEST_MESSAGE_COUNT); + + // Check that we received messages on both channels + expect(Object.keys(receivedMessages).sort()).toEqual(Object.keys(sentMessages).sort()); + + // Check messages match for each channel + for (const channel of channels) { + if (sentMessages[channel]) { + expect(receivedMessages[channel]).toEqual(sentMessages[channel]); + } + } + + await subscriber.unsubscribe(channels); + }); + + test("unsubscribing from specific channels while remaining subscribed to others", async () => { + const channel1 = "channel-1"; + const channel2 = "channel-2"; + const channel3 = "channel-3"; + + const subscriber = createClient(connectionType); + await subscriber.connect(); + + let receivedMessages: { [channel: string]: string[] } = {}; + + // Total counter for all messages we expect to receive: 3 initial + 2 after unsubscribe = 5 total + const counter = awaitableCounter(); + + // Subscribe to three channels + await subscriber.subscribe([channel1, channel2, channel3], (message, channel) => { + receivedMessages[channel] = receivedMessages[channel] || []; + receivedMessages[channel].push(message); + counter.increment(); + }); + + // Send initial messages to all channels + expect(await ctx.redis.publish(channel1, "msg1-before")).toBe(1); + expect(await ctx.redis.publish(channel2, "msg2-before")).toBe(1); + expect(await ctx.redis.publish(channel3, "msg3-before")).toBe(1); + + // Wait for initial messages, then unsubscribe from channel2 + await counter.untilValue(3); + await subscriber.unsubscribe(channel2); + + // Send messages after unsubscribing from channel2 + expect(await ctx.redis.publish(channel1, "msg1-after")).toBe(1); + expect(await ctx.redis.publish(channel2, "msg2-after")).toBe(0); + expect(await ctx.redis.publish(channel3, "msg3-after")).toBe(1); + + await counter.untilValue(5); + + // Check we received messages only on subscribed channels + expect(receivedMessages[channel1]).toEqual(["msg1-before", "msg1-after"]); + expect(receivedMessages[channel2]).toEqual(["msg2-before"]); // No "msg2-after" + expect(receivedMessages[channel3]).toEqual(["msg3-before", "msg3-after"]); + + await subscriber.unsubscribe([channel1, channel3]); + }); + + test("subscribing to the same channel multiple times", async () => { + const subscriber = createClient(connectionType); + await subscriber.connect(); + const channel = testChannel(); + + const counter = awaitableCounter(); + + let callCount = 0; + const listener = () => { + callCount++; + counter.increment(); + }; + + let callCount2 = 0; + const listener2 = () => { + callCount2++; + counter.increment(); + }; + + // Subscribe to the same channel twice + await subscriber.subscribe(channel, listener); + await subscriber.subscribe(channel, listener2); + + // Publish a single message + expect(await ctx.redis.publish(channel, "test-message")).toBe(1); + + await counter.untilValue(2); + + // Both listeners should have been called once. + expect(callCount).toBe(1); + expect(callCount2).toBe(1); + + await subscriber.unsubscribe(channel); + }); + + test("empty string messages", async () => { + const channel = "empty-message-channel"; + const subscriber = createClient(connectionType); + await subscriber.connect(); + + const counter = awaitableCounter(); + let receivedMessage: string | undefined = undefined; + await subscriber.subscribe(channel, message => { + receivedMessage = message; + counter.increment(); + }); + + expect(await ctx.redis.publish(channel, "")).toBe(1); + await counter.untilValue(1); + + expect(receivedMessage).not.toBeUndefined(); + expect(receivedMessage!).toBe(""); + + await subscriber.unsubscribe(channel); + }); + + test("special characters in channel names", async () => { + const subscriber = createClient(connectionType); + await subscriber.connect(); + + const specialChannels = [ + "channel:with:colons", + "channel with spaces", + "channel-with-unicode-😀", + "channel[with]brackets", + "channel@with#special$chars", + ]; + + for (const channel of specialChannels) { + const counter = awaitableCounter(); + let received = false; + await subscriber.subscribe(channel, () => { + received = true; + counter.increment(); + }); + + expect(await ctx.redis.publish(channel, "test")).toBe(1); + await counter.untilValue(1); + + expect(received).toBe(true); + await subscriber.unsubscribe(channel); + } + }); + + test("ping works in subscription mode", async () => { + const channel = "ping-test-channel"; + + const subscriber = await ctx.newSubscriberClient(connectionType); + await subscriber.subscribe(channel, () => {}); + + // Ping should work in subscription mode + const pong = await subscriber.ping(); + expect(pong).toBe("PONG"); + + const customPing = await subscriber.ping("hello"); + expect(customPing).toBe("hello"); + }); + + test("publish does not work from a subscribed client", async () => { + const channel = "self-publish-channel"; + + const subscriber = await ctx.newSubscriberClient(connectionType); + await subscriber.subscribe(channel, () => {}); + + // Publishing from the same client should work + expect(async () => subscriber.publish(channel, "self-published")).toThrow(); + }); + + test("complete unsubscribe restores normal command mode", async () => { + const channel = "restore-test-channel"; + const testKey = "restore-test-key"; + + const subscriber = await ctx.newSubscriberClient(connectionType); + await subscriber.subscribe(channel, () => {}); + + // Should fail in subscription mode + expect(() => subscriber.set(testKey, testValue())).toThrow( + "RedisClient.prototype.set cannot be called while in subscriber mode.", + ); + + // Unsubscribe from all channels + await subscriber.unsubscribe(); + + // Should work after unsubscribing + const result = await ctx.redis.set(testKey, "value"); + expect(result).toBe("OK"); + + const value = await ctx.redis.get(testKey); + expect(value).toBe("value"); + }); + + test("publishing without subscribers succeeds", async () => { + const channel = "no-subscribers-channel"; + + // Publishing without subscribers should not throw + expect(await ctx.redis.publish(channel, "message")).toBe(0); + }); + + test("unsubscribing from non-subscribed channels", async () => { + const channel = "never-subscribed-channel"; + + expect(() => ctx.redis.unsubscribe(channel)).toThrow( + "RedisClient.prototype.unsubscribe can only be called while in subscriber mode.", + ); + }); + + test("callback errors don't crash the client", async () => { + const channel = "error-callback-channel"; + + const STEP_SUBSCRIBED = 1; + const STEP_FIRST_MESSAGE = 2; + const STEP_SECOND_MESSAGE = 3; + const STEP_THIRD_MESSAGE = 4; + + // stepCounter is a slight hack to track the progress of the subprocess. + const stepCounter = awaitableCounter(); + let currentMessage: any = {}; + + const subscriberProc = spawn({ + cmd: [self.process.execPath, "run", `${__dirname}/valkey.failing-subscriber.ts`], + stdout: "inherit", + stderr: "inherit", + ipc: msg => { + currentMessage = msg; + stepCounter.increment(); + }, + env: { + ...process.env, + NODE_ENV: "development", + }, + }); + + subscriberProc.send({ + event: "start", + url: connectionType === ConnectionType.TLS ? TLS_REDIS_URL : DEFAULT_REDIS_URL, + tlsPaths: connectionType === ConnectionType.TLS ? TLS_REDIS_OPTIONS.tlsPaths : undefined, + } as RedisTestStartMessage); + + try { + await stepCounter.untilValue(STEP_SUBSCRIBED); + expect(currentMessage.event).toBe("ready"); + + // Send multiple messages + expect(await ctx.redis.publish(channel, "message1")).toBeGreaterThanOrEqual(1); + await stepCounter.untilValue(STEP_FIRST_MESSAGE); + expect(currentMessage.event).toBe("message"); + expect(currentMessage.index).toBe(1); + + // Now, the subscriber process will crash + expect(await ctx.redis.publish(channel, "message2")).toBeGreaterThanOrEqual(1); + await stepCounter.untilValue(STEP_SECOND_MESSAGE); + expect(currentMessage.event).toBe("exception"); + //expect(currentMessage.index).toBe(2); + + // But it should recover and continue receiving messages + expect(await ctx.redis.publish(channel, "message3")).toBeGreaterThanOrEqual(1); + await stepCounter.untilValue(STEP_THIRD_MESSAGE); + expect(currentMessage.event).toBe("message"); + expect(currentMessage.index).toBe(3); + } finally { + subscriberProc.kill(); + await subscriberProc.exited; + } + }); + + test("subscriptions return correct counts", async () => { + const subscriber = createClient(connectionType); + await subscriber.connect(); + + expect(await subscriber.subscribe("chan1", () => {})).toBe(1); + expect(await subscriber.subscribe("chan2", () => {})).toBe(2); + }); + + test("unsubscribing from listeners", async () => { + const channel = "error-callback-channel"; + + const subscriber = createClient(connectionType); + await subscriber.connect(); + + // First phase: both listeners should receive 1 message each (2 total) + const counter = awaitableCounter(); + let messageCount1 = 0; + const listener1 = () => { + messageCount1++; + counter.increment(); + }; + await subscriber.subscribe(channel, listener1); + + let messageCount2 = 0; + const listener2 = () => { + messageCount2++; + counter.increment(); + }; + await subscriber.subscribe(channel, listener2); + + await ctx.redis.publish(channel, "message1"); + await counter.untilValue(2); + + expect(messageCount1).toBe(1); + expect(messageCount2).toBe(1); + + console.log("Unsubscribing listener2"); + await subscriber.unsubscribe(channel, listener2); + + await ctx.redis.publish(channel, "message1"); + await counter.untilValue(3); + + expect(messageCount1).toBe(2); + expect(messageCount2).toBe(1); + }); }); - test("subscriptions return correct counts", async () => { - const subscriber = createClient(ConnectionType.TCP); - await subscriber.connect(); + describe("duplicate()", () => { + test("should create duplicate of connected client that gets connected", async () => { + const duplicate = await ctx.redis.duplicate(); - expect(await subscriber.subscribe("chan1", () => {})).toBe(1); - expect(await subscriber.subscribe("chan2", () => {})).toBe(2); - }); + expect(duplicate.connected).toBe(true); + expect(duplicate).not.toBe(ctx.redis); - test("unsubscribing from listeners", async () => { - const channel = "error-callback-channel"; + // Both should work independently + await ctx.redis.set("test-original", "original-value"); + await duplicate.set("test-duplicate", "duplicate-value"); - const subscriber = createClient(ConnectionType.TCP); - await subscriber.connect(); + expect(await ctx.redis.get("test-duplicate")).toBe("duplicate-value"); + expect(await duplicate.get("test-original")).toBe("original-value"); - // First phase: both listeners should receive 1 message each (2 total) - const counter = awaitableCounter(); - let messageCount1 = 0; - const listener1 = () => { - messageCount1++; - counter.increment(); - }; - await subscriber.subscribe(channel, listener1); + duplicate.close(); + }); - let messageCount2 = 0; - const listener2 = () => { - messageCount2++; - counter.increment(); - }; - await subscriber.subscribe(channel, listener2); + test("should preserve connection configuration in duplicate", async () => { + await ctx.redis.connect(); - await ctx.redis.publish(channel, "message1"); - await counter.untilValue(2); + const duplicate = await ctx.redis.duplicate(); - expect(messageCount1).toBe(1); - expect(messageCount2).toBe(1); + // Both clients should be able to perform the same operations + const testKey = `duplicate-config-test-${randomUUIDv7().substring(0, 8)}`; + const testValue = "test-value"; - console.log("Unsubscribing listener2"); - await subscriber.unsubscribe(channel, listener2); + await ctx.redis.set(testKey, testValue); + const retrievedValue = await duplicate.get(testKey); - await ctx.redis.publish(channel, "message1"); - await counter.untilValue(3); + expect(retrievedValue).toBe(testValue); - expect(messageCount1).toBe(2); - expect(messageCount2).toBe(1); + duplicate.close(); + }); + + test("should allow duplicate to work independently from original", async () => { + const duplicate = await ctx.redis.duplicate(); + + // Close original, duplicate should still work + duplicate.close(); + + const testKey = `independent-test-${randomUUIDv7().substring(0, 8)}`; + const testValue = "independent-value"; + + await ctx.redis.set(testKey, testValue); + const retrievedValue = await ctx.redis.get(testKey); + + expect(retrievedValue).toBe(testValue); + }); + + test("should handle duplicate of client in subscriber mode", async () => { + const subscriber = await ctx.newSubscriberClient(connectionType); + + const testChannel = "test-subscriber-duplicate"; + + // Put original client in subscriber mode + await subscriber.subscribe(testChannel, () => {}); + + const duplicate = await subscriber.duplicate(); + + // Duplicate should not be in subscriber mode + expect(() => duplicate.set("test-key", "test-value")).not.toThrow(); + + await subscriber.unsubscribe(testChannel); + }); + + test("should create multiple duplicates from same client", async () => { + await ctx.redis.connect(); + + const duplicate1 = await ctx.redis.duplicate(); + const duplicate2 = await ctx.redis.duplicate(); + const duplicate3 = await ctx.redis.duplicate(); + + // All should be connected + expect(duplicate1.connected).toBe(true); + expect(duplicate2.connected).toBe(true); + expect(duplicate3.connected).toBe(true); + + // All should work independently + const testKey = `multi-duplicate-test-${randomUUIDv7().substring(0, 8)}`; + await duplicate1.set(`${testKey}-1`, "value-1"); + await duplicate2.set(`${testKey}-2`, "value-2"); + await duplicate3.set(`${testKey}-3`, "value-3"); + + expect(await duplicate1.get(`${testKey}-1`)).toBe("value-1"); + expect(await duplicate2.get(`${testKey}-2`)).toBe("value-2"); + expect(await duplicate3.get(`${testKey}-3`)).toBe("value-3"); + + // Cross-check: each duplicate can read what others wrote + expect(await duplicate1.get(`${testKey}-2`)).toBe("value-2"); + expect(await duplicate2.get(`${testKey}-3`)).toBe("value-3"); + expect(await duplicate3.get(`${testKey}-1`)).toBe("value-1"); + + duplicate1.close(); + duplicate2.close(); + duplicate3.close(); + }); + + test("should duplicate client that failed to connect", async () => { + // Create client with invalid credentials to force connection failure + const url = new URL(connectionType === ConnectionType.TLS ? TLS_REDIS_URL : DEFAULT_REDIS_URL); + url.username = "invaliduser"; + url.password = "invalidpassword"; + const failedRedis = new RedisClient(url.toString(), { + tls: connectionType === ConnectionType.TLS ? TLS_REDIS_OPTIONS.tls : false, + }); + + // Try to connect and expect it to fail + let connectionFailed = false; + try { + await failedRedis.connect(); + } catch { + connectionFailed = true; + } + + expect(connectionFailed).toBe(true); + expect(failedRedis.connected).toBe(false); + + // Duplicate should also remain unconnected + const duplicate = await failedRedis.duplicate(); + expect(duplicate.connected).toBe(false); + }); + + test("should handle duplicate timing with concurrent operations", async () => { + await ctx.redis.connect(); + + // Start some operations on the original client + const testKey = `concurrent-test-${randomUUIDv7().substring(0, 8)}`; + const originalOperation = ctx.redis.set(testKey, "original-value"); + + // Create duplicate while operation is in flight + const duplicate = await ctx.redis.duplicate(); + + // Wait for original operation to complete + await originalOperation; + + // Duplicate should be able to read the value + expect(await duplicate.get(testKey)).toBe("original-value"); + + duplicate.close(); + }); }); }); - - describe("duplicate()", () => { - test("should create duplicate of connected client that gets connected", async () => { - const duplicate = await ctx.redis.duplicate(); - - expect(duplicate.connected).toBe(true); - expect(duplicate).not.toBe(ctx.redis); - - // Both should work independently - await ctx.redis.set("test-original", "original-value"); - await duplicate.set("test-duplicate", "duplicate-value"); - - expect(await ctx.redis.get("test-duplicate")).toBe("duplicate-value"); - expect(await duplicate.get("test-original")).toBe("original-value"); - - duplicate.close(); - }); - - test("should preserve connection configuration in duplicate", async () => { - await ctx.redis.connect(); - - const duplicate = await ctx.redis.duplicate(); - - // Both clients should be able to perform the same operations - const testKey = `duplicate-config-test-${randomUUIDv7().substring(0, 8)}`; - const testValue = "test-value"; - - await ctx.redis.set(testKey, testValue); - const retrievedValue = await duplicate.get(testKey); - - expect(retrievedValue).toBe(testValue); - - duplicate.close(); - }); - - test("should allow duplicate to work independently from original", async () => { - const duplicate = await ctx.redis.duplicate(); - - // Close original, duplicate should still work - duplicate.close(); - - const testKey = `independent-test-${randomUUIDv7().substring(0, 8)}`; - const testValue = "independent-value"; - - await ctx.redis.set(testKey, testValue); - const retrievedValue = await ctx.redis.get(testKey); - - expect(retrievedValue).toBe(testValue); - }); - - test("should handle duplicate of client in subscriber mode", async () => { - const subscriber = await ctx.newSubscriberClient(ConnectionType.TCP); - - const testChannel = "test-subscriber-duplicate"; - - // Put original client in subscriber mode - await subscriber.subscribe(testChannel, () => {}); - - const duplicate = await subscriber.duplicate(); - - // Duplicate should not be in subscriber mode - expect(() => duplicate.set("test-key", "test-value")).not.toThrow(); - - await subscriber.unsubscribe(testChannel); - }); - - test("should create multiple duplicates from same client", async () => { - await ctx.redis.connect(); - - const duplicate1 = await ctx.redis.duplicate(); - const duplicate2 = await ctx.redis.duplicate(); - const duplicate3 = await ctx.redis.duplicate(); - - // All should be connected - expect(duplicate1.connected).toBe(true); - expect(duplicate2.connected).toBe(true); - expect(duplicate3.connected).toBe(true); - - // All should work independently - const testKey = `multi-duplicate-test-${randomUUIDv7().substring(0, 8)}`; - await duplicate1.set(`${testKey}-1`, "value-1"); - await duplicate2.set(`${testKey}-2`, "value-2"); - await duplicate3.set(`${testKey}-3`, "value-3"); - - expect(await duplicate1.get(`${testKey}-1`)).toBe("value-1"); - expect(await duplicate2.get(`${testKey}-2`)).toBe("value-2"); - expect(await duplicate3.get(`${testKey}-3`)).toBe("value-3"); - - // Cross-check: each duplicate can read what others wrote - expect(await duplicate1.get(`${testKey}-2`)).toBe("value-2"); - expect(await duplicate2.get(`${testKey}-3`)).toBe("value-3"); - expect(await duplicate3.get(`${testKey}-1`)).toBe("value-1"); - - duplicate1.close(); - duplicate2.close(); - duplicate3.close(); - }); - - test("should duplicate client that failed to connect", async () => { - // Create client with invalid credentials to force connection failure - const url = new URL(DEFAULT_REDIS_URL); - url.username = "invaliduser"; - url.password = "invalidpassword"; - const failedRedis = new RedisClient(url.toString()); - - // Try to connect and expect it to fail - let connectionFailed = false; - try { - await failedRedis.connect(); - } catch { - connectionFailed = true; - } - - expect(connectionFailed).toBe(true); - expect(failedRedis.connected).toBe(false); - - // Duplicate should also remain unconnected - const duplicate = await failedRedis.duplicate(); - expect(duplicate.connected).toBe(false); - }); - - test("should handle duplicate timing with concurrent operations", async () => { - await ctx.redis.connect(); - - // Start some operations on the original client - const testKey = `concurrent-test-${randomUUIDv7().substring(0, 8)}`; - const originalOperation = ctx.redis.set(testKey, "original-value"); - - // Create duplicate while operation is in flight - const duplicate = await ctx.redis.duplicate(); - - // Wait for original operation to complete - await originalOperation; - - // Duplicate should be able to read the value - expect(await duplicate.get(testKey)).toBe("original-value"); - - duplicate.close(); - }); - }); -}); +} From 733e7f6165ea037f950aa58cb825d1e0538905ee Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Fri, 26 Sep 2025 19:01:01 -0700 Subject: [PATCH 36/43] Fix fetch-preconnect test failure (#23016) ### What does this PR do? ### How did you verify your code works? --- src/http/HTTPContext.zig | 2 +- test/js/web/fetch/fetch-preconnect.test.ts | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/http/HTTPContext.zig b/src/http/HTTPContext.zig index 8cf733037e..a34020012e 100644 --- a/src/http/HTTPContext.zig +++ b/src/http/HTTPContext.zig @@ -412,7 +412,6 @@ pub fn NewHTTPContext(comptime ssl: bool) type { if (strings.eqlLong(socket.hostname_buf[0..socket.hostname_len], hostname, true)) { const http_socket = socket.http_socket; - assert(context().pending_sockets.put(socket)); if (http_socket.isClosed()) { markSocketAsDead(http_socket); @@ -424,6 +423,7 @@ pub fn NewHTTPContext(comptime ssl: bool) type { continue; } + assert(context().pending_sockets.put(socket)); log("+ Keep-Alive reuse {s}:{d}", .{ hostname, port }); return http_socket; } diff --git a/test/js/web/fetch/fetch-preconnect.test.ts b/test/js/web/fetch/fetch-preconnect.test.ts index 27f8ae9aea..a841460c80 100644 --- a/test/js/web/fetch/fetch-preconnect.test.ts +++ b/test/js/web/fetch/fetch-preconnect.test.ts @@ -1,12 +1,12 @@ import { describe, expect, it } from "bun:test"; import "harness"; -import { isWindows } from "harness"; +import { bunEnv, bunExe, isWindows } from "harness"; // TODO: on Windows, these tests fail. // This feature is mostly meant for serverless JS environments, so we can no-op it on Windows. -describe.todoIf(isWindows)("fetch.preconnect", () => { +describe.concurrent.todoIf(isWindows)("fetch.preconnect", () => { it("fetch.preconnect works", async () => { - const { promise, resolve } = Promise.withResolvers(); + const { promise, resolve } = Promise.withResolvers(); using listener = Bun.listen({ port: 0, hostname: "localhost", @@ -29,12 +29,12 @@ describe.todoIf(isWindows)("fetch.preconnect", () => { expect(response.status).toBe(200); }); - describe("doesn't break the request when", () => { + describe.concurrent("doesn't break the request when", () => { for (let endOrTerminate of ["end", "terminate", "shutdown"]) { describe(endOrTerminate, () => { for (let at of ["before", "middle", "after"]) { it(at, async () => { - let { promise, resolve } = Promise.withResolvers(); + let { promise, resolve } = Promise.withResolvers(); using listener = Bun.listen({ port: 0, hostname: "localhost", @@ -48,7 +48,7 @@ describe.todoIf(isWindows)("fetch.preconnect", () => { }); fetch.preconnect(`http://localhost:${listener.port}`); let socket = await promise; - ({ promise, resolve } = Promise.withResolvers()); + ({ promise, resolve } = Promise.withResolvers()); if (at === "before") { await Bun.sleep(16); socket[endOrTerminate](); @@ -86,7 +86,7 @@ describe.todoIf(isWindows)("fetch.preconnect", () => { }); it("--fetch-preconnect works", async () => { - const { promise, resolve } = Promise.withResolvers(); + const { promise, resolve } = Promise.withResolvers(); using listener = Bun.listen({ port: 0, hostname: "localhost", @@ -102,7 +102,13 @@ describe.todoIf(isWindows)("fetch.preconnect", () => { }); // Do --fetch-preconnect, but don't actually send a request. - expect([`--fetch-preconnect=http://localhost:${listener.port}`, "--eval", "Bun.sleep(64)"]).toRun(); + await using proc = Bun.spawn({ + cmd: [bunExe(), `--fetch-preconnect=http://localhost:${listener.port}`, "--eval", "Bun.sleep(64)"], + stdio: ["inherit", "inherit", "inherit"], + env: bunEnv, + }); + + expect(await proc.exited).toBe(0); await promise; }); From da0babebd2988e42df00300c4439250fd41f4b62 Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Fri, 26 Sep 2025 18:20:44 -0800 Subject: [PATCH 37/43] node:http2: fix leak in H2FrameParser (#22997) --- src/bun.js/api/bun/h2_frame_parser.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bun.js/api/bun/h2_frame_parser.zig b/src/bun.js/api/bun/h2_frame_parser.zig index 8f032d36e0..4e7f86db3b 100644 --- a/src/bun.js/api/bun/h2_frame_parser.zig +++ b/src/bun.js/api/bun/h2_frame_parser.zig @@ -4270,8 +4270,8 @@ pub const H2FrameParser = struct { } pub fn detachNativeSocket(this: *H2FrameParser) void { - this.native_socket = .{ .none = {} }; const native_socket = this.native_socket; + this.native_socket = .{ .none = {} }; switch (native_socket) { inline .tcp, .tls => |socket| { From 665ea96076b90680e95feb7dba66ef7dd182ce87 Mon Sep 17 00:00:00 2001 From: Jarred Sumner Date: Fri, 26 Sep 2025 19:22:02 -0700 Subject: [PATCH 38/43] Deflake test/js/bun/http/bun-server.test.ts --- test/js/bun/http/bun-server.test.ts | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/js/bun/http/bun-server.test.ts b/test/js/bun/http/bun-server.test.ts index a5ff27bceb..33c108b00a 100644 --- a/test/js/bun/http/bun-server.test.ts +++ b/test/js/bun/http/bun-server.test.ts @@ -201,7 +201,9 @@ describe.concurrent("Server", () => { }); try { - await fetch(`http://${server.hostname}:${server.port}`, { signal: abortController.signal }); + await fetch(`http://${server.hostname}:${server.port}`, { signal: abortController.signal }).then(res => + res.text(), + ); } catch (err: any) { expect(err).toBeDefined(); expect(err?.name).toBe("AbortError"); @@ -230,7 +232,9 @@ describe.concurrent("Server", () => { }); try { - await fetch(`http://${server.hostname}:${server.port}`, { signal: abortController.signal }); + await fetch(`http://${server.hostname}:${server.port}`, { signal: abortController.signal }).then(res => + res.text(), + ); } catch { fetchAborted = true; } @@ -279,7 +283,9 @@ describe.concurrent("Server", () => { }); try { - await fetch(`http://${server.hostname}:${server.port}`, { signal: abortController.signal }); + await fetch(`http://${server.hostname}:${server.port}`, { signal: abortController.signal }).then(res => + res.text(), + ); } catch {} await Bun.sleep(10); expect(signalOnServer).toBe(true); @@ -375,7 +381,9 @@ describe.concurrent("Server", () => { }); try { - await fetch(`http://${server.hostname}:${server.port}`, { signal: abortController.signal }); + await fetch(`http://${server.hostname}:${server.port}`, { signal: abortController.signal }).then(res => + res.text(), + ); } catch {} await Bun.sleep(10); expect(signalOnServer).toBe(true); From ed72eff2a9984f6c37696803889ec6b3d983f4ef Mon Sep 17 00:00:00 2001 From: "taylor.fish" Date: Fri, 26 Sep 2025 19:46:50 -0700 Subject: [PATCH 39/43] `bun.ptr.Shared`: fix dependency loop and add leak/adopt methods (#23024) Fixes STAB-1292 --- src/ptr/external_shared.zig | 7 ++++-- src/ptr/owned.zig | 12 +++++------ src/ptr/shared.zig | 43 +++++++++++++++++++++++++------------ 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/ptr/external_shared.zig b/src/ptr/external_shared.zig index de59780b7d..bcc62e07ba 100644 --- a/src/ptr/external_shared.zig +++ b/src/ptr/external_shared.zig @@ -7,11 +7,14 @@ /// pub fn deref(T*) void; /// }; pub fn ExternalShared(comptime T: type) type { - _ = T.external_shared_descriptor.ref; // must define a `ref` function - _ = T.external_shared_descriptor.deref; // must define a `deref` function return struct { const Self = @This(); + comptime { + _ = T.external_shared_descriptor.ref; // must define a `ref` function + _ = T.external_shared_descriptor.deref; // must define a `deref` function + } + #impl: *T, /// `incremented_raw` should have already had its ref count incremented by 1. diff --git a/src/ptr/owned.zig b/src/ptr/owned.zig index 3dd2b36d2c..aa99ae7def 100644 --- a/src/ptr/owned.zig +++ b/src/ptr/owned.zig @@ -34,13 +34,12 @@ pub fn Dynamic(comptime Pointer: type) type { /// If `Allocator` is a zero-sized type, the owned pointer has no overhead compared to a raw /// pointer. pub fn OwnedIn(comptime Pointer: type, comptime Allocator: type) type { - const info = PointerInfo.parse(Pointer, .{}); - const NonOptionalPointer = info.NonOptionalPointer; - const Child = info.Child; - const ConstPointer = AddConst(Pointer); - return struct { const Self = @This(); + const info = PointerInfo.parse(Pointer, .{}); + const NonOptionalPointer = info.NonOptionalPointer; + const Child = info.Child; + const ConstPointer = AddConst(Pointer); #pointer: Pointer, #allocator: Allocator, @@ -381,10 +380,9 @@ pub fn OwnedIn(comptime Pointer: type, comptime Allocator: type) type { /// /// This type is accessible as `OwnedIn(Pointer, Allocator).Unmanaged`. fn Unmanaged(comptime Pointer: type, comptime Allocator: type) type { - const info = PointerInfo.parse(Pointer, .{}); - return struct { const Self = @This(); + const info = PointerInfo.parse(Pointer, .{}); #pointer: Pointer, diff --git a/src/ptr/shared.zig b/src/ptr/shared.zig index 208922a981..f9ff6b6951 100644 --- a/src/ptr/shared.zig +++ b/src/ptr/shared.zig @@ -61,12 +61,6 @@ pub fn AtomicSharedIn(comptime Pointer: type, comptime Allocator: type) type { /// Like `Shared`, but takes explicit options. pub fn WithOptions(comptime Pointer: type, comptime options: Options) type { - const Allocator = options.Allocator; - const info = parsePointer(Pointer); - const Child = info.Child; - const NonOptionalPointer = info.NonOptionalPointer; - const Data = FullData(Child, options); - if (options.allow_weak) { // Weak pointers only make sense if `deinit` will be called, since their only function // is to ensure `deinit` can be called before the memory is freed (weak pointers keep @@ -84,6 +78,11 @@ pub fn WithOptions(comptime Pointer: type, comptime options: Options) type { return struct { const Self = @This(); + const Allocator = options.Allocator; + const info = parsePointer(Pointer); + const Child = info.Child; + const NonOptionalPointer = info.NonOptionalPointer; + const Data = FullData(Child, options); #pointer: Pointer, @@ -239,11 +238,28 @@ pub fn WithOptions(comptime Pointer: type, comptime options: Options) type { return .fromValuePtr(self.#pointer); } + /// Turns a shared pointer into a raw pointer without decrementing the reference count. + /// + /// This method invalidates `self`. To avoid leaks, the raw pointer should be turned back + /// into a shared pointer with `adoptRawUnsafe`. + pub fn leak(self: *Self) Pointer { + defer self.* = undefined; + return self.#pointer; + } + + /// Creates a shared pointer from a raw pointer returned by `leak`. + /// + /// `pointer` must have been previously returned by `leak`. `adoptRawUnsafe` should not be + /// called again on this pointer. + pub fn adoptRawUnsafe(pointer: Pointer) Self { + return .{ .#pointer = pointer }; + } + /// Clones a shared pointer, given a raw pointer that originally came from a shared pointer. /// - /// `pointer` must have come from a shared pointer (e.g., from `get` or `leak`), and the shared - /// pointer from which it came must remain valid (i.e., not be deinitialized) at least until - /// this function returns. + /// `pointer` must have come from a shared pointer, and the shared pointer from which it + /// came must remain valid (i.e., not be deinitialized) at least until this function + /// returns. pub fn cloneFromRawUnsafe(pointer: Pointer) Self { const temp: Self = .{ .#pointer = pointer }; return temp.clone(); @@ -252,11 +268,6 @@ pub fn WithOptions(comptime Pointer: type, comptime options: Options) type { } fn Weak(comptime Pointer: type, comptime options: Options) type { - const info = parsePointer(Pointer); - const Child = info.Child; - const NonOptionalPointer = info.NonOptionalPointer; - const Data = FullData(Child, options); - bun.assertf( options.allow_weak and options.deinit, "options incompatible with shared.Weak", @@ -265,6 +276,10 @@ fn Weak(comptime Pointer: type, comptime options: Options) type { return struct { const Self = @This(); + const info = parsePointer(Pointer); + const Child = info.Child; + const NonOptionalPointer = info.NonOptionalPointer; + const Data = FullData(Child, options); #pointer: Pointer, From 8102e80f8887987e64f26ffa2a726cd00f986bf8 Mon Sep 17 00:00:00 2001 From: Dylan Conway Date: Fri, 26 Sep 2025 22:21:00 -0700 Subject: [PATCH 40/43] fix(build): Promise.all() async module dependencies (#22704) ### What does this PR do? Currently bundling and running projects with cyclic async module dependencies will hang due to module promises never resolving. This PR unblocks these projects by outputting `await Promise.all` with these dependencies. Before (will hang with bun, or error with unsettled top level await with node): ```js var __esm = (fn, res) => () => (fn && (res = fn((fn = 0))), res); var init_mod3 = __esm(async () => { await init_mod1(); }); var init_mod2 = __esm(async () => { await init_mod1(); }); var init_mod1 = __esm(async () => { await init_mod2(); await init_mod3(); }); await init_mod1(); ``` After: ```js var __esm = (fn, res) => () => (fn && (res = fn((fn = 0))), res); var __promiseAll = Promise.all.bind(Promise); var init_mod3 = __esm(async () => { await init_mod1(); }); var init_mod2 = __esm(async () => { await init_mod1(); }); var init_mod1 = __esm(async () => { await __promiseAll([init_mod2(), init_mod3()]); }); await init_mod1(); ``` ### How did you verify your code works? Manually and tests --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jarred Sumner --- src/bundler/LinkerContext.zig | 264 +++++++-- src/bundler/LinkerGraph.zig | 8 +- src/bundler/bundle_v2.zig | 5 + src/bundler/linker_context/computeChunks.zig | 5 +- .../computeCrossChunkDependencies.zig | 2 +- .../linker_context/convertStmtsForChunk.zig | 18 +- .../convertStmtsForChunkForDevServer.zig | 10 +- .../generateCodeForFileInChunkJS.zig | 38 +- .../generateCodeForLazyExport.zig | 14 +- .../linker_context/scanImportsAndExports.zig | 54 +- src/defines-table.zig | 4 + src/runtime.js | 2 + src/runtime.zig | 2 + .../bundler_promiseall_deadcode.test.ts | 530 ++++++++++++++++++ .../cyclic-imports-async-bundler.test.js | 91 +++ 15 files changed, 917 insertions(+), 130 deletions(-) create mode 100644 test/bundler/bundler_promiseall_deadcode.test.ts diff --git a/src/bundler/LinkerContext.zig b/src/bundler/LinkerContext.zig index e05da53777..364fbf5149 100644 --- a/src/bundler/LinkerContext.zig +++ b/src/bundler/LinkerContext.zig @@ -19,6 +19,9 @@ pub const LinkerContext = struct { /// We may need to refer to the CommonJS "module" symbol for exports unbound_module_ref: Ref = Ref.None, + /// We may need to refer to the "__promiseAll" runtime symbol + promise_all_runtime_ref: Ref = Ref.None, + options: LinkerOptions = .{}, loop: EventLoop, @@ -217,6 +220,7 @@ pub const LinkerContext = struct { this.esm_runtime_ref = runtime_named_exports.get("__esm").?.ref; this.cjs_runtime_ref = runtime_named_exports.get("__commonJS").?.ref; + this.promise_all_runtime_ref = runtime_named_exports.get("__promiseAll").?.ref; if (this.options.output_format == .cjs) { this.unbound_module_ref = this.graph.generateNewSymbol(Index.runtime.get(), .unbound, "module"); @@ -336,13 +340,18 @@ pub const LinkerContext = struct { } } + const LinkError = OOM || error{ + BuildFailed, + ImportResolutionFailed, + }; + pub noinline fn link( this: *LinkerContext, bundle: *BundleV2, entry_points: []Index, server_component_boundaries: ServerComponentBoundary.List, reachable: []Index, - ) ![]Chunk { + ) LinkError![]Chunk { try this.load( bundle, entry_points, @@ -360,10 +369,8 @@ pub const LinkerContext = struct { this.checkForMemoryCorruption(); } - // Validate top-level await for all files first, like esbuild does. - // This must be done before scanImportsAndExports to ensure async - // status is properly propagated through cyclic imports. - { + // Validate top-level await for all files first. + if (bundle.has_any_top_level_await_modules) { const import_records_list: []ImportRecord.List = this.graph.ast.items(.import_records); const tla_keywords = this.parse_graph.ast.items(.top_level_await_keyword); const tla_checks = this.parse_graph.ast.items(.tla_check); @@ -388,9 +395,11 @@ pub const LinkerContext = struct { _ = try this.validateTLA(source_index, tla_keywords, tla_checks, input_files, import_records, flags, import_records_list); } + + // after validation propagate async through all importers. + try this.graph.propagateAsyncDependencies(); } - try this.graph.propagateAsyncDependencies(); try this.scanImportsAndExports(); // Stop now if there were errors @@ -410,6 +419,10 @@ pub const LinkerContext = struct { const chunks = try this.computeChunks(bundle.unique_key); + if (this.log.hasErrors()) { + return error.BuildFailed; + } + if (comptime FeatureFlags.help_catch_memory_issues) { this.checkForMemoryCorruption(); } @@ -438,32 +451,32 @@ pub const LinkerContext = struct { pub const findImportedFilesInCSSOrder = @import("./linker_context/findImportedFilesInCSSOrder.zig").findImportedFilesInCSSOrder; pub const findImportedCSSFilesInJSOrder = @import("./linker_context/findImportedCSSFilesInJSOrder.zig").findImportedCSSFilesInJSOrder; - pub fn generateNamedExportInFile(this: *LinkerContext, source_index: Index.Int, module_ref: Ref, name: []const u8, alias: []const u8) !struct { Ref, u32 } { + pub fn generateNamedExportInFile(this: *LinkerContext, source_index: Index.Int, module_ref: Ref, name: []const u8, alias: []const u8) bun.OOM!struct { Ref, u32 } { const ref = this.graph.generateNewSymbol(source_index, .other, name); - const part_index = this.graph.addPartToFile(source_index, .{ - .declared_symbols = js_ast.DeclaredSymbol.List.fromSlice( + const part_index = try this.graph.addPartToFile(source_index, .{ + .declared_symbols = try js_ast.DeclaredSymbol.List.fromSlice( this.allocator(), &[_]js_ast.DeclaredSymbol{ .{ .ref = ref, .is_top_level = true }, }, - ) catch unreachable, + ), .can_be_removed_if_unused = true, - }) catch unreachable; + }); try this.graph.generateSymbolImportAndUse(source_index, part_index, module_ref, 1, Index.init(source_index)); var top_level = &this.graph.meta.items(.top_level_symbol_to_parts_overlay)[source_index]; - var parts_list = this.allocator().alloc(u32, 1) catch unreachable; + var parts_list = try this.allocator().alloc(u32, 1); parts_list[0] = part_index; - top_level.put(this.allocator(), ref, BabyList(u32).fromOwnedSlice(parts_list)) catch unreachable; + try top_level.put(this.allocator(), ref, BabyList(u32).fromOwnedSlice(parts_list)); var resolved_exports = &this.graph.meta.items(.resolved_exports)[source_index]; - resolved_exports.put(this.allocator(), alias, ExportData{ + try resolved_exports.put(this.allocator(), alias, ExportData{ .data = ImportTracker{ .source_index = Index.init(source_index), .import_ref = ref, }, - }) catch unreachable; + }); return .{ ref, part_index }; } @@ -545,7 +558,7 @@ pub const LinkerContext = struct { return &c.parse_graph.input_files.items(.source)[index]; } - pub fn treeShakingAndCodeSplitting(c: *LinkerContext) !void { + pub fn treeShakingAndCodeSplitting(c: *LinkerContext) OOM!void { const trace = bun.perf.trace("Bundler.treeShakingAndCodeSplitting"); defer trace.end(); @@ -932,7 +945,7 @@ pub const LinkerContext = struct { import_records: []const ImportRecord, meta_flags: []JSMeta.Flags, ast_import_records: []const bun.BabyList(ImportRecord), - ) bun.OOM!js_ast.TlaCheck { + ) OOM!js_ast.TlaCheck { var result_tla_check: *js_ast.TlaCheck = &tla_checks[source_index]; if (result_tla_check.depth == 0) { @@ -1029,32 +1042,129 @@ pub const LinkerContext = struct { } pub const StmtList = struct { - inside_wrapper_prefix: std.ArrayList(Stmt), - outside_wrapper_prefix: std.ArrayList(Stmt), - inside_wrapper_suffix: std.ArrayList(Stmt), + const InsideWrapperPrefix = struct { + allocator: std.mem.Allocator, + stmts: std.ArrayListUnmanaged(Stmt), - all_stmts: std.ArrayList(Stmt), + sync_dependencies_end: usize, + + // if true it will exist at `sync_dependencies_end` + has_async_dependency: bool, + + pub fn init(alloc: std.mem.Allocator) InsideWrapperPrefix { + return .{ .stmts = .{}, .allocator = alloc, .sync_dependencies_end = 0, .has_async_dependency = false }; + } + + pub fn deinit(this: *InsideWrapperPrefix) void { + this.stmts.deinit(this.allocator); + this.sync_dependencies_end = 0; + this.has_async_dependency = false; + } + + pub fn reset(this: *InsideWrapperPrefix) void { + this.stmts.clearRetainingCapacity(); + this.sync_dependencies_end = 0; + this.has_async_dependency = false; + } + + pub fn appendNonDependency(this: *InsideWrapperPrefix, stmt: Stmt) OOM!void { + try this.stmts.append(this.allocator, stmt); + } + + pub fn appendNonDependencySlice(this: *InsideWrapperPrefix, stmts: []const Stmt) OOM!void { + try this.stmts.appendSlice(this.allocator, stmts); + } + + pub fn appendSyncDependency(this: *InsideWrapperPrefix, call_expr: Expr) OOM!void { + try this.stmts.insert(this.allocator, this.sync_dependencies_end, Stmt.alloc(S.SExpr, .{ .value = call_expr }, call_expr.loc)); + this.sync_dependencies_end += 1; + } + + pub fn appendAsyncDependency(this: *InsideWrapperPrefix, call_expr: Expr, promise_all_ref: Ref) OOM!void { + if (!this.has_async_dependency) { + this.has_async_dependency = true; + try this.stmts.insert( + this.allocator, + this.sync_dependencies_end, + Stmt.alloc(S.SExpr, .{ .value = Expr.init(E.Await, .{ .value = call_expr }, .Empty) }, .Empty), + ); + return; + } + + const first_dep_call_expr = this.stmts.items[this.sync_dependencies_end].data.s_expr.value.data.e_await.value; + const call = first_dep_call_expr.data.e_call; + + if (call.target.data.e_identifier.ref.eql(promise_all_ref)) { + // `await __promiseAll` already in place, append to the array argument + try call.args.at(0).data.e_array.items.append(this.allocator, call_expr); + } else { + // convert single `await init_` to `await __promiseAll([init_1(), init_2()])` + + const promise_all = Expr.init(E.Identifier, .{ .ref = promise_all_ref }, .Empty); + + var items: BabyList(Expr) = try .initCapacity(this.allocator, 2); + items.appendSliceAssumeCapacity(&.{ first_dep_call_expr, call_expr }); + + var args: BabyList(Expr) = try .initCapacity(this.allocator, 1); + args.appendAssumeCapacity(Expr.init(E.Array, .{ .items = items }, .Empty)); + + const promise_all_call = Expr.init(E.Call, .{ .target = promise_all, .args = args }, .Empty); + + // replace the `await init_` expr with `await __promiseAll` + this.stmts.items[this.sync_dependencies_end] = Stmt.alloc(S.SExpr, .{ .value = Expr.init(E.Await, .{ .value = promise_all_call }, .Empty) }, .Empty); + } + } + }; + + allocator: std.mem.Allocator, + inside_wrapper_prefix: InsideWrapperPrefix, + outside_wrapper_prefix: std.ArrayListUnmanaged(Stmt), + inside_wrapper_suffix: std.ArrayListUnmanaged(Stmt), + all_stmts: std.ArrayListUnmanaged(Stmt), pub fn reset(this: *StmtList) void { - this.inside_wrapper_prefix.clearRetainingCapacity(); + this.inside_wrapper_prefix.reset(); this.outside_wrapper_prefix.clearRetainingCapacity(); this.inside_wrapper_suffix.clearRetainingCapacity(); this.all_stmts.clearRetainingCapacity(); } pub fn deinit(this: *StmtList) void { - this.inside_wrapper_prefix.deinit(); - this.outside_wrapper_prefix.deinit(); - this.inside_wrapper_suffix.deinit(); - this.all_stmts.deinit(); + this.inside_wrapper_prefix.deinit(this.allocator); + this.outside_wrapper_prefix.deinit(this.allocator); + this.inside_wrapper_suffix.deinit(this.allocator); + this.all_stmts.deinit(this.allocator); } pub fn init(alloc: std.mem.Allocator) StmtList { return .{ - .inside_wrapper_prefix = std.ArrayList(Stmt).init(alloc), - .outside_wrapper_prefix = std.ArrayList(Stmt).init(alloc), - .inside_wrapper_suffix = std.ArrayList(Stmt).init(alloc), - .all_stmts = std.ArrayList(Stmt).init(alloc), + .allocator = alloc, + .inside_wrapper_prefix = .init(alloc), + .outside_wrapper_prefix = .{}, + .inside_wrapper_suffix = .{}, + .all_stmts = .{}, + }; + } + + const List = enum { + outside_wrapper_prefix, + inside_wrapper_suffix, + all_stmts, + }; + + pub fn appendSlice(this: *StmtList, list: List, stmts: []const Stmt) OOM!void { + try switch (list) { + .outside_wrapper_prefix => this.outside_wrapper_prefix.appendSlice(this.allocator, stmts), + .inside_wrapper_suffix => this.inside_wrapper_suffix.appendSlice(this.allocator, stmts), + .all_stmts => this.all_stmts.appendSlice(this.allocator, stmts), + }; + } + + pub fn append(this: *StmtList, list: List, stmt: Stmt) OOM!void { + try switch (list) { + .outside_wrapper_prefix => this.outside_wrapper_prefix.append(this.allocator, stmt), + .inside_wrapper_suffix => this.inside_wrapper_suffix.append(this.allocator, stmt), + .all_stmts => this.all_stmts.append(this.allocator, stmt), }; } }; @@ -1077,7 +1187,7 @@ pub const LinkerContext = struct { } // Otherwise, replace this statement with a call to "require()" - stmts.inside_wrapper_prefix.append( + stmts.inside_wrapper_prefix.appendNonDependency( Stmt.alloc( S.Local, S.Local{ @@ -1120,7 +1230,7 @@ pub const LinkerContext = struct { .none => {}, .cjs => { // Replace the statement with a call to "require()" if this module is not wrapped - try stmts.inside_wrapper_prefix.append( + try stmts.inside_wrapper_prefix.appendNonDependency( Stmt.alloc(S.Local, .{ .decls = try G.Decl.List.fromSlice( alloc, @@ -1152,31 +1262,18 @@ pub const LinkerContext = struct { } // Replace the statement with a call to "init()" - const value: Expr = brk: { - const default = Expr.init(E.Call, .{ - .target = Expr.initIdentifier( - wrapper_ref, - loc, - ), - }, loc); + const init_call = Expr.init(E.Call, .{ + .target = Expr.initIdentifier( + wrapper_ref, + loc, + ), + }, loc); - if (other_flags.is_async_or_has_async_dependency) { - // This currently evaluates sibling dependencies in serial instead of in - // parallel, which is incorrect. This should be changed to store a promise - // and await all stored promises after all imports but before any code. - break :brk Expr.init(E.Await, .{ - .value = default, - }, loc); - } - - break :brk default; - }; - - try stmts.inside_wrapper_prefix.append( - Stmt.alloc(S.SExpr, .{ - .value = value, - }, loc), - ); + if (other_flags.is_async_or_has_async_dependency) { + try stmts.inside_wrapper_prefix.appendAsyncDependency(init_call, c.promise_all_runtime_ref); + } else { + try stmts.inside_wrapper_prefix.appendSyncDependency(init_call); + } }, } @@ -2104,18 +2201,53 @@ pub const LinkerContext = struct { // // This depends on the "__esm" symbol and declares the "init_foo" symbol // for similar reasons to the CommonJS closure above. + + // Count async dependencies to determine if we need __promiseAll + var async_import_count: usize = 0; + const import_records = c.graph.ast.items(.import_records)[source_index].slice(); + const meta_flags = c.graph.meta.items(.flags); + + for (import_records) |record| { + if (!record.source_index.isValid()) { + continue; + } + const other_flags = meta_flags[record.source_index.get()]; + if (other_flags.is_async_or_has_async_dependency) { + async_import_count += 1; + if (async_import_count >= 2) { + break; + } + } + } + + const needs_promise_all = async_import_count >= 2; + const esm_parts = if (wrapper_ref.isValid() and c.options.output_format != .internal_bake_dev) c.topLevelSymbolsToPartsForRuntime(c.esm_runtime_ref) else &.{}; - // generate a dummy part that depends on the "__esm" symbol - const dependencies = c.allocator().alloc(js_ast.Dependency, esm_parts.len) catch unreachable; - for (esm_parts, dependencies) |part, *esm| { - esm.* = .{ + const promise_all_parts = if (needs_promise_all and wrapper_ref.isValid() and c.options.output_format != .internal_bake_dev) + c.topLevelSymbolsToPartsForRuntime(c.promise_all_runtime_ref) + else + &.{}; + + // generate a dummy part that depends on the "__esm" and optionally "__promiseAll" symbols + const dependencies = c.allocator().alloc(js_ast.Dependency, esm_parts.len + promise_all_parts.len) catch unreachable; + var dep_index: usize = 0; + for (esm_parts) |part| { + dependencies[dep_index] = .{ .part_index = part, .source_index = Index.runtime, }; + dep_index += 1; + } + for (promise_all_parts) |part| { + dependencies[dep_index] = .{ + .part_index = part, + .source_index = Index.runtime, + }; + dep_index += 1; } var symbol_uses: Part.SymbolUseMap = .empty; @@ -2140,6 +2272,17 @@ pub const LinkerContext = struct { 1, Index.runtime, ) catch |err| bun.handleOom(err); + + // Only mark __promiseAll as used if we have multiple async dependencies + if (needs_promise_all) { + c.graph.generateSymbolImportAndUse( + source_index, + part_index, + c.promise_all_runtime_ref, + 1, + Index.runtime, + ) catch |err| bun.handleOom(err); + } } }, else => {}, @@ -2538,6 +2681,7 @@ const FeatureFlags = bun.FeatureFlags; const ImportRecord = bun.ImportRecord; const MultiArrayList = bun.MultiArrayList; const MutableString = bun.MutableString; +const OOM = bun.OOM; const Output = bun.Output; const StringJoiner = bun.StringJoiner; const bake = bun.bake; diff --git a/src/bundler/LinkerGraph.zig b/src/bundler/LinkerGraph.zig index 0e436c0771..c160e96c28 100644 --- a/src/bundler/LinkerGraph.zig +++ b/src/bundler/LinkerGraph.zig @@ -78,7 +78,7 @@ pub fn generateRuntimeSymbolImportAndUse( entry_point_part_index: Index, name: []const u8, count: u32, -) !void { +) bun.OOM!void { if (count == 0) return; debug("generateRuntimeSymbolImportAndUse({s}) for {d}", .{ name, source_index }); @@ -96,7 +96,7 @@ pub fn addPartToFile( graph: *LinkerGraph, id: u32, part: Part, -) !u32 { +) bun.OOM!u32 { var parts: *Part.List = &graph.ast.items(.parts)[id]; const part_id = @as(u32, @truncate(parts.len)); try parts.append(graph.allocator, part); @@ -157,7 +157,7 @@ pub fn generateSymbolImportAndUse( ref: Ref, use_count: u32, source_index_to_import_from: Index, -) !void { +) bun.OOM!void { if (use_count == 0) return; var parts_list = g.ast.items(.parts)[source_index].slice(); @@ -166,7 +166,7 @@ pub fn generateSymbolImportAndUse( // Mark this symbol as used by this part var uses = &part.symbol_uses; - var uses_entry = uses.getOrPut(g.allocator, ref) catch unreachable; + var uses_entry = try uses.getOrPut(g.allocator, ref); if (!uses_entry.found_existing) { uses_entry.value_ptr.* = .{ .count_estimate = use_count }; diff --git a/src/bundler/bundle_v2.zig b/src/bundler/bundle_v2.zig index f93b2904d1..430884450b 100644 --- a/src/bundler/bundle_v2.zig +++ b/src/bundler/bundle_v2.zig @@ -145,6 +145,9 @@ pub const BundleV2 = struct { asynchronous: bool = false, thread_lock: bun.safety.ThreadLock, + // if false we can skip TLA validation and propagation + has_any_top_level_await_modules: bool = false, + const BakeOptions = struct { framework: bake.Framework, client_transpiler: *Transpiler, @@ -3654,6 +3657,8 @@ pub const BundleV2 = struct { .success => |*result| { result.log.cloneToWithRecycled(this.transpiler.log, true) catch unreachable; + this.has_any_top_level_await_modules = this.has_any_top_level_await_modules or !result.ast.top_level_await_keyword.isEmpty(); + // Warning: `input_files` and `ast` arrays may resize in this function call // It is not safe to cache slices from them. graph.input_files.items(.source)[result.source.index.get()] = result.source; diff --git a/src/bundler/linker_context/computeChunks.zig b/src/bundler/linker_context/computeChunks.zig index 266d71c35a..fd98457e4c 100644 --- a/src/bundler/linker_context/computeChunks.zig +++ b/src/bundler/linker_context/computeChunks.zig @@ -389,7 +389,10 @@ pub noinline fn computeChunks( }; defer dir.close(); - break :dir try dir.getFdPath(&real_path_buf); + break :dir dir.getFdPath(&real_path_buf) catch |err| { + try this.log.addErrorFmt(null, .Empty, this.allocator(), "{s}: Failed to get full path for directory '{s}'", .{ @errorName(err), dir_path }); + return error.BuildFailed; + }; }; chunk.template.placeholder.dir = try resolve_path.relativeAlloc(this.allocator(), this.resolver.opts.root_dir, dir); diff --git a/src/bundler/linker_context/computeCrossChunkDependencies.zig b/src/bundler/linker_context/computeCrossChunkDependencies.zig index 8e4c3ba3ac..68f3756350 100644 --- a/src/bundler/linker_context/computeCrossChunkDependencies.zig +++ b/src/bundler/linker_context/computeCrossChunkDependencies.zig @@ -1,4 +1,4 @@ -pub fn computeCrossChunkDependencies(c: *LinkerContext, chunks: []Chunk) !void { +pub fn computeCrossChunkDependencies(c: *LinkerContext, chunks: []Chunk) bun.OOM!void { if (!c.graph.code_splitting) { // No need to compute cross-chunk dependencies if there can't be any return; diff --git a/src/bundler/linker_context/convertStmtsForChunk.zig b/src/bundler/linker_context/convertStmtsForChunk.zig index eb9eb2a3f8..e5a7cd6b0c 100644 --- a/src/bundler/linker_context/convertStmtsForChunk.zig +++ b/src/bundler/linker_context/convertStmtsForChunk.zig @@ -87,7 +87,7 @@ pub fn convertStmtsForChunk( // Make sure these don't end up in the wrapper closure if (shouldExtractESMStmtsForWrap) { - try stmts.outside_wrapper_prefix.append(stmt); + try stmts.append(.outside_wrapper_prefix, stmt); continue; } }, @@ -120,7 +120,7 @@ pub fn convertStmtsForChunk( // Make sure these don't end up in the wrapper closure if (shouldExtractESMStmtsForWrap) { - try stmts.outside_wrapper_prefix.append(stmt); + try stmts.append(.outside_wrapper_prefix, stmt); continue; } @@ -172,7 +172,7 @@ pub fn convertStmtsForChunk( args[3] = mod; } - try stmts.inside_wrapper_prefix.append( + try stmts.inside_wrapper_prefix.appendNonDependency( Stmt.alloc( S.SExpr, S.SExpr{ @@ -199,7 +199,7 @@ pub fn convertStmtsForChunk( // Make sure these don't end up in the wrapper closure if (shouldExtractESMStmtsForWrap) { - try stmts.outside_wrapper_prefix.append(stmt); + try stmts.append(.outside_wrapper_prefix, stmt); continue; } } @@ -208,7 +208,7 @@ pub fn convertStmtsForChunk( const flag = flags[record.source_index.get()]; const wrapper_ref = c.graph.ast.items(.wrapper_ref)[record.source_index.get()]; if (flag.wrap == .esm and wrapper_ref.isValid()) { - try stmts.inside_wrapper_prefix.append( + try stmts.inside_wrapper_prefix.appendNonDependency( Stmt.alloc(S.SExpr, .{ .value = Expr.init(E.Call, .{ .target = Expr.init( @@ -258,7 +258,7 @@ pub fn convertStmtsForChunk( args[2] = mod; } - try stmts.inside_wrapper_prefix.append( + try stmts.inside_wrapper_prefix.appendNonDependency( Stmt.alloc( S.SExpr, S.SExpr{ @@ -326,7 +326,7 @@ pub fn convertStmtsForChunk( // Make sure these don't end up in the wrapper closure if (shouldExtractESMStmtsForWrap) { - try stmts.outside_wrapper_prefix.append(stmt); + try stmts.append(.outside_wrapper_prefix, stmt); continue; } }, @@ -341,7 +341,7 @@ pub fn convertStmtsForChunk( // Make sure these don't end up in the wrapper closure if (shouldExtractESMStmtsForWrap) { - try stmts.outside_wrapper_prefix.append(stmt); + try stmts.append(.outside_wrapper_prefix, stmt); continue; } }, @@ -519,7 +519,7 @@ pub fn convertStmtsForChunk( } } - try stmts.inside_wrapper_suffix.append(stmt); + try stmts.append(.inside_wrapper_suffix, stmt); } } diff --git a/src/bundler/linker_context/convertStmtsForChunkForDevServer.zig b/src/bundler/linker_context/convertStmtsForChunkForDevServer.zig index 144dc62ea2..69010f0560 100644 --- a/src/bundler/linker_context/convertStmtsForChunkForDevServer.zig +++ b/src/bundler/linker_context/convertStmtsForChunkForDevServer.zig @@ -54,7 +54,7 @@ pub fn convertStmtsForChunkForDevServer( // Modules which do not have side effects for (part_stmts) |stmt| switch (stmt.data) { - else => try stmts.inside_wrapper_suffix.append(stmt), + else => try stmts.append(.inside_wrapper_suffix, stmt), .s_import => |st| { const record = ast.import_records.mut(st.import_record_index); @@ -78,7 +78,7 @@ pub fn convertStmtsForChunkForDevServer( }, stmt.loc); // var namespace = ...; - try stmts.inside_wrapper_prefix.append(Stmt.alloc(S.Local, .{ + try stmts.inside_wrapper_prefix.appendNonDependency(Stmt.alloc(S.Local, .{ .kind = .k_var, // remove a tdz .decls = try G.Decl.List.fromSlice(allocator, &.{.{ .binding = Binding.alloc( @@ -113,14 +113,14 @@ pub fn convertStmtsForChunkForDevServer( }, .Empty)); } - try stmts.outside_wrapper_prefix.append(stmt); + try stmts.append(.outside_wrapper_prefix, stmt); } }, }; if (esm_decls.items.len > 0) { // var ...; - try stmts.inside_wrapper_prefix.append(Stmt.alloc(S.Local, .{ + try stmts.inside_wrapper_prefix.appendNonDependency(Stmt.alloc(S.Local, .{ .kind = .k_var, // remove a tdz .decls = try .fromSlice(allocator, &.{.{ .binding = Binding.alloc(allocator, B.Array{ @@ -135,7 +135,7 @@ pub fn convertStmtsForChunkForDevServer( }}), }, .Empty)); // hmr.onUpdate = [ ... ]; - try stmts.inside_wrapper_prefix.append(Stmt.alloc(S.SExpr, .{ + try stmts.inside_wrapper_prefix.appendNonDependency(Stmt.alloc(S.SExpr, .{ .value = Expr.init(E.Binary, .{ .op = .bin_assign, .left = Expr.init(E.Dot, .{ diff --git a/src/bundler/linker_context/generateCodeForFileInChunkJS.zig b/src/bundler/linker_context/generateCodeForFileInChunkJS.zig index 514abc7e0b..eb2f7589fc 100644 --- a/src/bundler/linker_context/generateCodeForFileInChunkJS.zig +++ b/src/bundler/linker_context/generateCodeForFileInChunkJS.zig @@ -40,11 +40,11 @@ pub fn generateCodeForFileInChunkJS( return .{ .err = err }; } - const main_stmts_len = stmts.inside_wrapper_prefix.items.len + stmts.inside_wrapper_suffix.items.len; + const main_stmts_len = stmts.inside_wrapper_prefix.stmts.items.len + stmts.inside_wrapper_suffix.items.len; const all_stmts_len = main_stmts_len + stmts.outside_wrapper_prefix.items.len + 1; - bun.handleOom(stmts.all_stmts.ensureUnusedCapacity(all_stmts_len)); - stmts.all_stmts.appendSliceAssumeCapacity(stmts.inside_wrapper_prefix.items); + bun.handleOom(stmts.all_stmts.ensureUnusedCapacity(stmts.allocator, all_stmts_len)); + stmts.all_stmts.appendSliceAssumeCapacity(stmts.inside_wrapper_prefix.stmts.items); stmts.all_stmts.appendSliceAssumeCapacity(stmts.inside_wrapper_suffix.items); const inner = stmts.all_stmts.items[0..main_stmts_len]; @@ -128,7 +128,7 @@ pub fn generateCodeForFileInChunkJS( // The top-level directive must come first (the non-wrapped case is handled // by the chunk generation code, although only for the entry point) if (flags.wrap != .none and ast.flags.has_explicit_use_strict_directive and !chunk.isEntryPoint() and !output_format.isAlwaysStrictMode()) { - stmts.inside_wrapper_prefix.append(Stmt.alloc(S.Directive, .{ + stmts.inside_wrapper_prefix.appendNonDependency(Stmt.alloc(S.Directive, .{ .value = "use strict", }, Logger.Loc.Empty)) catch unreachable; } @@ -153,10 +153,10 @@ pub fn generateCodeForFileInChunkJS( switch (flags.wrap) { .esm => { - stmts.outside_wrapper_prefix.appendSlice(stmts.inside_wrapper_suffix.items) catch unreachable; + stmts.appendSlice(.outside_wrapper_prefix, stmts.inside_wrapper_suffix.items) catch unreachable; }, else => { - stmts.inside_wrapper_prefix.appendSlice(stmts.inside_wrapper_suffix.items) catch unreachable; + stmts.inside_wrapper_prefix.appendNonDependencySlice(stmts.inside_wrapper_suffix.items) catch unreachable; }, } @@ -291,11 +291,11 @@ pub fn generateCodeForFileInChunkJS( // evaluated (well, except for cyclic import scenarios). We need to preserve // these semantics even when modules imported via ES6 import statements end // up being CommonJS modules. - stmts.all_stmts.ensureUnusedCapacity(stmts.inside_wrapper_prefix.items.len + stmts.inside_wrapper_suffix.items.len) catch unreachable; - stmts.all_stmts.appendSliceAssumeCapacity(stmts.inside_wrapper_prefix.items); + stmts.all_stmts.ensureUnusedCapacity(stmts.allocator, stmts.inside_wrapper_prefix.stmts.items.len + stmts.inside_wrapper_suffix.items.len) catch unreachable; + stmts.all_stmts.appendSliceAssumeCapacity(stmts.inside_wrapper_prefix.stmts.items); stmts.all_stmts.appendSliceAssumeCapacity(stmts.inside_wrapper_suffix.items); - stmts.inside_wrapper_prefix.items.len = 0; - stmts.inside_wrapper_suffix.items.len = 0; + stmts.inside_wrapper_prefix.reset(); + stmts.inside_wrapper_suffix.clearRetainingCapacity(); if (c.options.minify_syntax) { mergeAdjacentLocalStmts(&stmts.all_stmts, temp_allocator); @@ -384,7 +384,8 @@ pub fn generateCodeForFileInChunkJS( .value = commonjs_wrapper_definition, }; - stmts.outside_wrapper_prefix.append( + stmts.append( + .outside_wrapper_prefix, Stmt.alloc( S.Local, S.Local{ @@ -468,12 +469,12 @@ pub fn generateCodeForFileInChunkJS( break :stmt Stmt.allocateExpr(temp_allocator, value); }, .s_function => { - bun.handleOom(stmts.outside_wrapper_prefix.append(stmt)); + bun.handleOom(stmts.append(.outside_wrapper_prefix, stmt)); continue; }, .s_class => |class| stmt: { if (class.class.canBeMoved()) { - bun.handleOom(stmts.outside_wrapper_prefix.append(stmt)); + bun.handleOom(stmts.append(.outside_wrapper_prefix, stmt)); continue; } @@ -498,7 +499,8 @@ pub fn generateCodeForFileInChunkJS( } if (hoist.decls.items.len > 0) { - stmts.outside_wrapper_prefix.append( + stmts.append( + .outside_wrapper_prefix, Stmt.alloc( S.Local, S.Local{ @@ -544,7 +546,8 @@ pub fn generateCodeForFileInChunkJS( .value = value, }; - stmts.outside_wrapper_prefix.append( + stmts.append( + .outside_wrapper_prefix, Stmt.alloc(S.Local, .{ .decls = G.Decl.List.fromOwnedSlice(decls), }, Logger.Loc.Empty), @@ -577,7 +580,8 @@ pub fn generateCodeForFileInChunkJS( }, }, Logger.Loc.Empty); - stmts.outside_wrapper_prefix.append( + stmts.append( + .outside_wrapper_prefix, Stmt.alloc(S.Local, .{ .decls = G.Decl.List.fromSlice(temp_allocator, &.{.{ .binding = Binding.alloc( @@ -624,7 +628,7 @@ pub fn generateCodeForFileInChunkJS( ); } -fn mergeAdjacentLocalStmts(stmts: *std.ArrayList(Stmt), allocator: std.mem.Allocator) void { +fn mergeAdjacentLocalStmts(stmts: *std.ArrayListUnmanaged(Stmt), allocator: std.mem.Allocator) void { if (stmts.items.len == 0) return; diff --git a/src/bundler/linker_context/generateCodeForLazyExport.zig b/src/bundler/linker_context/generateCodeForLazyExport.zig index 8e8d3746bf..c9b484ad73 100644 --- a/src/bundler/linker_context/generateCodeForLazyExport.zig +++ b/src/bundler/linker_context/generateCodeForLazyExport.zig @@ -1,4 +1,4 @@ -pub fn generateCodeForLazyExport(this: *LinkerContext, source_index: Index.Int) !void { +pub fn generateCodeForLazyExport(this: *LinkerContext, source_index: Index.Int) bun.OOM!void { const exports_kind = this.graph.ast.items(.exports_kind)[source_index]; const all_sources = this.parse_graph.input_files.items(.source); const all_css_asts = this.graph.ast.items(.css); @@ -333,12 +333,12 @@ pub fn generateCodeForLazyExport(this: *LinkerContext, source_index: Index.Int) // end up actually being used at this point (since import binding hasn't // happened yet). So we need to wait until after tree shaking happens. const generated = try this.generateNamedExportInFile(source_index, module_ref, name, name); - parts.ptr[generated[1]].stmts = this.allocator().alloc(Stmt, 1) catch unreachable; + parts.ptr[generated[1]].stmts = try this.allocator().alloc(Stmt, 1); parts.ptr[generated[1]].stmts[0] = Stmt.alloc( S.Local, S.Local{ .is_export = true, - .decls = js_ast.G.Decl.List.fromSlice( + .decls = try js_ast.G.Decl.List.fromSlice( this.allocator(), &.{ .{ @@ -352,7 +352,7 @@ pub fn generateCodeForLazyExport(this: *LinkerContext, source_index: Index.Int) .value = property.value.?, }, }, - ) catch unreachable, + ), }, property.key.?.loc, ); @@ -363,14 +363,14 @@ pub fn generateCodeForLazyExport(this: *LinkerContext, source_index: Index.Int) const generated = try this.generateNamedExportInFile( source_index, module_ref, - std.fmt.allocPrint( + try std.fmt.allocPrint( this.allocator(), "{}_default", .{this.parse_graph.input_files.items(.source)[source_index].fmtIdentifier()}, - ) catch unreachable, + ), "default", ); - parts.ptr[generated[1]].stmts = this.allocator().alloc(Stmt, 1) catch unreachable; + parts.ptr[generated[1]].stmts = try this.allocator().alloc(Stmt, 1); parts.ptr[generated[1]].stmts[0] = Stmt.alloc( S.ExportDefault, S.ExportDefault{ diff --git a/src/bundler/linker_context/scanImportsAndExports.zig b/src/bundler/linker_context/scanImportsAndExports.zig index 921f4fd74c..724ece8ebc 100644 --- a/src/bundler/linker_context/scanImportsAndExports.zig +++ b/src/bundler/linker_context/scanImportsAndExports.zig @@ -1,4 +1,6 @@ -pub fn scanImportsAndExports(this: *LinkerContext) !void { +pub const ScanImportsAndExportsError = bun.OOM || error{ImportResolutionFailed}; + +pub fn scanImportsAndExports(this: *LinkerContext) ScanImportsAndExportsError!void { const outer_trace = bun.perf.trace("Bundler.scanImportsAndExports"); defer outer_trace.end(); const reachable = this.graph.reachable_files; @@ -278,7 +280,7 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { .imports_to_bind = this.graph.meta.items(.imports_to_bind), - .source_index_stack = std.ArrayList(u32).initCapacity(this.allocator(), 32) catch unreachable, + .source_index_stack = try std.ArrayList(u32).initCapacity(this.allocator(), 32), .exports_kind = exports_kind, .named_exports = this.graph.ast.items(.named_exports), }; @@ -443,7 +445,7 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { break :brk count; }; - const string_buffer = this.allocator().alloc(u8, string_buffer_len) catch unreachable; + const string_buffer = try this.allocator().alloc(u8, string_buffer_len); var builder = bun.StringBuilder{ .len = 0, .cap = string_buffer.len, @@ -456,7 +458,7 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { // are necessary later. This is done now because the symbols map cannot be // mutated later due to parallelism. if (is_entry_point and output_format == .esm) { - const copies = this.allocator().alloc(Ref, aliases.len) catch unreachable; + const copies = try this.allocator().alloc(Ref, aliases.len); for (aliases, copies) |alias, *copy| { const original_name = builder.fmt("export_{}", .{bun.fmt.fmtIdentifier(alias)}); @@ -514,13 +516,13 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { bun.assert(runtime_export_symbol_ref.isValid()); - this.graph.generateSymbolImportAndUse( + try this.graph.generateSymbolImportAndUse( id, js_ast.namespace_export_part_index, runtime_export_symbol_ref, 1, Index.runtime, - ) catch unreachable; + ); } var imports_to_bind_list: []RefImportData = this.graph.meta.items(.imports_to_bind); var parts_list: []Part.List = ast_fields.items(.parts); @@ -629,12 +631,12 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { // Pull in the "__toCommonJS" symbol if we need it due to being an entry point if (force_include_exports and output_format != .internal_bake_dev) { - this.graph.generateRuntimeSymbolImportAndUse( + try this.graph.generateRuntimeSymbolImportAndUse( source_index, Index.part(entry_point_part_index), "__toCommonJS", 1, - ) catch unreachable; + ); } } @@ -713,13 +715,13 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { // Depend on the automatically-generated require wrapper symbol const wrapper_ref = wrapper_refs[other_id]; if (wrapper_ref.isValid()) { - this.graph.generateSymbolImportAndUse( + try this.graph.generateSymbolImportAndUse( source_index, @as(u32, @intCast(part_index)), wrapper_ref, 1, Index.source(other_source_index), - ) catch unreachable; + ); } // This is an ES6 import of a CommonJS module, so it needs the @@ -735,13 +737,13 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { // but does not need to be done for "import" statements since // those just cause us to reference the exports directly. if (other_flags.wrap == .esm and kind != .stmt) { - this.graph.generateSymbolImportAndUse( + try this.graph.generateSymbolImportAndUse( source_index, @as(u32, @intCast(part_index)), this.graph.ast.items(.exports_ref)[other_id], 1, Index.source(other_source_index), - ) catch unreachable; + ); // If this is a "require()" call, then we should add the // "__esModule" marker to behave as if the module was converted @@ -763,13 +765,13 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { // something ends up needing to use it later. This could potentially // be omitted in some cases with more advanced analysis if this // dynamic export fallback object doesn't end up being needed. - this.graph.generateSymbolImportAndUse( + try this.graph.generateSymbolImportAndUse( source_index, @as(u32, @intCast(part_index)), this.graph.ast.items(.exports_ref)[other_id], 1, Index.source(other_source_index), - ) catch unreachable; + ); } } @@ -795,25 +797,25 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { // pull in the "exports_b" symbol into this export star. This matters // in code splitting situations where the "export_b" symbol might live // in a different chunk than this export star. - this.graph.generateSymbolImportAndUse( + try this.graph.generateSymbolImportAndUse( source_index, @as(u32, @intCast(part_index)), this.graph.ast.items(.exports_ref)[other_id], 1, Index.source(other_source_index), - ) catch unreachable; + ); } } if (happens_at_runtime) { // Depend on this file's "exports" object for the first argument to "__reExport" - this.graph.generateSymbolImportAndUse( + try this.graph.generateSymbolImportAndUse( source_index, @as(u32, @intCast(part_index)), this.graph.ast.items(.exports_ref)[id], 1, Index.source(source_index), - ) catch unreachable; + ); this.graph.ast.items(.flags)[id].uses_exports_ref = true; record.calls_runtime_re_export_fn = true; re_export_uses += 1; @@ -823,37 +825,37 @@ pub fn scanImportsAndExports(this: *LinkerContext) !void { if (output_format != .internal_bake_dev) { // If there's an ES6 import of a CommonJS module, then we're going to need the // "__toESM" symbol from the runtime to wrap the result of "require()" - this.graph.generateRuntimeSymbolImportAndUse( + try this.graph.generateRuntimeSymbolImportAndUse( source_index, Index.part(part_index), "__toESM", to_esm_uses, - ) catch unreachable; + ); // If there's a CommonJS require of an ES6 module, then we're going to need the // "__toCommonJS" symbol from the runtime to wrap the exports object - this.graph.generateRuntimeSymbolImportAndUse( + try this.graph.generateRuntimeSymbolImportAndUse( source_index, Index.part(part_index), "__toCommonJS", to_common_js_uses, - ) catch unreachable; + ); // If there are unbundled calls to "require()" and we're not generating // code for node, then substitute a "__require" wrapper for "require". - this.graph.generateRuntimeSymbolImportAndUse( + try this.graph.generateRuntimeSymbolImportAndUse( source_index, Index.part(part_index), "__require", runtime_require_uses, - ) catch unreachable; + ); - this.graph.generateRuntimeSymbolImportAndUse( + try this.graph.generateRuntimeSymbolImportAndUse( source_index, Index.part(part_index), "__reExport", re_export_uses, - ) catch unreachable; + ); } } } diff --git a/src/defines-table.zig b/src/defines-table.zig index d000012a43..bcb67659da 100644 --- a/src/defines-table.zig +++ b/src/defines-table.zig @@ -181,6 +181,10 @@ pub const global_no_side_effect_property_accesses = &[_][]const string{ &[_]string{ "console", "trace" }, &[_]string{ "console", "warn" }, + &[_]string{ "Promise", "resolve" }, + &[_]string{ "Promise", "reject" }, + &[_]string{ "Promise", "all" }, + // Crypto: Static methods &[_]string{ "crypto", "randomUUID" }, }; diff --git a/src/runtime.js b/src/runtime.js index ea410c7a32..76abb726f0 100644 --- a/src/runtime.js +++ b/src/runtime.js @@ -175,3 +175,5 @@ export var __esm = (fn, res) => () => (fn && (res = fn((fn = 0))), res); export var $$typeof = /* @__PURE__ */ Symbol.for("react.element"); export var __jsonParse = /* @__PURE__ */ a => JSON.parse(a); + +export var __promiseAll = args => Promise.all(args); diff --git a/src/runtime.zig b/src/runtime.zig index 354978a15a..85f1cc6373 100644 --- a/src/runtime.zig +++ b/src/runtime.zig @@ -317,6 +317,7 @@ pub const Runtime = struct { __using: ?Ref = null, __callDispose: ?Ref = null, __jsonParse: ?Ref = null, + __promiseAll: ?Ref = null, pub const all = [_][]const u8{ "__name", @@ -333,6 +334,7 @@ pub const Runtime = struct { "__using", "__callDispose", "__jsonParse", + "__promiseAll", }; const all_sorted: [all.len]string = brk: { @setEvalBranchQuota(1000000); diff --git a/test/bundler/bundler_promiseall_deadcode.test.ts b/test/bundler/bundler_promiseall_deadcode.test.ts new file mode 100644 index 0000000000..76550aafe6 --- /dev/null +++ b/test/bundler/bundler_promiseall_deadcode.test.ts @@ -0,0 +1,530 @@ +import { expect, test } from "bun:test"; +import { bunEnv, bunExe, tempDirWithFiles } from "harness"; +import { join } from "path"; + +test("__promiseAll is tree-shaken when only one async import exists but __esm remains", async () => { + const dir = tempDirWithFiles("promise-all-single-async", { + "build.ts": ` + import { build } from "bun"; + build({ + entrypoints: ["src/entry.ts"], + outdir: "dist", + format: "esm", + target: "browser", + sourcemap: "linked", + minify: false, + }).then(() => { + console.log("Build completed successfully."); + }).catch((error) => { + console.error("Build failed:", error); + }) + `, + "src/entry.ts": ` + const { AsyncEntryPoint } = await import("./AsyncEntryPoint"); + AsyncEntryPoint(); + export {}; + `, + "src/AsyncEntryPoint.ts": ` + export async function AsyncEntryPoint() { + const { BaseElement } = await import("./BaseElement"); + console.log("Launching AsyncEntryPoint", BaseElement()); + } + `, + "src/BaseElement.ts": ` + import { StoreDependency } from "./StoreDependency"; + import { BaseElementImport } from "./BaseElementImport"; + + const depValue = StoreDependency(); + + export const formValue = { + key: depValue, + }; + + export const listValue = { + key: depValue + "value", + }; + + export function BaseElement() { + console.log("BaseElement called", BaseElementImport()); + return BaseElementImport(); + } + `, + "src/BaseElementImport.ts": ` + import { SecondElementImport } from "./SecondElementImport"; + export function BaseElementImport() { + console.log("BaseElementImport called", SecondElementImport()); + return SecondElementImport(); + } + `, + "src/SecondElementImport.ts": ` + import { formValue } from "./BaseElement"; + export function SecondElementImport() { + console.log("SecondElementImport called", formValue.key); + return formValue.key; + } + `, + "src/StoreDependency.ts": ` + import { somePromise } from "./StoreDependencyAsync"; + + export function StoreDependency() { + return "A string from StoreFunc" + somePromise; + } + `, + "src/StoreDependencyAsync.ts": ` + export const somePromise = await Promise.resolve("Hello World"); + `, + }); + + // Build the project + const buildResult = await Bun.spawn({ + cmd: [bunExe(), "build.ts"], + env: bunEnv, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const exitCode1 = await buildResult.exited; + expect(exitCode1).toBe(0); + + // Read the bundled output + const bundledPath = join(dir, "dist", "entry.js"); + const bundled = await Bun.file(bundledPath).text(); + + expect(bundled).toMatchInlineSnapshot(` + "var __defProp = Object.defineProperty; + var __export = (target, all) => { + for (var name in all) + __defProp(target, name, { + get: all[name], + enumerable: true, + configurable: true, + set: (newValue) => all[name] = () => newValue + }); + }; + var __esm = (fn, res) => () => (fn && (res = fn(fn = 0)), res); + var __promiseAll = (args) => Promise.all(args); + + // src/StoreDependencyAsync.ts + var somePromise; + var init_StoreDependencyAsync = __esm(async () => { + somePromise = await Promise.resolve("Hello World"); + }); + + // src/StoreDependency.ts + function StoreDependency() { + return "A string from StoreFunc" + somePromise; + } + var init_StoreDependency = __esm(async () => { + await init_StoreDependencyAsync(); + }); + + // src/SecondElementImport.ts + function SecondElementImport() { + console.log("SecondElementImport called", formValue.key); + return formValue.key; + } + var init_SecondElementImport = __esm(async () => { + await init_BaseElement(); + }); + + // src/BaseElementImport.ts + function BaseElementImport() { + console.log("BaseElementImport called", SecondElementImport()); + return SecondElementImport(); + } + var init_BaseElementImport = __esm(async () => { + await init_SecondElementImport(); + }); + + // src/BaseElement.ts + var exports_BaseElement = {}; + __export(exports_BaseElement, { + listValue: () => listValue, + formValue: () => formValue, + BaseElement: () => BaseElement + }); + function BaseElement() { + console.log("BaseElement called", BaseElementImport()); + return BaseElementImport(); + } + var depValue, formValue, listValue; + var init_BaseElement = __esm(async () => { + await __promiseAll([ + init_StoreDependency(), + init_BaseElementImport() + ]); + depValue = StoreDependency(); + formValue = { + key: depValue + }; + listValue = { + key: depValue + "value" + }; + }); + + // src/AsyncEntryPoint.ts + var exports_AsyncEntryPoint = {}; + __export(exports_AsyncEntryPoint, { + AsyncEntryPoint: () => AsyncEntryPoint + }); + async function AsyncEntryPoint() { + const { BaseElement: BaseElement2 } = await init_BaseElement().then(() => exports_BaseElement); + console.log("Launching AsyncEntryPoint", BaseElement2()); + } + + // src/entry.ts + var { AsyncEntryPoint: AsyncEntryPoint2 } = await Promise.resolve().then(() => exports_AsyncEntryPoint); + AsyncEntryPoint2(); + + //# debugId=BFB0A84A2F1B802064756E2164756E21 + //# sourceMappingURL=entry.js.map + " + `); + + // MUST have __esm because of circular dependency requiring wrapping + expect(bundled).toContain("__esm"); + expect(bundled).toContain("var init_"); + + // Should have __promiseAll because BaseElement has multiple dependencies + // even though only one is async (due to circular deps both need to be awaited) + expect(bundled).toContain("__promiseAll"); + expect(bundled).toContain("var __promiseAll = "); + + // Verify it's used with both dependencies + expect(bundled).toMatch(/await\s+__promiseAll\s*\(\s*\[/); + + // Also verify the bundled code can execute without syntax errors + const runResult = await Bun.spawn({ + cmd: [bunExe(), bundledPath], + env: bunEnv, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([ + new Response(runResult.stdout).text(), + new Response(runResult.stderr).text(), + runResult.exited, + ]); + + // Should not have syntax errors + expect(stderr).not.toContain('await" can only be used inside an "async" function'); + expect(exitCode).toBe(0); + expect(stdout).toContain("Launching AsyncEntryPoint"); +}); + +test("__promiseAll is included when multiple async imports exist with __esm", async () => { + const dir = tempDirWithFiles("promise-all-multiple-async", { + "build.ts": ` + import { build } from "bun"; + build({ + entrypoints: ["src/entry.ts"], + outdir: "dist", + format: "esm", + target: "browser", + sourcemap: "linked", + minify: false, + }).then(() => { + console.log("Build completed successfully."); + }).catch((error) => { + console.error("Build failed:", error); + }) + `, + "src/entry.ts": ` + const { AsyncEntryPoint } = await import("./AsyncEntryPoint"); + AsyncEntryPoint(); + export {}; + `, + "src/AsyncEntryPoint.ts": ` + export async function AsyncEntryPoint() { + const { BaseElement } = await import("./BaseElement"); + console.log("Launching AsyncEntryPoint", BaseElement()); + } + `, + "src/BaseElement.ts": ` + import { StoreDependency } from "./StoreDependency"; + import { StoreDependency2 } from "./StoreDependency2"; + import { BaseElementImport } from "./BaseElementImport"; + + const depValue = StoreDependency(); + const depValue2 = StoreDependency2(); + + export const formValue = { + key: depValue + depValue2, + }; + + export const listValue = { + key: depValue + "value", + }; + + export function BaseElement() { + console.log("BaseElement called", BaseElementImport()); + return BaseElementImport(); + } + `, + "src/BaseElementImport.ts": ` + import { SecondElementImport } from "./SecondElementImport"; + export function BaseElementImport() { + console.log("BaseElementImport called", SecondElementImport()); + return SecondElementImport(); + } + `, + "src/SecondElementImport.ts": ` + import { formValue } from "./BaseElement"; + export function SecondElementImport() { + console.log("SecondElementImport called", formValue.key); + return formValue.key; + } + `, + "src/StoreDependency.ts": ` + import { somePromise } from "./StoreDependencyAsync"; + + export function StoreDependency() { + return "A string from StoreFunc" + somePromise; + } + `, + "src/StoreDependencyAsync.ts": ` + export const somePromise = await Promise.resolve("Hello World"); + `, + "src/StoreDependency2.ts": ` + import { somePromise2 } from "./StoreDependencyAsync2"; + + export function StoreDependency2() { + return "Another string" + somePromise2; + } + `, + "src/StoreDependencyAsync2.ts": ` + export const somePromise2 = await Promise.resolve(" World2"); + `, + }); + + // Build the project + const buildResult = await Bun.spawn({ + cmd: [bunExe(), "build.ts"], + env: bunEnv, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const exitCode1 = await buildResult.exited; + expect(exitCode1).toBe(0); + + // Read the bundled output + const bundledPath = join(dir, "dist", "entry.js"); + const bundled = await Bun.file(bundledPath).text(); + + // MUST have __esm because of circular dependency requiring wrapping + expect(bundled).toContain("__esm"); + expect(bundled).toContain("var init_"); + + // MUST have __promiseAll since there are TWO async dependencies + expect(bundled).toContain("__promiseAll"); + expect(bundled).toContain("var __promiseAll = "); + + // Verify it's actually used in the code with multiple async deps + expect(bundled).toMatch(/await\s+__promiseAll\s*\(\s*\[/); + + // Also verify the bundled code can execute without syntax errors + const runResult = await Bun.spawn({ + cmd: [bunExe(), bundledPath], + env: bunEnv, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([ + new Response(runResult.stdout).text(), + new Response(runResult.stderr).text(), + runResult.exited, + ]); + + // Should not have syntax errors + expect(stderr).not.toContain('await" can only be used inside an "async" function'); + expect(exitCode).toBe(0); + expect(stdout).toContain("Launching AsyncEntryPoint"); +}); + +test("__promiseAll is tree-shaken when no async imports despite circular deps with __esm", async () => { + const dir = tempDirWithFiles("promise-all-no-async", { + "build.ts": ` + import { build } from "bun"; + build({ + entrypoints: ["src/entry.ts"], + outdir: "dist", + format: "esm", + target: "browser", + sourcemap: "linked", + minify: false, + }).then(() => { + console.log("Build completed successfully."); + }).catch((error) => { + console.error("Build failed:", error); + }) + `, + "src/entry.ts": ` + const { AsyncEntryPoint } = await import("./AsyncEntryPoint"); + AsyncEntryPoint(); + export {}; + `, + "src/AsyncEntryPoint.ts": ` + export async function AsyncEntryPoint() { + const { BaseElement } = await import("./BaseElement"); + console.log("Launching AsyncEntryPoint", BaseElement()); + } + `, + "src/BaseElement.ts": ` + import { BaseElementImport } from "./BaseElementImport"; + + export const formValue = { + key: "static value", + }; + + export const listValue = { + key: "static list value", + }; + + export function BaseElement() { + console.log("BaseElement called", BaseElementImport()); + return BaseElementImport(); + } + `, + "src/BaseElementImport.ts": ` + import { SecondElementImport } from "./SecondElementImport"; + export function BaseElementImport() { + console.log("BaseElementImport called", SecondElementImport()); + return SecondElementImport(); + } + `, + "src/SecondElementImport.ts": ` + import { formValue } from "./BaseElement"; + export function SecondElementImport() { + console.log("SecondElementImport called", formValue.key); + return formValue.key; + } + `, + }); + + // Build the project + const buildResult = await Bun.spawn({ + cmd: [bunExe(), "build.ts"], + env: bunEnv, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const exitCode1 = await buildResult.exited; + expect(exitCode1).toBe(0); + + // Read the bundled output + const bundledPath = join(dir, "dist", "entry.js"); + const bundled = await Bun.file(bundledPath).text(); + + expect(bundled).toMatchInlineSnapshot(` + "var __defProp = Object.defineProperty; + var __export = (target, all) => { + for (var name in all) + __defProp(target, name, { + get: all[name], + enumerable: true, + configurable: true, + set: (newValue) => all[name] = () => newValue + }); + }; + var __esm = (fn, res) => () => (fn && (res = fn(fn = 0)), res); + + // src/SecondElementImport.ts + function SecondElementImport() { + console.log("SecondElementImport called", formValue.key); + return formValue.key; + } + var init_SecondElementImport = __esm(() => { + init_BaseElement(); + }); + + // src/BaseElementImport.ts + function BaseElementImport() { + console.log("BaseElementImport called", SecondElementImport()); + return SecondElementImport(); + } + var init_BaseElementImport = __esm(() => { + init_SecondElementImport(); + }); + + // src/BaseElement.ts + var exports_BaseElement = {}; + __export(exports_BaseElement, { + listValue: () => listValue, + formValue: () => formValue, + BaseElement: () => BaseElement + }); + function BaseElement() { + console.log("BaseElement called", BaseElementImport()); + return BaseElementImport(); + } + var formValue, listValue; + var init_BaseElement = __esm(() => { + init_BaseElementImport(); + formValue = { + key: "static value" + }; + listValue = { + key: "static list value" + }; + }); + + // src/AsyncEntryPoint.ts + var exports_AsyncEntryPoint = {}; + __export(exports_AsyncEntryPoint, { + AsyncEntryPoint: () => AsyncEntryPoint + }); + async function AsyncEntryPoint() { + const { BaseElement: BaseElement2 } = await Promise.resolve().then(() => (init_BaseElement(), exports_BaseElement)); + console.log("Launching AsyncEntryPoint", BaseElement2()); + } + + // src/entry.ts + var { AsyncEntryPoint: AsyncEntryPoint2 } = await Promise.resolve().then(() => exports_AsyncEntryPoint); + AsyncEntryPoint2(); + + //# debugId=9153B63E16185E4364756E2164756E21 + //# sourceMappingURL=entry.js.map + " + `); + + // MUST have __esm because of circular dependency requiring wrapping + expect(bundled).toContain("__esm"); + expect(bundled).toContain("var init_"); + + // Currently __promiseAll is always included with ESM wrappers (not tree-shaken) + // but it shouldn't be used since there are no async dependencies + expect(bundled).not.toContain("__promiseAll"); + expect(bundled).not.toContain("var __promiseAll = "); + + // Verify it's NOT actually used in any init functions + expect(bundled).not.toMatch(/await\s+__promiseAll\s*\(/); + + // Also verify the bundled code can execute without syntax errors + const runResult = await Bun.spawn({ + cmd: [bunExe(), bundledPath], + env: bunEnv, + cwd: dir, + stdout: "pipe", + stderr: "pipe", + }); + + const [stdout, stderr, exitCode] = await Promise.all([ + new Response(runResult.stdout).text(), + new Response(runResult.stderr).text(), + runResult.exited, + ]); + + // Should not have syntax errors + expect(stderr).not.toContain('await" can only be used inside an "async" function'); + expect(exitCode).toBe(0); + expect(stdout).toContain("Launching AsyncEntryPoint"); +}); diff --git a/test/regression/issue/cyclic-imports-async-bundler.test.js b/test/regression/issue/cyclic-imports-async-bundler.test.js index 4a924627ab..16fbe4ed27 100644 --- a/test/regression/issue/cyclic-imports-async-bundler.test.js +++ b/test/regression/issue/cyclic-imports-async-bundler.test.js @@ -90,6 +90,97 @@ test("cyclic imports with async dependencies should generate async wrappers", as const bundledPath = join(dir, "dist", "entryBuild.js"); const bundled = await Bun.file(bundledPath).text(); + expect(bundled).toMatchInlineSnapshot(` + "var __defProp = Object.defineProperty; + var __export = (target, all) => { + for (var name in all) + __defProp(target, name, { + get: all[name], + enumerable: true, + configurable: true, + set: (newValue) => all[name] = () => newValue + }); + }; + var __esm = (fn, res) => () => (fn && (res = fn(fn = 0)), res); + var __promiseAll = (args) => Promise.all(args); + + // src/RecursiveDependencies/StoreDependencyAsync.ts + var somePromise; + var init_StoreDependencyAsync = __esm(async () => { + somePromise = await Promise.resolve("Hello World"); + }); + + // src/RecursiveDependencies/StoreDependency.ts + function StoreDependency() { + return "A string from StoreFunc" + somePromise; + } + var init_StoreDependency = __esm(async () => { + await init_StoreDependencyAsync(); + }); + + // src/RecursiveDependencies/SecondElementImport.ts + function SecondElementImport() { + console.log("SecondElementImport called", formValue.key); + return formValue.key; + } + var init_SecondElementImport = __esm(async () => { + await init_BaseElement(); + }); + + // src/RecursiveDependencies/BaseElementImport.ts + function BaseElementImport() { + console.log("BaseElementImport called", SecondElementImport()); + return SecondElementImport(); + } + var init_BaseElementImport = __esm(async () => { + await init_SecondElementImport(); + }); + + // src/RecursiveDependencies/BaseElement.ts + var exports_BaseElement = {}; + __export(exports_BaseElement, { + listValue: () => listValue, + formValue: () => formValue, + BaseElement: () => BaseElement + }); + function BaseElement() { + console.log("BaseElement called", BaseElementImport()); + return BaseElementImport(); + } + var depValue, formValue, listValue; + var init_BaseElement = __esm(async () => { + await __promiseAll([ + init_StoreDependency(), + init_BaseElementImport() + ]); + depValue = StoreDependency(); + formValue = { + key: depValue + }; + listValue = { + key: depValue + "value" + }; + }); + + // src/RecursiveDependencies/AsyncEntryPoint.ts + var exports_AsyncEntryPoint = {}; + __export(exports_AsyncEntryPoint, { + AsyncEntryPoint: () => AsyncEntryPoint + }); + async function AsyncEntryPoint() { + const { BaseElement: BaseElement2 } = await init_BaseElement().then(() => exports_BaseElement); + console.log("Launching AsyncEntryPoint", BaseElement2()); + } + + // src/entryBuild.ts + var { AsyncEntryPoint: AsyncEntryPoint2 } = await Promise.resolve().then(() => exports_AsyncEntryPoint); + AsyncEntryPoint2(); + + //# debugId=68A023AE1F6BCD1164756E2164756E21 + //# sourceMappingURL=entryBuild.js.map + " + `); + // Check that there are no syntax errors like "await" in non-async functions // The bug would manifest as something like: // var init_BaseElement = __esm(() => { From 97f6adf767d1c7d4cfbbe925aea87efabc028769 Mon Sep 17 00:00:00 2001 From: Meghan Denny Date: Fri, 26 Sep 2025 21:23:06 -0800 Subject: [PATCH 41/43] revert this change made to ShellRmTask.DirTask (#23028) it caused deploying the site to hang and `ref.ref(` isnt threadsafe. the proper fix should latch onto the shell command instead of the generic task queue --- src/shell/builtin/rm.zig | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/shell/builtin/rm.zig b/src/shell/builtin/rm.zig index 9f21dbcaa1..8d2efccc54 100644 --- a/src/shell/builtin/rm.zig +++ b/src/shell/builtin/rm.zig @@ -507,8 +507,6 @@ pub const ShellRmTask = struct { task: jsc.WorkPoolTask = .{ .callback = runFromThreadPool }, deleted_entries: std.ArrayList(u8), concurrent_task: jsc.EventLoopTask, - ref: bun.Async.KeepAlive = .{}, - event_loop: bun.jsc.EventLoopHandle, const EntryKindHint = enum { idk, dir, file }; @@ -521,7 +519,6 @@ pub const ShellRmTask = struct { pub fn runFromMainThread(this: *DirTask) void { debug("DirTask(0x{x}, path={s}) runFromMainThread", .{ @intFromPtr(this), this.path }); - this.ref.unref(this.event_loop); this.task_manager.rm.writeVerbose(this).run(); } @@ -695,7 +692,6 @@ pub const ShellRmTask = struct { .kind_hint = .idk, .deleted_entries = std.ArrayList(u8).init(bun.default_allocator), .concurrent_task = jsc.EventLoopTask.fromEventLoop(rm.bltn().eventLoop()), - .event_loop = rm.bltn().eventLoop(), }, .event_loop = rm.bltn().eventLoop(), .concurrent_task = jsc.EventLoopTask.fromEventLoop(rm.bltn().eventLoop()), @@ -741,7 +737,6 @@ pub const ShellRmTask = struct { .kind_hint = kind_hint, .deleted_entries = std.ArrayList(u8).init(bun.default_allocator), .concurrent_task = jsc.EventLoopTask.fromEventLoop(this.event_loop), - .event_loop = this.event_loop, }; const count = parent_task.subtask_count.fetchAdd(1, .monotonic); @@ -749,7 +744,6 @@ pub const ShellRmTask = struct { assert(count > 0); } - subtask.ref.ref(subtask.event_loop); jsc.WorkPool.schedule(&subtask.task); } From 5179dad48137dbc5c265cf9804b0ec570519b0fa Mon Sep 17 00:00:00 2001 From: Marko Vejnovic Date: Fri, 26 Sep 2025 22:43:23 -0700 Subject: [PATCH 42/43] hotfix(redis): Automatically connect on .subscribe() (#23018) ### What does this PR do? Previously `redis.subscribe` did not automatically connect, which was in contrast to other Redis functions. This PR changes the necessary `PUB/SUB` things so that `.subscribe` automatically connects. ### How did you verify your code works? Didn't --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- src/bun.js/api/BunObject.zig | 11 +- src/bun.js/bindings/JSMap.zig | 3 + src/bun.js/bindings/bindings.cpp | 5 + src/valkey/ValkeyCommand.zig | 6 +- src/valkey/js_valkey.zig | 183 +++++++++++++++-------------- src/valkey/js_valkey_functions.zig | 38 +++--- src/valkey/valkey.zig | 153 ++++++++++++++---------- 7 files changed, 231 insertions(+), 168 deletions(-) diff --git a/src/bun.js/api/BunObject.zig b/src/bun.js/api/BunObject.zig index edf05d293a..5c13bc36af 100644 --- a/src/bun.js/api/BunObject.zig +++ b/src/bun.js/api/BunObject.zig @@ -1294,7 +1294,9 @@ pub fn setTLSDefaultCiphers(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject, c } pub fn getValkeyDefaultClient(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) jsc.JSValue { - var valkey = jsc.API.Valkey.createNoJs(globalThis, &.{.js_undefined}) catch |err| { + const SubscriptionCtx = @import("../../valkey/js_valkey.zig").SubscriptionCtx; + + var valkey = jsc.API.Valkey.createNoJsNoPubsub(globalThis, &.{.js_undefined}) catch |err| { if (err != error.JSError) { _ = globalThis.throwError(err, "Failed to create Redis client") catch {}; return .zero; @@ -1305,6 +1307,13 @@ pub fn getValkeyDefaultClient(globalThis: *jsc.JSGlobalObject, _: *jsc.JSObject) const as_js = valkey.toJS(globalThis); valkey.this_value = jsc.JSRef.initWeak(as_js); + valkey._subscription_ctx = SubscriptionCtx.init(valkey) catch |err| { + if (err != error.JSError) { + _ = globalThis.throwError(err, "Failed to create Redis client") catch {}; + return .zero; + } + return .zero; + }; return as_js; } diff --git a/src/bun.js/bindings/JSMap.zig b/src/bun.js/bindings/JSMap.zig index 0abbd06173..97ba325c96 100644 --- a/src/bun.js/bindings/JSMap.zig +++ b/src/bun.js/bindings/JSMap.zig @@ -15,6 +15,9 @@ pub const JSMap = opaque { /// Attempt to remove a key from this JS Map object. pub const remove = bun.cpp.JSC__JSMap__remove; + /// Clear all entries from this JS Map object. + pub const clear = bun.cpp.JSC__JSMap__clear; + /// Retrieve the number of entries in this JS Map object. pub const size = bun.cpp.JSC__JSMap__size; diff --git a/src/bun.js/bindings/bindings.cpp b/src/bun.js/bindings/bindings.cpp index 6111976187..add44a48f1 100644 --- a/src/bun.js/bindings/bindings.cpp +++ b/src/bun.js/bindings/bindings.cpp @@ -6506,6 +6506,11 @@ CPP_DECL [[ZIG_EXPORT(check_slow)]] bool JSC__JSMap__remove(JSC::JSMap* map, JSC return map->remove(arg1, value); } +CPP_DECL [[ZIG_EXPORT(check_slow)]] void JSC__JSMap__clear(JSC::JSMap* map, JSC::JSGlobalObject* arg1) +{ + map->clear(arg1); +} + CPP_DECL [[ZIG_EXPORT(check_slow)]] void JSC__JSMap__set(JSC::JSMap* map, JSC::JSGlobalObject* arg1, JSC::EncodedJSValue JSValue2, JSC::EncodedJSValue JSValue3) { map->set(arg1, JSC::JSValue::decode(JSValue2), JSC::JSValue::decode(JSValue3)); diff --git a/src/valkey/ValkeyCommand.zig b/src/valkey/ValkeyCommand.zig index 2349ecad4a..2c1fa6d3b9 100644 --- a/src/valkey/ValkeyCommand.zig +++ b/src/valkey/ValkeyCommand.zig @@ -70,7 +70,8 @@ pub const Entry = struct { ) !Entry { return Entry{ .serialized_data = try command.serialize(allocator), - .meta = command.meta.check(command), + .meta = command.meta, // TODO(markovejnovic): We should be calling .check against command here but due + // to a hack introduced to let SUBSCRIBE work, we are not doing that for now. .promise = promise, }; } @@ -84,7 +85,8 @@ pub const Meta = packed struct(u8) { return_as_bool: bool = false, supports_auto_pipelining: bool = true, return_as_buffer: bool = false, - _padding: u5 = 0, + subscription_request: bool = false, + _padding: u4 = 0, const not_allowed_autopipeline_commands = bun.ComptimeStringMap(void, .{ .{"AUTH"}, diff --git a/src/valkey/js_valkey.zig b/src/valkey/js_valkey.zig index 7594dbdb22..31342eaa1b 100644 --- a/src/valkey/js_valkey.zig +++ b/src/valkey/js_valkey.zig @@ -1,37 +1,32 @@ pub const SubscriptionCtx = struct { const Self = @This(); - // TODO(markovejnovic): Consider using refactoring this to use - // @fieldParentPtr. The reason this was not implemented is because there is - // no support for optional fields yet. - // - // See: https://github.com/ziglang/zig/issues/25241 - // - // An alternative is to hold a flag within the context itself, indicating - // whether it is active or not, but that feels less clean. - _parent: *JSValkeyClient, - + is_subscriber: bool, original_enable_offline_queue: bool, original_enable_auto_pipelining: bool, const ParentJS = JSValkeyClient.js; - pub fn init(parent: *JSValkeyClient, enable_offline_queue: bool, enable_auto_pipelining: bool) bun.JSError!Self { - const callback_map = jsc.JSMap.create(parent.globalObject); - const parent_this = parent.this_value.tryGet() orelse unreachable; + pub fn init(valkey_parent: *JSValkeyClient) bun.JSError!Self { + const callback_map = jsc.JSMap.create(valkey_parent.globalObject); + const parent_this = valkey_parent.this_value.tryGet() orelse unreachable; - ParentJS.gc.set(.subscriptionCallbackMap, parent_this, parent.globalObject, callback_map); + ParentJS.gc.set(.subscriptionCallbackMap, parent_this, valkey_parent.globalObject, callback_map); const self = Self{ - ._parent = parent, - .original_enable_offline_queue = enable_offline_queue, - .original_enable_auto_pipelining = enable_auto_pipelining, + .original_enable_offline_queue = valkey_parent.client.flags.enable_offline_queue, + .original_enable_auto_pipelining = valkey_parent.client.flags.enable_auto_pipelining, + .is_subscriber = false, }; return self; } + fn parent(this: *SubscriptionCtx) *JSValkeyClient { + return @alignCast(@fieldParentPtr("_subscription_ctx", this)); + } + fn subscriptionCallbackMap(this: *Self) *jsc.JSMap { - const parent_this = this._parent.this_value.tryGet() orelse unreachable; + const parent_this = this.parent().this_value.tryGet() orelse unreachable; const value_js = ParentJS.gc.get(.subscriptionCallbackMap, parent_this).?; return jsc.JSMap.fromJS(value_js).?; @@ -39,7 +34,9 @@ pub const SubscriptionCtx = struct { /// Get the total number of channels that this subscription context is subscribed to. pub fn channelsSubscribedToCount(this: *Self, globalObject: *jsc.JSGlobalObject) bun.JSError!u32 { - return this.subscriptionCallbackMap().size(globalObject); + const count = try this.subscriptionCallbackMap().size(globalObject); + + return count; } /// Test whether this context has any subscriptions. It is mandatory to @@ -57,6 +54,10 @@ pub const SubscriptionCtx = struct { _ = try map.remove(globalObject, channelName); } + pub fn clearAllReceiveHandlers(this: *Self, globalObject: *jsc.JSGlobalObject) bun.JSError!void { + try this.subscriptionCallbackMap().clear(globalObject); + } + /// Remove a specific receive handler. /// /// Returns: The total number of remaining handlers for this channel, or null if here were no listeners originally @@ -121,7 +122,7 @@ pub const SubscriptionCtx = struct { channelName: JSValue, callback: JSValue, ) bun.JSError!void { - defer this._parent.onNewSubscriptionCallbackInsert(); + defer this.parent().onNewSubscriptionCallbackInsert(); const map = this.subscriptionCallbackMap(); var handlers_array: JSValue = undefined; @@ -185,7 +186,7 @@ pub const SubscriptionCtx = struct { // After we go through every single callback, we will have to update the poll ref. // The user may, for example, unsubscribe in the callbacks, or even stop the client. - defer this._parent.updatePollRef(); + defer this.parent().updatePollRef(); // If callbacks is an array, iterate and call each one var iter = try callbacks.arrayIterator(globalObject); @@ -203,17 +204,15 @@ pub const SubscriptionCtx = struct { // The user may request .close(), in which case we can dispose of the subscription object. If that is the case, // finalized will be true. Otherwise, we should treat the object as disposable if there are no active // subscriptions. - return this._parent.client.flags.finalized or !(try this.hasSubscriptions(global_object)); + return this.parent().client.flags.finalized or !(try this.hasSubscriptions(global_object)); } pub fn deinit(this: *Self, global_object: *jsc.JSGlobalObject) void { - // This check is necessary because crossing between Zig and C++ is necessary because Zig doesn't know that C++ - // is side-effect-free. if (comptime bun.Environment.isDebug) { - bun.debugAssert(this.isDeletable(this._parent.globalObject) catch unreachable); + bun.debugAssert(this.isDeletable(this.parent().globalObject) catch unreachable); } - if (this._parent.this_value.tryGet()) |parent_this| { + if (this.parent().this_value.tryGet()) |parent_this| { ParentJS.gc.set(.subscriptionCallbackMap, parent_this, global_object, .js_undefined); } } @@ -226,8 +225,9 @@ pub const JSValkeyClient = struct { this_value: jsc.JSRef = jsc.JSRef.empty(), poll_ref: bun.Async.KeepAlive = .{}, - _subscription_ctx: ?SubscriptionCtx, + _subscription_ctx: SubscriptionCtx, _socket_ctx: ?*uws.SocketContext = null, + timer: Timer.EventLoopTimer = .{ .tag = .ValkeyConnectionTimeout, .next = .{ @@ -259,7 +259,10 @@ pub const JSValkeyClient = struct { return try create(globalObject, callframe.arguments(), js_this); } - pub fn createNoJs(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue) bun.JSError!*JSValkeyClient { + /// Create a Valkey client that does not have an associated JS object nor a SubscriptionCtx. + /// + /// This whole client needs a refactor. + pub fn createNoJsNoPubsub(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue) bun.JSError!*JSValkeyClient { const this_allocator = bun.default_allocator; const vm = globalObject.bunVM(); @@ -334,9 +337,9 @@ pub const JSValkeyClient = struct { bun.analytics.Features.valkey += 1; - const client = JSValkeyClient.new(.{ + return JSValkeyClient.new(.{ .ref_count = .init(), - ._subscription_ctx = null, + ._subscription_ctx = undefined, .client = .{ .vm = vm, .address = switch (uri) { @@ -368,7 +371,7 @@ pub const JSValkeyClient = struct { .flags = .{ .enable_auto_reconnect = options.enable_auto_reconnect, .enable_offline_queue = options.enable_offline_queue, - .auto_pipelining = options.enable_auto_pipelining, + .enable_auto_pipelining = options.enable_auto_pipelining, }, .max_retries = options.max_retries, .connection_timeout_ms = options.connection_timeout_ms, @@ -376,15 +379,17 @@ pub const JSValkeyClient = struct { }, .globalObject = globalObject, }); - - return client; } pub fn create(globalObject: *jsc.JSGlobalObject, arguments: []const JSValue, js_this: JSValue) bun.JSError!*JSValkeyClient { - var new_client = try JSValkeyClient.createNoJs(globalObject, arguments); + var new_client = try JSValkeyClient.createNoJsNoPubsub(globalObject, arguments); // Initially, we only need to hold a weak reference to the JS object. new_client.this_value = jsc.JSRef.initWeak(js_this); + + // Need to associate the subscription context, after the JS ref has been populated. + new_client._subscription_ctx = try SubscriptionCtx.init(new_client); + return new_client; } @@ -415,7 +420,7 @@ pub const JSValkeyClient = struct { return JSValkeyClient.new(.{ .ref_count = .init(), - ._subscription_ctx = null, + ._subscription_ctx = undefined, .client = .{ .vm = vm, .address = switch (this.client.protocol) { @@ -450,11 +455,17 @@ pub const JSValkeyClient = struct { // If the user manually closed the connection, then duplicating a closed client // means the new client remains finalized. .is_manually_closed = this.client.flags.is_manually_closed, - .enable_offline_queue = if (this._subscription_ctx) |*ctx| ctx.original_enable_offline_queue else this.client.flags.enable_offline_queue, + .enable_offline_queue = if (this._subscription_ctx.is_subscriber) + this._subscription_ctx.original_enable_offline_queue + else + this.client.flags.enable_offline_queue, .needs_to_open_socket = true, .enable_auto_reconnect = this.client.flags.enable_auto_reconnect, .is_reconnecting = false, - .auto_pipelining = if (this._subscription_ctx) |*ctx| ctx.original_enable_auto_pipelining else this.client.flags.auto_pipelining, + .enable_auto_pipelining = if (this._subscription_ctx.is_subscriber) + this._subscription_ctx.original_enable_auto_pipelining + else + this.client.flags.enable_auto_pipelining, // Duplicating a finalized client means it stays finalized. .finalized = this.client.flags.finalized, }, @@ -466,7 +477,40 @@ pub const JSValkeyClient = struct { }); } - pub fn getOrCreateSubscriptionCtxEnteringSubscriptionMode( + pub fn addSubscription(this: *JSValkeyClient) void { + debug("addSubscription: entering, current subscriber state: {}", .{this._subscription_ctx.is_subscriber}); + bun.debugAssert(this.client.status == .connected); + this.ref(); + defer this.deref(); + + if (!this._subscription_ctx.is_subscriber) { + this._subscription_ctx.original_enable_offline_queue = this.client.flags.enable_offline_queue; + this._subscription_ctx.original_enable_auto_pipelining = this.client.flags.enable_auto_pipelining; + debug("addSubscription: calling updatePollRef", .{}); + this.updatePollRef(); + } + + this._subscription_ctx.is_subscriber = true; + debug("addSubscription: exiting, new subscriber state: {}", .{this._subscription_ctx.is_subscriber}); + } + + pub fn removeSubscription(this: *JSValkeyClient) void { + debug("removeSubscription: entering, has subscriptions: {}", .{this._subscription_ctx.hasSubscriptions(this.globalObject) catch false}); + this.ref(); + defer this.deref(); + + // This is the last subscription, restore original flags + if (!(this._subscription_ctx.hasSubscriptions(this.globalObject) catch false)) { + this.client.flags.enable_offline_queue = this._subscription_ctx.original_enable_offline_queue; + this.client.flags.enable_auto_pipelining = this._subscription_ctx.original_enable_auto_pipelining; + this._subscription_ctx.is_subscriber = false; + debug("removeSubscription: calling updatePollRef", .{}); + this.updatePollRef(); + } + debug("removeSubscription: exiting", .{}); + } + + pub fn getOrCreateSubscriptionCtx( this: *JSValkeyClient, ) bun.JSError!*SubscriptionCtx { if (this._subscription_ctx) |*ctx| { @@ -478,31 +522,22 @@ pub const JSValkeyClient = struct { this._subscription_ctx = try SubscriptionCtx.init( this, this.client.flags.enable_offline_queue, - this.client.flags.auto_pipelining, + this.client.flags.enable_auto_pipelining, ); - // We need to make sure we disable the offline queue. - this.client.flags.enable_offline_queue = false; - this.client.flags.auto_pipelining = false; + // We need to make sure we disable the offline queue, but we actually want to make sure that our HELLO message + // goes through first. Consequently, we only disable the offline queue if we're already connected. + if (this.client.status == .connected) { + this.client.flags.enable_offline_queue = false; + } + + this.client.flags.enable_auto_pipelining = false; return &(this._subscription_ctx.?); } - pub fn deleteSubscriptionCtx(this: *JSValkeyClient) void { - if (this._subscription_ctx) |*ctx| { - // Restore the original flag values when leaving subscription mode - this.client.flags.enable_offline_queue = ctx.original_enable_offline_queue; - this.client.flags.auto_pipelining = ctx.original_enable_auto_pipelining; - - ctx.deinit(this.globalObject); - this._subscription_ctx = null; - } - - this._subscription_ctx = null; - } - pub fn isSubscriber(this: *const JSValkeyClient) bool { - return this._subscription_ctx != null; + return this._subscription_ctx.is_subscriber; } pub fn getConnected(this: *JSValkeyClient, _: *jsc.JSGlobalObject) JSValue { @@ -526,12 +561,10 @@ pub const JSValkeyClient = struct { // If already connected, resolve immediately if (this.client.status == .connected) { - debug("Connecting client is already connected.", .{}); return jsc.JSPromise.resolvedPromiseValue(globalObject, js.helloGetCached(this_value) orelse .js_undefined); } if (js.connectionPromiseGetCached(this_value)) |promise| { - debug("Connecting client is already connected.", .{}); return promise; } @@ -544,7 +577,6 @@ pub const JSValkeyClient = struct { defer this.updatePollRef(); if (this.client.flags.needs_to_open_socket) { - debug("Need to open socket, starting connection process.", .{}); this.poll_ref.ref(this.client.vm); this.connect() catch |err| { @@ -829,22 +861,6 @@ pub const JSValkeyClient = struct { bun.debugAssert(this.isSubscriber()); bun.debugAssert(this.this_value.isStrong()); - this.ref(); - defer this.deref(); - - var subscription_ctx = this._subscription_ctx.?; - - // Check if we have any remaining subscriptions - // If the callback map is empty, we can exit subscription mode - - // If fetching the subscription count fails, the best we can do is - // bubble the error up. - const has_subs = try subscription_ctx.hasSubscriptions(this.globalObject); - if (!has_subs) { - // No more subscriptions, exit subscription mode - this.deleteSubscriptionCtx(); - } - this.client.onWritable(); this.updatePollRef(); } @@ -876,11 +892,8 @@ pub const JSValkeyClient = struct { return; }; - // Get the subscription context - const subs_ctx = &this._subscription_ctx.?; - // Invoke callbacks for this channel with message and channel as arguments - subs_ctx.invokeCallbacks( + this._subscription_ctx.invokeCallbacks( globalObject, channel_value, &[_]JSValue{ message_value, channel_value }, @@ -1004,7 +1017,7 @@ pub const JSValkeyClient = struct { // garbage collected now and the subscription context holds a reference // to us. If we still had a subscription context, we would never be // garbage collected. - bun.debugAssert(this._subscription_ctx == null); + bun.debugAssert(!this._subscription_ctx.is_subscriber); } pub fn stopTimers(this: *JSValkeyClient) void { @@ -1025,7 +1038,6 @@ pub const JSValkeyClient = struct { } fn connect(this: *JSValkeyClient) !void { - debug("Connecting to Redis.", .{}); this.client.flags.needs_to_open_socket = false; const vm = this.client.vm; @@ -1103,12 +1115,12 @@ pub const JSValkeyClient = struct { } defer this.updatePollRef(); - return try this.client.send(globalThis, command); } // Getter for memory cost - useful for diagnostics pub fn memoryCost(this: *JSValkeyClient) usize { + // TODO(markovejnovic): This is most-likely wrong because I didn't know better. var memory_cost: usize = @sizeOf(JSValkeyClient); // Add size of all internal buffers @@ -1167,16 +1179,12 @@ pub const JSValkeyClient = struct { // should be treating valkey as a state machine, with well-defined // state and modes in which it tracks and manages its own lifecycle. // This is a mess beyond belief and it is incredibly fragile. - const has_pending_commands = this.client.hasAnyPendingCommands(); // isDeletable may throw an exception, and if it does, we have to assume // that the object still has references. Best we can do is hope nothing // catastrophic happens. - const subs_deletable: bool = if (this._subscription_ctx) |*ctx| - ctx.isDeletable(this.globalObject) catch false - else - true; + const subs_deletable: bool = !(this._subscription_ctx.hasSubscriptions(this.globalObject) catch false); const has_activity = has_pending_commands or !subs_deletable or this.client.flags.is_reconnecting; @@ -1192,7 +1200,6 @@ pub const JSValkeyClient = struct { } if (this.this_value.isEmpty()) { - debug("this_value is empty, skipping updatePollRef", .{}); return; } diff --git a/src/valkey/js_valkey_functions.zig b/src/valkey/js_valkey_functions.zig index 7afeb4a19d..06f6befd51 100644 --- a/src/valkey/js_valkey_functions.zig +++ b/src/valkey/js_valkey_functions.zig @@ -745,9 +745,6 @@ pub fn subscribe( return globalObject.throwInvalidArgumentType("subscribe", "listener", "function"); } - // We now need to register the callback with our subscription context, which may or may not exist. - var subscription_ctx = try this.getOrCreateSubscriptionCtxEnteringSubscriptionMode(); - // The first argument given is the channel or may be an array of channels. if (channel_or_many.isArray()) { if ((try channel_or_many.getLength(globalObject)) == 0) { @@ -762,7 +759,13 @@ pub fn subscribe( }; redis_channels.appendAssumeCapacity(channel); - try subscription_ctx.upsertReceiveHandler(globalObject, channel_arg, handler_callback); + // What we do here is add our receive handler. Notice that this doesn't really do anything until the + // "SUBSCRIBE" command is sent to redis and we get a response. + // + // TODO(markovejnovic): This is less-than-ideal, still, because this assumes a happy path. What happens if + // the SUBSCRIBE command fails? We have no way to roll back the addition of the + // handler. + try this._subscription_ctx.upsertReceiveHandler(globalObject, channel_arg, handler_callback); } } else if (channel_or_many.isString()) { // It is a single string channel @@ -771,7 +774,7 @@ pub fn subscribe( }; redis_channels.appendAssumeCapacity(channel); - try subscription_ctx.upsertReceiveHandler(globalObject, channel_or_many, handler_callback); + try this._subscription_ctx.upsertReceiveHandler(globalObject, channel_or_many, handler_callback); } else { return globalObject.throwInvalidArgumentType("subscribe", "channel", "string or array"); } @@ -779,14 +782,17 @@ pub fn subscribe( const command: valkey.Command = .{ .command = "SUBSCRIBE", .args = .{ .args = redis_channels.items }, + .meta = .{ + .subscription_request = true, + }, }; const promise = this.send( globalObject, callframe.this(), &command, ) catch |err| { - // If we find an error, we need to clean up the subscription context. - this.deleteSubscriptionCtx(); + // If we catch an error, we need to clean up any handlers we may have added and fall out of subscription mode + try this._subscription_ctx.clearAllReceiveHandlers(globalObject); return protocol.valkeyErrorToJS(globalObject, "Failed to send SUBSCRIBE command", err); }; @@ -815,9 +821,6 @@ fn sendUnsubscribeRequestAndCleanup( return protocol.valkeyErrorToJS(globalObject, "Failed to send UNSUBSCRIBE command", err); }; - // We do not delete the subscription context here, but rather when the - // onValkeyUnsubscribe callback is invoked. - return promise.toJS(); } @@ -842,6 +845,7 @@ pub fn unsubscribe( // If no arguments, unsubscribe from all channels if (args_view.len == 0) { + try this._subscription_ctx.clearAllReceiveHandlers(globalObject); return try sendUnsubscribeRequestAndCleanup(this, callframe.this(), globalObject, redis_channels.items); } @@ -849,9 +853,9 @@ pub fn unsubscribe( const channel_or_many = callframe.argument(0); // Get the subscription context - var subscription_ctx = this._subscription_ctx orelse { + if (!this._subscription_ctx.is_subscriber) { return jsc.JSPromise.resolvedPromiseValue(globalObject, .js_undefined); - }; + } // Two arguments means .unsubscribe(channel, listener) is invoked. if (callframe.arguments().len == 2) { @@ -884,7 +888,7 @@ pub fn unsubscribe( return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string"); }); - const remaining_listeners = subscription_ctx.removeReceiveHandler( + const remaining_listeners = this._subscription_ctx.removeReceiveHandler( globalObject, channel, listener_cb, @@ -926,7 +930,7 @@ pub fn unsubscribe( }; redis_channels.appendAssumeCapacity(channel); // Clear the handlers for this channel - try subscription_ctx.clearReceiveHandlers(globalObject, channel_arg); + try this._subscription_ctx.clearReceiveHandlers(globalObject, channel_arg); } } else if (channel_or_many.isString()) { // It is a single string channel @@ -935,7 +939,7 @@ pub fn unsubscribe( }; redis_channels.appendAssumeCapacity(channel); // Clear the handlers for this channel - try subscription_ctx.clearReceiveHandlers(globalObject, channel_or_many); + try this._subscription_ctx.clearReceiveHandlers(globalObject, channel_or_many); } else { return globalObject.throwInvalidArgumentType("unsubscribe", "channel", "string or array"); } @@ -954,8 +958,8 @@ pub fn duplicate( var new_client: *JSValkeyClient = try this.cloneWithoutConnecting(globalObject); const new_client_js = new_client.toJS(globalObject); - new_client.this_value = jsc.JSRef.initWeak(new_client_js); + new_client._subscription_ctx = try SubscriptionCtx.init(new_client); // If the original client is already connected and not manually closed, start connecting the new client. if (this.client.status == .connected and !this.client.flags.is_manually_closed) { // Use strong reference during connection to prevent premature GC @@ -1212,7 +1216,9 @@ fn fromJS(globalObject: *jsc.JSGlobalObject, value: JSValue) !?JSArgument { const bun = @import("bun"); const std = @import("std"); + const JSValkeyClient = @import("./js_valkey.zig").JSValkeyClient; +const SubscriptionCtx = @import("./js_valkey.zig").SubscriptionCtx; const jsc = bun.jsc; const JSValue = jsc.JSValue; diff --git a/src/valkey/valkey.zig b/src/valkey/valkey.zig index ad23100e2c..0b31decdc2 100644 --- a/src/valkey/valkey.zig +++ b/src/valkey/valkey.zig @@ -16,7 +16,7 @@ pub const ConnectionFlags = struct { needs_to_open_socket: bool = true, enable_auto_reconnect: bool = true, is_reconnecting: bool = false, - auto_pipelining: bool = true, + enable_auto_pipelining: bool = true, finalized: bool = false, // This flag is a slight hack to allow returning the client instance in the // promise which resolves when the connection is established. There are two @@ -388,7 +388,8 @@ pub const ValkeyClient = struct { if (wrote > 0) { this.write_buffer.consume(@intCast(wrote)); } - return this.write_buffer.len() > 0; + const has_remaining = this.write_buffer.len() > 0; + return has_remaining; } const DeferredFailure = struct { @@ -543,6 +544,7 @@ pub const ValkeyClient = struct { /// /// Caller refs / derefs. pub fn onData(this: *ValkeyClient, data: []const u8) void { + debug("Low-level onData called with {d} bytes: {s}", .{ data.len, data }); // Path 1: Buffer already has data, append and process from buffer if (this.read_buffer.remaining().len > 0) { this.read_buffer.write(this.allocator, data) catch @panic("failed to write to read buffer"); @@ -651,9 +653,12 @@ pub const ValkeyClient = struct { /// Try handling this response as a subscriber-state response. /// Returns `handled` if we handled it, `fallthrough` if we did not. - fn handleSubscribeResponse(this: *ValkeyClient, value: *protocol.RESPValue, pair: *ValkeyCommand.PromisePair) bun.JSError!enum { handled, fallthrough } { + fn handleSubscribeResponse( + this: *ValkeyClient, + value: *protocol.RESPValue, + pair: ?*ValkeyCommand.PromisePair, + ) bun.JSError!enum { handled, fallthrough } { // Resolve the promise with the potentially transformed value - var promise_ptr = &pair.promise; const globalThis = this.globalObject(); const loop = this.vm.eventLoop(); @@ -663,35 +668,43 @@ pub const ValkeyClient = struct { return switch (value.*) { .Error => { - promise_ptr.reject(globalThis, value.toJS(globalThis)); + if (pair) |p| { + p.promise.reject(globalThis, value.toJS(globalThis)); + } return .handled; }, .Push => |push| { const p = this.parent(); - const subs_ctx = try p.getOrCreateSubscriptionCtxEnteringSubscriptionMode(); - const sub_count = try subs_ctx.channelsSubscribedToCount(globalThis); + const sub_count = try p._subscription_ctx.channelsSubscribedToCount(globalThis); if (protocol.SubscriptionPushMessage.map.get(push.kind)) |msg_type| { switch (msg_type) { + .message => { + this.onValkeyMessage(push.data); + return .handled; + }, .subscribe => { + p.addSubscription(); this.onValkeySubscribe(value); - promise_ptr.promise.resolve(globalThis, .jsNumber(sub_count)); + + // For SUBSCRIBE responses, only resolve the promise for the first channel confirmation + // Additional channel confirmations from multi-channel SUBSCRIBE commands don't need promise pairs + if (pair) |req_pair| { + req_pair.promise.promise.resolve(globalThis, .jsNumber(sub_count)); + } return .handled; }, .unsubscribe => { try this.onValkeyUnsubscribe(); - promise_ptr.promise.resolve(globalThis, .js_undefined); + p.removeSubscription(); + + // For UNSUBSCRIBE responses, only resolve the promise if we have one + // Additional channel confirmations from multi-channel UNSUBSCRIBE commands don't need promise pairs + if (pair) |req_pair| { + req_pair.promise.promise.resolve(globalThis, .js_undefined); + } return .handled; }, - else => { - // Other push messages (message, pmessage, etc) are not handled here - @branchHint(.cold); - this.fail( - "Push message is not a subscription message.", - protocol.RedisError.InvalidResponseType, - ); - return .fallthrough; - }, } } else { // We should rarely reach this point. If we're guaranteed to be handling a subscribe/unsubscribe, @@ -769,7 +782,6 @@ pub const ValkeyClient = struct { /// Handle Valkey protocol response fn handleResponse(this: *ValkeyClient, value: *protocol.RESPValue) !void { - debug("onData() {any}", .{value.*}); // Special handling for the initial HELLO response if (!this.flags.is_authenticated) { this.handleHelloResponse(value); @@ -804,26 +816,47 @@ pub const ValkeyClient = struct { }, }; } - // Let's load the promise pair. - var pair_maybe = this.in_flight.readItem(); + // Check if this is a subscription push message that might not need a promise pair + var should_consume_promise_pair = true; + var pair_maybe: ?ValkeyCommand.PromisePair = null; - // We handle subscriptions specially because they are not regular - // commands and their failure will potentially cause the client to drop - // out of subscriber mode. + // For subscription clients, check if this is a push message that doesn't need a promise pair if (this.parent().isSubscriber()) { - debug("This client is a subscriber. Handling as subscriber...", .{}); - - // There are multiple different commands we may receive in - // subscriber mode. One is from a client.subscribe() call which - // requires that a promise is in-flight, but otherwise, we may also - // receive push messages from the server that do not have an - // associated promise. - if (pair_maybe) |*pair| { - debug("There is a request in flight. Handling as a subscribe request...", .{}); - if ((try this.handleSubscribeResponse(value, pair)) == .handled) { - return; - } + switch (value.*) { + .Push => |push| { + if (protocol.SubscriptionPushMessage.map.get(push.kind)) |msg_type| { + switch (msg_type) { + .message => { + // Message pushes never need promise pairs + should_consume_promise_pair = false; + }, + .subscribe, .unsubscribe => { + // Subscribe/unsubscribe pushes only need promise pairs if we have pending commands + if (this.in_flight.readableLength() == 0) { + should_consume_promise_pair = false; + } + }, + } + } + }, + else => {}, } + } + + // Only consume promise pair if we determined we need one + // The reaosn we consume pairs is that a SUBSCRIBE message may actually be followed by a number of SUBSCRIBE + // responses which indicate all the channels we have connected to. As a stop-gap, we currently ignore the + // actual of content of the SUBSCRIBE responses and just resolve the first one with the count of channels. + // TODO(markovejnovic): Do better. + if (should_consume_promise_pair) { + pair_maybe = this.in_flight.readItem(); + } + + // We handle subscriptions specially because they are not regular commands and their failure will potentially + // cause the client to drop out of subscriber mode. + const request_is_subscribe = if (pair_maybe) |p| p.meta.subscription_request else false; + if (this.parent().isSubscriber() or request_is_subscribe) { + debug("This client is a subscriber. Handling as subscriber...", .{}); switch (value.*) { .Error => |err| { @@ -831,19 +864,9 @@ pub const ValkeyClient = struct { return; }, .Push => |push| { - if (protocol.SubscriptionPushMessage.map.get(push.kind)) |msg_type| { - switch (msg_type) { - .message => { - @branchHint(.likely); - debug("Received a message.", .{}); - this.onValkeyMessage(push.data); - return; - }, - else => { - @branchHint(.cold); - debug("Received non-message push without promise: {any}", .{push.data}); - return; - }, + if (protocol.SubscriptionPushMessage.map.get(push.kind)) |_| { + if ((try this.handleSubscribeResponse(value, if (pair_maybe) |*pm| pm else null)) == .handled) { + return; } } else { @branchHint(.cold); @@ -866,7 +889,6 @@ pub const ValkeyClient = struct { // For regular commands, get the next command+promise pair from the queue var pair = pair_maybe orelse { - debug("Received response but no promise in queue", .{}); return; }; @@ -984,7 +1006,9 @@ pub const ValkeyClient = struct { } } - const offline_cmd = this.queue.readItem() orelse return false; + const offline_cmd = this.queue.readItem() orelse { + return false; + }; // Add the promise to the command queue first this.in_flight.writeItem(.{ @@ -1023,7 +1047,7 @@ pub const ValkeyClient = struct { } fn enqueue(this: *ValkeyClient, command: *const Command, promise: *Command.Promise) !void { - const can_pipeline = command.meta.supports_auto_pipelining and this.flags.auto_pipelining; + const can_pipeline = command.meta.supports_auto_pipelining and this.flags.enable_auto_pipelining; // For commands that don't support pipelining, we need to wait for the queue to drain completely // before sending the command. This ensures proper order of execution for state-changing commands. @@ -1042,7 +1066,8 @@ pub const ValkeyClient = struct { can_pipeline) { // We serialize the bytes in here, so we don't need to worry about the lifetime of the Command itself. - try this.queue.writeItem(try Command.Entry.create(this.allocator, command, promise.*)); + const entry = try Command.Entry.create(this.allocator, command, promise.*); + try this.queue.writeItem(entry); // If we're connected and using auto pipelining, schedule a flush if (this.status == .connected and can_pipeline) { @@ -1053,9 +1078,11 @@ pub const ValkeyClient = struct { } switch (this.status) { - .connecting, .connected => command.write(this.writer()) catch { - promise.reject(this.globalObject(), this.globalObject().createOutOfMemoryError()); - return; + .connecting, .connected => { + command.write(this.writer()) catch { + promise.reject(this.globalObject(), this.globalObject().createOutOfMemoryError()); + return; + }; }, else => unreachable, } @@ -1072,17 +1099,21 @@ pub const ValkeyClient = struct { } pub fn send(this: *ValkeyClient, globalThis: *jsc.JSGlobalObject, command: *const Command) !*jsc.JSPromise { - var promise = Command.Promise.create(globalThis, command.meta); + // FIX: Check meta before using it for routing decisions + var checked_command = command.*; + checked_command.meta = command.meta.check(command); + + var promise = Command.Promise.create(globalThis, checked_command.meta); const js_promise = promise.promise.get(); // Handle disconnected state with offline queue switch (this.status) { .connected => { - try this.enqueue(command, &promise); + try this.enqueue(&checked_command, &promise); // Schedule auto-flushing to process this command if pipelining is enabled - if (this.flags.auto_pipelining and - command.meta.supports_auto_pipelining and + if (this.flags.enable_auto_pipelining and + checked_command.meta.supports_auto_pipelining and this.status == .connected and this.queue.readableLength() > 0) { @@ -1092,7 +1123,7 @@ pub const ValkeyClient = struct { .connecting, .disconnected => { // Only queue if offline queue is enabled if (this.flags.enable_offline_queue) { - try this.enqueue(command, &promise); + try this.enqueue(&checked_command, &promise); } else { promise.reject( globalThis, From 73feb108d9fdcacaaf048e7d55400ded72c65f24 Mon Sep 17 00:00:00 2001 From: "taylor.fish" Date: Fri, 26 Sep 2025 23:02:44 -0700 Subject: [PATCH 43/43] Handle optionals in `bun.memory.deinit` (#23027) Currently, if you try to deinit an optional, `bun.memory.deinit` will silently do nothing, even if the optional's payload is a struct with a `deinit` method. This commit makes sure the payload is deinitialized. (For internal tracking: fixes STAB-1293) --- src/memory.zig | 53 +++++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/src/memory.zig b/src/memory.zig index 6f41f9fed9..cd3d076bd9 100644 --- a/src/memory.zig +++ b/src/memory.zig @@ -65,15 +65,17 @@ fn deinitIsVoid(comptime T: type) bool { }; } -/// Calls `deinit` on `ptr_or_slice`, or on every element of `ptr_or_slice`, if such a `deinit` -/// method exists. +/// Calls `deinit` on `ptr_or_slice`, or on every element of `ptr_or_slice`, if the pointer points +/// to a struct or tagged union. /// /// This function first does the following: /// -/// * If `ptr_or_slice` is a single-item pointer, calls `ptr_or_slice.deinit()`, if that method -/// exists. -/// * If `ptr_or_slice` is a slice, calls `deinit` on every element of the slice, if the slice -/// elements have a `deinit` method. +/// * If `ptr_or_slice` is a single-item pointer of type `*T`: +/// - If `T` is a struct or tagged union, calls `ptr_or_slice.deinit()` +/// - If `T` is an optional, checks if `ptr_or_slice` points to a non-null value, and if so, +/// calls `bun.memory.deinit` with a pointer to the payload. +/// * If `ptr_or_slice` is a slice, for each element of the slice, calls `bun.memory.deinit` with +/// a pointer to the element. /// /// Then, if `ptr_or_slice` is non-const, this function also sets all memory referenced by the /// pointer to `undefined`. @@ -81,32 +83,43 @@ fn deinitIsVoid(comptime T: type) bool { /// This method does not free `ptr_or_slice` itself. pub fn deinit(ptr_or_slice: anytype) void { const ptr_info = @typeInfo(@TypeOf(ptr_or_slice)); + switch (comptime ptr_info.pointer.size) { + .slice => { + for (ptr_or_slice) |*elem| { + deinit(elem); + } + return; + }, + .one => {}, + else => @compileError("unsupported pointer type"), + } + const Child = ptr_info.pointer.child; const mutable = !ptr_info.pointer.is_const; + defer { + if (comptime mutable) { + ptr_or_slice.* = undefined; + } + } const needs_deinit = comptime switch (@typeInfo(Child)) { .@"struct" => true, .@"union" => |u| u.tag_type != null, + .optional => { + if (ptr_or_slice.*) |*payload| { + deinit(payload); + } + return; + }, else => false, }; + const should_call_deinit = comptime needs_deinit and !exemptedFromDeinit(Child) and !deinitIsVoid(Child); - switch (comptime ptr_info.pointer.size) { - .one => { - if (comptime should_call_deinit) { - ptr_or_slice.deinit(); - } - if (comptime mutable) ptr_or_slice.* = undefined; - }, - .slice => for (ptr_or_slice) |*elem| { - if (comptime should_call_deinit) { - elem.deinit(); - } - if (comptime mutable) elem.* = undefined; - }, - else => @compileError("unsupported pointer type"), + if (comptime should_call_deinit) { + ptr_or_slice.deinit(); } }