diff --git a/packages/bun-types/bun-test.d.ts b/packages/bun-types/bun-test.d.ts index 57814d2aae..2a49c6b77e 100644 --- a/packages/bun-types/bun-test.d.ts +++ b/packages/bun-types/bun-test.d.ts @@ -27,6 +27,38 @@ declare module "bun:test" { export const mock: { (Function: T): Mock; + + /** + * Replace the module `id` with the return value of `factory`. + * + * This is useful for mocking modules. + * + * @param id module ID to mock + * @param factory a function returning an object that will be used as the exports of the mocked module + * + * @example + * ## Example + * ```ts + * import { mock } from "bun:test"; + * + * mock.module("fs/promises", () => { + * return { + * readFile: () => Promise.resolve("hello world"), + * }; + * }); + * + * import { readFile } from "fs/promises"; + * + * console.log(await readFile("hello.txt", "utf8")); // hello world + * ``` + * + * ## More notes + * + * If the module is already loaded, exports are overwritten with the return + * value of `factory`. If the export didn't exist before, it will not be + * added to existing import statements. This is due to how ESM works. + */ + module(id: string, factory: () => any): void | Promise; }; /** diff --git a/src/bun.js/bindings/BunPlugin.cpp b/src/bun.js/bindings/BunPlugin.cpp index e0527c7241..0cd47b3c93 100644 --- a/src/bun.js/bindings/BunPlugin.cpp +++ b/src/bun.js/bindings/BunPlugin.cpp @@ -22,6 +22,11 @@ #include "JavaScriptCore/RegularExpression.h" #include "JavaScriptCore/JSMap.h" #include "JavaScriptCore/JSMapInlines.h" +#include "JavaScriptCore/JSModuleRecord.h" +#include "JavaScriptCore/JSModuleNamespaceObject.h" +#include "JavaScriptCore/SourceOrigin.h" +#include "JavaScriptCore/JSModuleLoader.h" +#include "CommonJSModuleRecord.h" namespace Zig { @@ -403,6 +408,277 @@ JSFunction* BunPlugin::Group::find(JSC::JSGlobalObject* globalObject, String& pa return nullptr; } +void BunPlugin::OnLoad::addModuleMock(JSC::VM& vm, const String& path, JSC::JSObject* mockObject) +{ + Zig::GlobalObject* globalObject = Zig::jsCast(mockObject->globalObject()); + + if (globalObject->onLoadPlugins.virtualModules == nullptr) { + globalObject->onLoadPlugins.virtualModules = new BunPlugin::VirtualModuleMap; + } + auto* virtualModules = globalObject->onLoadPlugins.virtualModules; + + virtualModules->set(path, JSC::Strong { vm, mockObject }); +} + +class JSModuleMock final : public JSC::JSNonFinalObject { +public: + using Base = JSC::JSNonFinalObject; + + mutable WriteBarrier callbackFunctionOrCachedResult; + bool hasCalledModuleMock = false; + + static JSModuleMock* create(JSC::VM& vm, JSC::Structure* structure, JSC::JSObject* callback); + static Structure* createStructure(JSC::VM& vm, JSC::JSGlobalObject* globalObject, JSC::JSValue prototype); + + DECLARE_INFO; + DECLARE_VISIT_CHILDREN; + + JSObject* executeOnce(JSC::JSGlobalObject* lexicalGlobalObject); + + template static JSC::GCClient::IsoSubspace* subspaceFor(JSC::VM& vm) + { + if constexpr (mode == JSC::SubspaceAccess::Concurrently) + return nullptr; + return WebCore::subspaceForImpl( + vm, + [](auto& spaces) { return spaces.m_clientSubspaceForJSModuleMock.get(); }, + [](auto& spaces, auto&& space) { spaces.m_clientSubspaceForJSModuleMock = std::forward(space); }, + [](auto& spaces) { return spaces.m_subspaceForJSModuleMock.get(); }, + [](auto& spaces, auto&& space) { spaces.m_subspaceForJSModuleMock = std::forward(space); }); + } + + void finishCreation(JSC::VM&, JSC::JSObject* callback); + +private: + JSModuleMock(JSC::VM&, JSC::Structure*); +}; + +const JSC::ClassInfo JSModuleMock::s_info = { "ModuleMock"_s, &Base::s_info, nullptr, nullptr, CREATE_METHOD_TABLE(JSModuleMock) }; + +JSModuleMock* JSModuleMock::create(JSC::VM& vm, JSC::Structure* structure, JSC::JSObject* callback) +{ + JSModuleMock* ptr = new (NotNull, JSC::allocateCell(vm)) JSModuleMock(vm, structure); + ptr->finishCreation(vm, callback); + return ptr; +} + +void JSModuleMock::finishCreation(JSC::VM& vm, JSObject* callback) +{ + Base::finishCreation(vm); + callbackFunctionOrCachedResult.set(vm, this, callback); +} + +JSModuleMock::JSModuleMock(JSC::VM& vm, JSC::Structure* structure) + : Base(vm, structure) +{ +} + +Structure* JSModuleMock::createStructure(JSC::VM& vm, JSC::JSGlobalObject* globalObject, JSC::JSValue prototype) +{ + return Structure::create(vm, globalObject, prototype, JSC::TypeInfo(JSC::ObjectType, StructureFlags), info()); +} + +JSObject* JSModuleMock::executeOnce(JSC::JSGlobalObject* lexicalGlobalObject) +{ + auto& vm = lexicalGlobalObject->vm(); + auto scope = DECLARE_THROW_SCOPE(vm); + + if (hasCalledModuleMock) { + return callbackFunctionOrCachedResult.get(); + } + + hasCalledModuleMock = true; + + if (!callbackFunctionOrCachedResult) { + scope.throwException(lexicalGlobalObject, JSC::createTypeError(lexicalGlobalObject, "Cannot call mock without a callback"_s)); + return nullptr; + } + + JSC::JSValue callbackValue = callbackFunctionOrCachedResult.get(); + if (!callbackValue.isCell() || !callbackValue.isCallable()) { + scope.throwException(lexicalGlobalObject, JSC::createTypeError(lexicalGlobalObject, "mock(module, fn) requires a function"_s)); + return nullptr; + } + + JSObject* callback = callbackValue.getObject(); + JSC::JSValue result = JSC::call(lexicalGlobalObject, callback, JSC::getCallData(callback), JSC::jsUndefined(), ArgList()); + RETURN_IF_EXCEPTION(scope, {}); + + if (!result.isObject()) { + scope.throwException(lexicalGlobalObject, JSC::createTypeError(lexicalGlobalObject, "mock(module, fn) requires a function that returns an object"_s)); + return nullptr; + } + + auto* object = result.getObject(); + this->callbackFunctionOrCachedResult.set(vm, this, object); + + return object; +} + +extern "C" EncodedJSValue JSMock__jsModuleMock(JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callframe) +{ + JSC::VM& vm = lexicalGlobalObject->vm(); + Zig::GlobalObject* globalObject = jsDynamicCast(lexicalGlobalObject); + auto scope = DECLARE_THROW_SCOPE(vm); + if (UNLIKELY(!globalObject)) { + scope.throwException(lexicalGlobalObject, JSC::createTypeError(lexicalGlobalObject, "Cannot run mock from a different global context"_s)); + return {}; + } + + if (callframe->argumentCount() < 1) { + scope.throwException(lexicalGlobalObject, JSC::createTypeError(lexicalGlobalObject, "mock(module, fn) requires a module and function"_s)); + return {}; + } + + JSC::JSString* specifierString = callframe->argument(0).toString(globalObject); + RETURN_IF_EXCEPTION(scope, {}); + + WTF::String specifier = specifierString->value(globalObject); + + if (specifier.isEmpty()) { + scope.throwException(lexicalGlobalObject, JSC::createTypeError(lexicalGlobalObject, "mock(module, fn) requires a module and function"_s)); + return {}; + } + + if (specifier.startsWith("./"_s) || specifier.startsWith("../"_s) || specifier == "."_s) { + JSC::SourceOrigin sourceOrigin = callframe->callerSourceOrigin(vm); + const URL& url = sourceOrigin.url(); + if (url.protocolIsFile()) { + URL joinedURL = URL(url, specifier); + specifier = joinedURL.fileSystemPath(); + specifierString = jsString(vm, specifier); + } else { + scope.throwException(lexicalGlobalObject, JSC::createTypeError(lexicalGlobalObject, "mock(module, fn) cannot mock relative paths in non-files"_s)); + return {}; + } + } + + JSC::JSValue callbackValue = callframe->argument(1); + if (!callbackValue.isCell() || !callbackValue.isCallable()) { + scope.throwException(lexicalGlobalObject, JSC::createTypeError(lexicalGlobalObject, "mock(module, fn) requires a function"_s)); + return {}; + } + + JSC::JSObject* callback = callbackValue.getObject(); + + JSModuleMock* mock = JSModuleMock::create(vm, globalObject->mockModule.mockModuleStructure.getInitializedOnMainThread(globalObject), callback); + + auto* esm = globalObject->esmRegistryMap(); + + auto getJSValue = [&]() -> JSValue { + auto scope = DECLARE_THROW_SCOPE(vm); + JSValue result = mock->executeOnce(globalObject); + RETURN_IF_EXCEPTION(scope, JSValue()); + + if (result && result.isObject()) { + while (JSC::JSPromise* promise = jsDynamicCast(result)) { + switch (promise->status(vm)) { + case JSC::JSPromise::Status::Rejected: { + result = promise->result(vm); + scope.throwException(globalObject, result); + return {}; + break; + } + case JSC::JSPromise::Status::Fulfilled: { + result = promise->result(vm); + break; + } + // TODO: blocking wait for promise + default: { + break; + } + } + } + } + + return result; + }; + + bool removeFromESM = false; + bool removeFromCJS = false; + + if (JSValue entryValue = esm->get(globalObject, specifierString)) { + removeFromESM = true; + + if (entryValue.isObject()) { + JSObject* entry = entryValue.getObject(); + if (JSValue moduleValue = entry->getIfPropertyExists(globalObject, Identifier::fromString(vm, String("module"_s)))) { + if (auto* mod = jsDynamicCast(moduleValue)) { + JSC::JSModuleNamespaceObject* moduleNamespaceObject = mod->getModuleNamespace(globalObject); + JSValue exportsValue = getJSValue(); + RETURN_IF_EXCEPTION(scope, {}); + removeFromESM = false; + + if (exportsValue.isObject()) { + // TODO: use fast path for property iteration + auto* object = exportsValue.getObject(); + JSC::PropertyNameArray names(vm, PropertyNameMode::Strings, PrivateSymbolMode::Exclude); + JSObject::getOwnPropertyNames(object, globalObject, names, DontEnumPropertiesMode::Exclude); + RETURN_IF_EXCEPTION(scope, {}); + + for (auto& name : names) { + // consistent with regular esm handling code + auto catchScope = DECLARE_CATCH_SCOPE(vm); + JSValue value = object->get(globalObject, name); + if (scope.exception()) { + scope.clearException(); + value = jsUndefined(); + } + moduleNamespaceObject->overrideExportValue(globalObject, name, value); + } + + } else { + // if it's not an object, I guess we just set the default export? + moduleNamespaceObject->overrideExportValue(globalObject, vm.propertyNames->defaultKeyword, exportsValue); + } + + RETURN_IF_EXCEPTION(scope, {}); + + // TODO: do we need to handle intermediate loading state here? + // entry->putDirect(vm, Identifier::fromString(vm, String("evaluated"_s)), jsBoolean(true), 0); + // entry->putDirect(vm, Identifier::fromString(vm, String("state"_s)), jsNumber(JSC::JSModuleLoader::Status::Ready), 0); + } + } + } + } + + if (auto entryValue = globalObject->requireMap()->get(globalObject, specifierString)) { + removeFromCJS = true; + if (auto* moduleObject = jsDynamicCast(entryValue)) { + JSValue exportsValue = getJSValue(); + RETURN_IF_EXCEPTION(scope, {}); + + moduleObject->putDirect(vm, Bun::builtinNames(vm).exportsPublicName(), exportsValue, 0); + moduleObject->hasEvaluated = true; + removeFromCJS = false; + } + } + + if (removeFromESM) { + esm->remove(globalObject, specifierString); + } + + if (removeFromCJS) { + globalObject->requireMap()->remove(globalObject, specifierString); + } + + globalObject->onLoadPlugins.addModuleMock(vm, specifier, mock); + + return JSValue::encode(jsUndefined()); +} + +template +void JSModuleMock::visitChildrenImpl(JSCell* cell, Visitor& visitor) +{ + JSModuleMock* mock = jsCast(cell); + ASSERT_GC_OBJECT_INHERITS(mock, info()); + Base::visitChildren(mock, visitor); + + visitor.append(mock->callbackFunctionOrCachedResult); +} + +DEFINE_VISIT_CHILDREN(JSModuleMock); + EncodedJSValue BunPlugin::OnLoad::run(JSC::JSGlobalObject* globalObject, BunString* namespaceString, BunString* path) { Group* groupPtr = this->group(namespaceString ? Bun::toWTFString(*namespaceString) : String()); @@ -559,7 +835,13 @@ extern "C" JSC::EncodedJSValue Bun__runOnLoadPlugins(Zig::GlobalObject* globalOb } namespace Bun { -JSC::JSValue runVirtualModule(Zig::GlobalObject* globalObject, BunString* specifier) + +Structure* createModuleMockStructure(JSC::VM& vm, JSC::JSGlobalObject* globalObject, JSC::JSValue prototype) +{ + return Zig::JSModuleMock::createStructure(vm, globalObject, prototype); +} + +JSC::JSValue runVirtualModule(Zig::GlobalObject* globalObject, BunString* specifier, bool& wasModuleMock) { auto fallback = [&]() -> JSC::JSValue { return JSValue::decode(Bun__runVirtualModule(globalObject, specifier)); @@ -570,16 +852,27 @@ JSC::JSValue runVirtualModule(Zig::GlobalObject* globalObject, BunString* specif } auto& virtualModules = *globalObject->onLoadPlugins.virtualModules; WTF::String specifierString = Bun::toWTFString(*specifier); + if (auto virtualModuleFn = virtualModules.get(specifierString)) { auto& vm = globalObject->vm(); JSC::JSObject* function = virtualModuleFn.get(); auto throwScope = DECLARE_THROW_SCOPE(vm); - JSC::MarkedArgumentBuffer arguments; - JSC::CallData callData = JSC::getCallData(function); - RELEASE_ASSERT(callData.type != JSC::CallData::Type::None); + JSValue result; + + if (Zig::JSModuleMock* moduleMock = jsDynamicCast(function)) { + wasModuleMock = true; + // module mock + result = moduleMock->executeOnce(globalObject); + } else { + // regular function + JSC::MarkedArgumentBuffer arguments; + JSC::CallData callData = JSC::getCallData(function); + RELEASE_ASSERT(callData.type != JSC::CallData::Type::None); + + result = call(globalObject, function, callData, JSC::jsUndefined(), arguments); + } - auto result = call(globalObject, function, callData, JSC::jsUndefined(), arguments); RETURN_IF_EXCEPTION(throwScope, JSC::jsUndefined()); if (auto* promise = JSC::jsDynamicCast(result)) { diff --git a/src/bun.js/bindings/BunPlugin.h b/src/bun.js/bindings/BunPlugin.h index f4d09883d1..162e70dc7a 100644 --- a/src/bun.js/bindings/BunPlugin.h +++ b/src/bun.js/bindings/BunPlugin.h @@ -72,6 +72,8 @@ public: VirtualModuleMap* virtualModules = nullptr; JSC::EncodedJSValue run(JSC::JSGlobalObject* globalObject, BunString* namespaceString, BunString* path); + void addModuleMock(JSC::VM& vm, const String& path, JSC::JSObject* mock); + ~OnLoad() { if (virtualModules) { @@ -97,5 +99,6 @@ class GlobalObject; } // namespace Zig namespace Bun { -JSC::JSValue runVirtualModule(Zig::GlobalObject*, BunString* specifier); +JSC::JSValue runVirtualModule(Zig::GlobalObject*, BunString* specifier, bool& wasModuleMock); +JSC::Structure* createModuleMockStructure(JSC::VM& vm, JSC::JSGlobalObject* globalObject, JSC::JSValue prototype); } \ No newline at end of file diff --git a/src/bun.js/bindings/CommonJSModuleRecord.cpp b/src/bun.js/bindings/CommonJSModuleRecord.cpp index f47c575590..8e1da59c42 100644 --- a/src/bun.js/bindings/CommonJSModuleRecord.cpp +++ b/src/bun.js/bindings/CommonJSModuleRecord.cpp @@ -645,14 +645,15 @@ bool JSCommonJSModule::evaluate( RELEASE_AND_RETURN(throwScope, true); } -void JSCommonJSModule::toSyntheticSource(JSC::JSGlobalObject* globalObject, - JSC::Identifier moduleKey, +void populateESMExports( + JSC::JSGlobalObject* globalObject, + JSValue result, Vector& exportNames, - JSC::MarkedArgumentBuffer& exportValues) + JSC::MarkedArgumentBuffer& exportValues, + bool ignoreESModuleAnnotation) { - auto result = this->exportsObject(); - auto& vm = globalObject->vm(); + Identifier esModuleMarker = builtinNames(vm).__esModulePublicName(); // Bun's intepretation of the "__esModule" annotation: // @@ -686,6 +687,7 @@ void JSCommonJSModule::toSyntheticSource(JSC::JSGlobalObject* globalObject, if (result.isObject()) { auto* exports = result.getObject(); + bool hasESModuleMarker = !ignoreESModuleAnnotation && exports->hasProperty(globalObject, esModuleMarker); auto* structure = exports->structure(); uint32_t size = structure->inlineSize() + structure->outOfLineSize(); @@ -694,8 +696,6 @@ void JSCommonJSModule::toSyntheticSource(JSC::JSGlobalObject* globalObject, auto catchScope = DECLARE_CATCH_SCOPE(vm); - Identifier esModuleMarker = builtinNames(vm).__esModulePublicName(); - bool hasESModuleMarker = !this->ignoreESModuleAnnotation && exports->hasProperty(globalObject, esModuleMarker); if (catchScope.exception()) { catchScope.clearException(); } @@ -805,6 +805,18 @@ void JSCommonJSModule::toSyntheticSource(JSC::JSGlobalObject* globalObject, } } +void JSCommonJSModule::toSyntheticSource(JSC::JSGlobalObject* globalObject, + JSC::Identifier moduleKey, + Vector& exportNames, + JSC::MarkedArgumentBuffer& exportValues) +{ + auto result = this->exportsObject(); + + auto& vm = globalObject->vm(); + Identifier esModuleMarker = builtinNames(vm).__esModulePublicName(); + populateESMExports(globalObject, result, exportNames, exportValues, this->ignoreESModuleAnnotation); +} + JSValue JSCommonJSModule::exportsObject() { return this->get(globalObject(), JSC::PropertyName(clientData(vm())->builtinNames().exportsPublicName())); diff --git a/src/bun.js/bindings/CommonJSModuleRecord.h b/src/bun.js/bindings/CommonJSModuleRecord.h index 37353978e3..219174a14f 100644 --- a/src/bun.js/bindings/CommonJSModuleRecord.h +++ b/src/bun.js/bindings/CommonJSModuleRecord.h @@ -17,6 +17,13 @@ namespace Bun { JSC_DECLARE_HOST_FUNCTION(jsFunctionCreateCommonJSModule); JSC_DECLARE_HOST_FUNCTION(jsFunctionLoadModule); +void populateESMExports( + JSC::JSGlobalObject* globalObject, + JSC::JSValue result, + WTF::Vector& exportNames, + JSC::MarkedArgumentBuffer& exportValues, + bool ignoreESModuleAnnotation); + class JSCommonJSModule final : public JSC::JSDestructibleObject { public: using Base = JSC::JSDestructibleObject; diff --git a/src/bun.js/bindings/JSMockFunction.cpp b/src/bun.js/bindings/JSMockFunction.cpp index 498864dfc2..3115229cbd 100644 --- a/src/bun.js/bindings/JSMockFunction.cpp +++ b/src/bun.js/bindings/JSMockFunction.cpp @@ -20,7 +20,9 @@ #include #include #include - +#include +#include +#include "BunPlugin.h" namespace Bun { /** @@ -606,7 +608,13 @@ extern "C" EncodedJSValue JSMock__jsSpyOn(JSC::JSGlobalObject* lexicalGlobalObje if (!hasValue || slot.isValue()) { JSValue value = jsUndefined(); if (hasValue) { - value = slot.getValue(globalObject, propertyKey); + if (UNLIKELY(slot.isTaintedByOpaqueObject())) { + // if it's a Proxy or JSModuleNamespaceObject + value = object->get(globalObject, propertyKey); + } else { + value = slot.getValue(globalObject, propertyKey); + } + if (jsDynamicCast(value)) { return JSValue::encode(value); } @@ -624,7 +632,12 @@ extern "C" EncodedJSValue JSMock__jsSpyOn(JSC::JSGlobalObject* lexicalGlobalObje mock->copyNameAndLength(vm, globalObject, value); - object->putDirect(vm, propertyKey, mock, attributes); + if (JSModuleNamespaceObject* moduleNamespaceObject = jsDynamicCast(object)) { + moduleNamespaceObject->overrideExportValue(globalObject, propertyKey, mock); + } else { + object->putDirect(vm, propertyKey, mock, attributes); + } + RETURN_IF_EXCEPTION(scope, {}); pushImpl(mock, globalObject, JSMockImplementation::Kind::Call, value); @@ -633,7 +646,13 @@ extern "C" EncodedJSValue JSMock__jsSpyOn(JSC::JSGlobalObject* lexicalGlobalObje attributes = slot.attributes(); attributes |= PropertyAttribute::Accessor; - object->putDirect(vm, propertyKey, JSC::GetterSetter::create(vm, globalObject, mock, mock), attributes); + + if (JSModuleNamespaceObject* moduleNamespaceObject = jsDynamicCast(object)) { + moduleNamespaceObject->overrideExportValue(globalObject, propertyKey, mock); + } else { + object->putDirect(vm, propertyKey, JSC::GetterSetter::create(vm, globalObject, mock, mock), attributes); + } + // mock->setName(propertyKey.publicName()); RETURN_IF_EXCEPTION(scope, {}); @@ -696,6 +715,13 @@ JSMockModule JSMockModule::create(JSC::JSGlobalObject* globalObject) Structure* implementation = ActiveSpySet::createStructure(init.vm, init.owner, jsNull()); init.set(implementation); }); + + mock.mockModuleStructure.initLater( + [](const JSC::LazyProperty::Initializer& init) { + Structure* implementation = createModuleMockStructure(init.vm, init.owner, jsNull()); + init.set(implementation); + }); + mock.mockImplementationStructure.initLater( [](const JSC::LazyProperty::Initializer& init) { Structure* implementation = JSMockImplementation::createStructure(init.vm, init.owner, jsNull()); diff --git a/src/bun.js/bindings/JSMockFunction.h b/src/bun.js/bindings/JSMockFunction.h index 93c8bb0150..592481ef80 100644 --- a/src/bun.js/bindings/JSMockFunction.h +++ b/src/bun.js/bindings/JSMockFunction.h @@ -23,6 +23,7 @@ public: LazyProperty mockResultStructure; LazyProperty mockImplementationStructure; LazyProperty mockObjectStructure; + LazyProperty mockModuleStructure; LazyProperty activeSpySetStructure; LazyProperty withImplementationCleanupFunction; LazyProperty mockWithImplementationCleanupDataStructure; diff --git a/src/bun.js/bindings/ModuleLoader.cpp b/src/bun.js/bindings/ModuleLoader.cpp index ac980d0623..55e0bbf821 100644 --- a/src/bun.js/bindings/ModuleLoader.cpp +++ b/src/bun.js/bindings/ModuleLoader.cpp @@ -43,6 +43,8 @@ using namespace JSC; using namespace Zig; using namespace WebCore; +static OnLoadResult handleOnLoadResultNotPromise(Zig::GlobalObject* globalObject, JSC::JSValue objectValue, bool wasModuleMock = false); + extern "C" BunLoaderType Bun__getDefaultLoader(JSC::JSGlobalObject*, BunString* specifier); static JSC::JSInternalPromise* rejectedInternalPromise(JSC::JSGlobalObject* globalObject, JSC::JSValue value) @@ -178,7 +180,7 @@ PendingVirtualModuleResult* PendingVirtualModuleResult::create(JSC::JSGlobalObje return virtualModule; } -OnLoadResult handleOnLoadResultNotPromise(Zig::GlobalObject* globalObject, JSC::JSValue objectValue, BunString* specifier) +OnLoadResult handleOnLoadResultNotPromise(Zig::GlobalObject* globalObject, JSC::JSValue objectValue, BunString* specifier, bool wasModuleMock) { OnLoadResult result = {}; result.type = OnLoadResultTypeError; @@ -193,9 +195,16 @@ OnLoadResult handleOnLoadResultNotPromise(Zig::GlobalObject* globalObject, JSC:: return result; } + if (wasModuleMock) { + result.type = OnLoadResultTypeObject; + result.value.object = objectValue; + return result; + } + JSC::JSObject* object = objectValue.getObject(); if (UNLIKELY(!object)) { - scope.throwException(globalObject, JSC::createError(globalObject, "Expected onLoad callback to return an object"_s)); + scope.throwException(globalObject, JSC::createError(globalObject, "Expected module mock to return an object"_s)); + result.value.error = scope.exception(); result.type = OnLoadResultTypeError; return result; @@ -259,16 +268,17 @@ OnLoadResult handleOnLoadResultNotPromise(Zig::GlobalObject* globalObject, JSC:: return result; } -static OnLoadResult handleOnLoadResult(Zig::GlobalObject* globalObject, JSC::JSValue objectValue, BunString* specifier) +static OnLoadResult handleOnLoadResult(Zig::GlobalObject* globalObject, JSC::JSValue objectValue, BunString* specifier, bool wasModuleMock = false) { if (JSC::JSPromise* promise = JSC::jsDynamicCast(objectValue)) { OnLoadResult result = {}; result.type = OnLoadResultTypePromise; result.value.promise = objectValue; + result.wasMock = wasModuleMock; return result; } - return handleOnLoadResultNotPromise(globalObject, objectValue, specifier); + return handleOnLoadResultNotPromise(globalObject, objectValue, specifier, wasModuleMock); } template @@ -277,9 +287,10 @@ static JSValue handleVirtualModuleResult( JSValue virtualModuleResult, ErrorableResolvedSource* res, BunString* specifier, - BunString* referrer) + BunString* referrer, + bool wasModuleMock = false) { - auto onLoadResult = handleOnLoadResult(globalObject, virtualModuleResult, specifier); + auto onLoadResult = handleOnLoadResult(globalObject, virtualModuleResult, specifier, wasModuleMock); JSC::VM& vm = globalObject->vm(); auto scope = DECLARE_THROW_SCOPE(vm); @@ -409,6 +420,8 @@ extern "C" void Bun__onFulfillAsyncModule( promise->resolve(promise->globalObject(), JSC::JSSourceCode::create(vm, JSC::SourceCode(provider))); } +extern "C" bool isBunTest; + JSValue fetchCommonJSModule( Zig::GlobalObject* globalObject, JSCommonJSModule* target, @@ -424,6 +437,40 @@ JSValue fetchCommonJSModule( auto& builtinNames = WebCore::clientData(vm)->builtinNames(); + bool wasModuleMock = false; + + // When "bun test" is enabled, allow users to override builtin modules + // This is important for being able to trivially mock things like the filesystem. + if (isBunTest) { + if (JSC::JSValue virtualModuleResult = Bun::runVirtualModule(globalObject, specifier, wasModuleMock)) { + JSPromise* promise = jsCast(handleVirtualModuleResult(globalObject, virtualModuleResult, res, specifier, referrer, wasModuleMock)); + switch (promise->status(vm)) { + case JSPromise::Status::Rejected: { + uint32_t promiseFlags = promise->internalField(JSPromise::Field::Flags).get().asUInt32AsAnyInt(); + promise->internalField(JSPromise::Field::Flags).set(vm, promise, jsNumber(promiseFlags | JSPromise::isHandledFlag)); + JSC::throwException(globalObject, scope, promise->result(vm)); + RELEASE_AND_RETURN(scope, JSValue {}); + } + case JSPromise::Status::Pending: { + JSC::throwTypeError(globalObject, scope, makeString("require() async module \""_s, Bun::toWTFString(*specifier), "\" is unsupported. use \"await import()\" instead."_s)); + RELEASE_AND_RETURN(scope, JSValue {}); + } + case JSPromise::Status::Fulfilled: { + if (!res->success) { + throwException(scope, res->result.err, globalObject); + RELEASE_AND_RETURN(scope, {}); + } + if (!wasModuleMock) { + auto* jsSourceCode = jsCast(promise->result(vm)); + globalObject->moduleLoader()->provideFetch(globalObject, specifierValue, jsSourceCode->sourceCode()); + RETURN_IF_EXCEPTION(scope, {}); + } + RELEASE_AND_RETURN(scope, jsNumber(-1)); + } + } + } + } + if (Bun__fetchBuiltinModule(bunVM, globalObject, specifier, referrer, res)) { if (!res->success) { throwException(scope, res->result.err, globalObject); @@ -465,29 +512,34 @@ JSValue fetchCommonJSModule( } } - if (JSC::JSValue virtualModuleResult = Bun::runVirtualModule(globalObject, specifier)) { - JSPromise* promise = jsCast(handleVirtualModuleResult(globalObject, virtualModuleResult, res, specifier, referrer)); - switch (promise->status(vm)) { - case JSPromise::Status::Rejected: { - uint32_t promiseFlags = promise->internalField(JSPromise::Field::Flags).get().asUInt32AsAnyInt(); - promise->internalField(JSPromise::Field::Flags).set(vm, promise, jsNumber(promiseFlags | JSPromise::isHandledFlag)); - JSC::throwException(globalObject, scope, promise->result(vm)); - RELEASE_AND_RETURN(scope, JSValue {}); - } - case JSPromise::Status::Pending: { - JSC::throwTypeError(globalObject, scope, makeString("require() async module \""_s, Bun::toWTFString(*specifier), "\" is unsupported. use \"await import()\" instead."_s)); - RELEASE_AND_RETURN(scope, JSValue {}); - } - case JSPromise::Status::Fulfilled: { - if (!res->success) { - throwException(scope, res->result.err, globalObject); - RELEASE_AND_RETURN(scope, {}); + // When "bun test" is NOT enabled, disable users from overriding builtin modules + if (!isBunTest) { + if (JSC::JSValue virtualModuleResult = Bun::runVirtualModule(globalObject, specifier, wasModuleMock)) { + JSPromise* promise = jsCast(handleVirtualModuleResult(globalObject, virtualModuleResult, res, specifier, referrer, wasModuleMock)); + switch (promise->status(vm)) { + case JSPromise::Status::Rejected: { + uint32_t promiseFlags = promise->internalField(JSPromise::Field::Flags).get().asUInt32AsAnyInt(); + promise->internalField(JSPromise::Field::Flags).set(vm, promise, jsNumber(promiseFlags | JSPromise::isHandledFlag)); + JSC::throwException(globalObject, scope, promise->result(vm)); + RELEASE_AND_RETURN(scope, JSValue {}); + } + case JSPromise::Status::Pending: { + JSC::throwTypeError(globalObject, scope, makeString("require() async module \""_s, Bun::toWTFString(*specifier), "\" is unsupported. use \"await import()\" instead."_s)); + RELEASE_AND_RETURN(scope, JSValue {}); + } + case JSPromise::Status::Fulfilled: { + if (!res->success) { + throwException(scope, res->result.err, globalObject); + RELEASE_AND_RETURN(scope, {}); + } + if (!wasModuleMock) { + auto* jsSourceCode = jsCast(promise->result(vm)); + globalObject->moduleLoader()->provideFetch(globalObject, specifierValue, jsSourceCode->sourceCode()); + RETURN_IF_EXCEPTION(scope, {}); + } + RELEASE_AND_RETURN(scope, jsNumber(-1)); + } } - auto* jsSourceCode = jsCast(promise->result(vm)); - globalObject->moduleLoader()->provideFetch(globalObject, specifierValue, jsSourceCode->sourceCode()); - RETURN_IF_EXCEPTION(scope, {}); - RELEASE_AND_RETURN(scope, jsNumber(-1)); - } } } @@ -551,6 +603,8 @@ JSValue fetchCommonJSModule( RELEASE_AND_RETURN(scope, jsNumber(-1)); } +extern "C" bool isBunTest; + template static JSValue fetchESMSourceCode( Zig::GlobalObject* globalObject, @@ -601,6 +655,16 @@ static JSValue fetchESMSourceCode( } }; + bool wasModuleMock = false; + + // When "bun test" is enabled, allow users to override builtin modules + // This is important for being able to trivially mock things like the filesystem. + if (isBunTest) { + if (JSC::JSValue virtualModuleResult = Bun::runVirtualModule(globalObject, specifier, wasModuleMock)) { + return handleVirtualModuleResult(globalObject, virtualModuleResult, res, specifier, referrer, wasModuleMock); + } + } + if (Bun__fetchBuiltinModule(bunVM, globalObject, specifier, referrer, res)) { if (!res->success) { throwException(scope, res->result.err, globalObject); @@ -640,8 +704,11 @@ static JSValue fetchESMSourceCode( } } - if (JSC::JSValue virtualModuleResult = Bun::runVirtualModule(globalObject, specifier)) { - return handleVirtualModuleResult(globalObject, virtualModuleResult, res, specifier, referrer); + // When "bun test" is NOT enabled, disable users from overriding builtin modules + if (!isBunTest) { + if (JSC::JSValue virtualModuleResult = Bun::runVirtualModule(globalObject, specifier, wasModuleMock)) { + return handleVirtualModuleResult(globalObject, virtualModuleResult, res, specifier, referrer, wasModuleMock); + } } if constexpr (allowPromise) { @@ -724,7 +791,10 @@ extern "C" JSC::EncodedJSValue jsFunctionOnLoadObjectResultResolve(JSC::JSGlobal BunString specifier = Bun::toString(globalObject, specifierString); BunString referrer = Bun::toString(globalObject, referrerString); auto scope = DECLARE_THROW_SCOPE(vm); - JSC::JSValue result = handleVirtualModuleResult(reinterpret_cast(globalObject), objectResult, &res, &specifier, &referrer); + + bool wasModuleMock = pendingModule->wasModuleMock; + + JSC::JSValue result = handleVirtualModuleResult(reinterpret_cast(globalObject), objectResult, &res, &specifier, &referrer, wasModuleMock); if (res.success) { if (scope.exception()) { auto retValue = JSValue::encode(promise->rejectWithCaughtException(globalObject, scope)); diff --git a/src/bun.js/bindings/ModuleLoader.h b/src/bun.js/bindings/ModuleLoader.h index 72dd8b49ae..e6070d5385 100644 --- a/src/bun.js/bindings/ModuleLoader.h +++ b/src/bun.js/bindings/ModuleLoader.h @@ -42,6 +42,7 @@ union OnLoadResultValue { struct OnLoadResult { OnLoadResultValue value; OnLoadResultType type; + bool wasMock; }; class PendingVirtualModuleResult : public JSC::JSInternalFieldObjectImpl<3> { @@ -81,9 +82,10 @@ public: PendingVirtualModuleResult(JSC::VM&, JSC::Structure*); void finishCreation(JSC::VM&, const WTF::String& specifier, const WTF::String& referrer); + + bool wasModuleMock = false; }; -OnLoadResult handleOnLoadResultNotPromise(Zig::GlobalObject* globalObject, JSC::JSValue objectValue); JSValue fetchESMSourceCodeSync( Zig::GlobalObject* globalObject, ErrorableResolvedSource* res, diff --git a/src/bun.js/bindings/ZigGlobalObject.cpp b/src/bun.js/bindings/ZigGlobalObject.cpp index 15647cd240..4177e1352a 100644 --- a/src/bun.js/bindings/ZigGlobalObject.cpp +++ b/src/bun.js/bindings/ZigGlobalObject.cpp @@ -2206,7 +2206,8 @@ static inline EncodedJSValue functionPerformanceNowBody(JSGlobalObject* globalOb return JSValue::encode(jsDoubleNumber(result)); } -static inline EncodedJSValue functionPerformanceGetEntriesByNameBody(JSGlobalObject* globalObject) { +static inline EncodedJSValue functionPerformanceGetEntriesByNameBody(JSGlobalObject* globalObject) +{ auto& vm = globalObject->vm(); auto* global = reinterpret_cast(globalObject); auto* array = JSC::constructEmptyArray(globalObject, nullptr); @@ -2296,7 +2297,6 @@ JSC_DEFINE_HOST_FUNCTION(functionPerformanceNow, (JSGlobalObject * globalObject, return functionPerformanceNowBody(globalObject); } - JSC_DEFINE_HOST_FUNCTION(functionPerformanceGetEntriesByName, (JSGlobalObject * globalObject, JSC::CallFrame* callFrame)) { return functionPerformanceGetEntriesByNameBody(globalObject); @@ -3149,6 +3149,24 @@ void GlobalObject::finishCreation(VM& vm) init.set(map); }); + m_esmRegistryMap.initLater( + [](const JSC::LazyProperty::Initializer& init) { + auto* global = init.owner; + auto& vm = init.vm; + JSMap* registry = nullptr; + if (auto loaderValue = global->getIfPropertyExists(global, JSC::Identifier::fromString(vm, "Loader"_s))) { + if (auto registryValue = loaderValue.getObject()->getIfPropertyExists(global, JSC::Identifier::fromString(vm, "registry"_s))) { + registry = jsCast(registryValue); + } + } + + if (!registry) { + registry = JSC::JSMap::create(init.vm, init.owner->mapStructure()); + } + + init.set(registry); + }); + m_encodeIntoObjectStructure.initLater( [](const JSC::LazyProperty::Initializer& init) { auto& vm = init.vm; @@ -3886,6 +3904,7 @@ void GlobalObject::visitChildrenImpl(JSCell* cell, Visitor& visitor) thisObject->m_utilInspectStylizeNoColorFunction.visit(visitor); thisObject->m_lazyReadableStreamPrototypeMap.visit(visitor); thisObject->m_requireMap.visit(visitor); + thisObject->m_esmRegistryMap.visit(visitor); thisObject->m_encodeIntoObjectStructure.visit(visitor); thisObject->m_JSArrayBufferControllerPrototype.visit(visitor); thisObject->m_JSFileSinkControllerPrototype.visit(visitor); @@ -3928,6 +3947,7 @@ void GlobalObject::visitChildrenImpl(JSCell* cell, Visitor& visitor) thisObject->mockModule.mockResultStructure.visit(visitor); thisObject->mockModule.mockImplementationStructure.visit(visitor); thisObject->mockModule.mockObjectStructure.visit(visitor); + thisObject->mockModule.mockModuleStructure.visit(visitor); thisObject->mockModule.activeSpySetStructure.visit(visitor); thisObject->mockModule.mockWithImplementationCleanupDataStructure.visit(visitor); thisObject->mockModule.withImplementationCleanupFunction.visit(visitor); diff --git a/src/bun.js/bindings/ZigGlobalObject.h b/src/bun.js/bindings/ZigGlobalObject.h index 19630d4b7f..86cfb04e42 100644 --- a/src/bun.js/bindings/ZigGlobalObject.h +++ b/src/bun.js/bindings/ZigGlobalObject.h @@ -213,6 +213,7 @@ public: JSC::JSMap* readableStreamNativeMap() { return m_lazyReadableStreamPrototypeMap.getInitializedOnMainThread(this); } JSC::JSMap* requireMap() { return m_requireMap.getInitializedOnMainThread(this); } + JSC::JSMap* esmRegistryMap() { return m_esmRegistryMap.getInitializedOnMainThread(this); } JSC::Structure* encodeIntoObjectStructure() { return m_encodeIntoObjectStructure.getInitializedOnMainThread(this); } JSC::Structure* callSiteStructure() const { return m_callSiteStructure.getInitializedOnMainThread(this); } @@ -481,6 +482,7 @@ public: LazyProperty m_emitReadableNextTickFunction; LazyProperty m_lazyReadableStreamPrototypeMap; LazyProperty m_requireMap; + LazyProperty m_esmRegistryMap; LazyProperty m_encodeIntoObjectStructure; LazyProperty m_JSArrayBufferControllerPrototype; LazyProperty m_JSFileSinkControllerPrototype; diff --git a/src/bun.js/bindings/webcore/DOMClientIsoSubspaces.h b/src/bun.js/bindings/webcore/DOMClientIsoSubspaces.h index e21a62bf8c..c6860580ba 100644 --- a/src/bun.js/bindings/webcore/DOMClientIsoSubspaces.h +++ b/src/bun.js/bindings/webcore/DOMClientIsoSubspaces.h @@ -36,6 +36,7 @@ public: std::unique_ptr m_clientSubspaceForNodeVMScript; std::unique_ptr m_clientSubspaceForCommonJSModuleRecord; std::unique_ptr m_clientSubspaceForJSMockImplementation; + std::unique_ptr m_clientSubspaceForJSModuleMock; std::unique_ptr m_clientSubspaceForJSMockFunction; std::unique_ptr m_clientSubspaceForAsyncContextFrame; std::unique_ptr m_clientSubspaceForMockWithImplementationCleanupData; diff --git a/src/bun.js/bindings/webcore/DOMIsoSubspaces.h b/src/bun.js/bindings/webcore/DOMIsoSubspaces.h index 806aa4454d..6d7da432ee 100644 --- a/src/bun.js/bindings/webcore/DOMIsoSubspaces.h +++ b/src/bun.js/bindings/webcore/DOMIsoSubspaces.h @@ -36,6 +36,7 @@ public: std::unique_ptr m_subspaceForNodeVMScript; std::unique_ptr m_subspaceForCommonJSModuleRecord; std::unique_ptr m_subspaceForJSMockImplementation; + std::unique_ptr m_subspaceForJSModuleMock; std::unique_ptr m_subspaceForJSMockFunction; std::unique_ptr m_subspaceForAsyncContextFrame; std::unique_ptr m_subspaceForMockWithImplementationCleanupData; diff --git a/src/bun.js/javascript.zig b/src/bun.js/javascript.zig index 1c996039d8..e8884a8700 100644 --- a/src/bun.js/javascript.zig +++ b/src/bun.js/javascript.zig @@ -3258,3 +3258,5 @@ pub fn NewHotReloader(comptime Ctx: type, comptime EventLoopType: type, comptime } }; } + +pub export var isBunTest: bool = false; \ No newline at end of file diff --git a/src/bun.js/test/jest.zig b/src/bun.js/test/jest.zig index 5490b3472d..f3e25c41f9 100644 --- a/src/bun.js/test/jest.zig +++ b/src/bun.js/test/jest.zig @@ -459,7 +459,9 @@ pub const Jest = struct { const mockFn = JSC.NewFunction(globalObject, ZigString.static("fn"), 1, JSMock__jsMockFn, false); const spyOn = JSC.NewFunction(globalObject, ZigString.static("spyOn"), 2, JSMock__jsSpyOn, false); const restoreAllMocks = JSC.NewFunction(globalObject, ZigString.static("restoreAllMocks"), 2, JSMock__jsRestoreAllMocks, false); + const mockModuleFn = JSC.NewFunction(globalObject, ZigString.static("module"), 2, JSMock__jsModuleMock, false); module.put(globalObject, ZigString.static("mock"), mockFn); + mockFn.put(globalObject, ZigString.static("module"), mockModuleFn); const jest = JSValue.createEmptyObject(globalObject, 7); jest.put(globalObject, ZigString.static("fn"), mockFn); @@ -488,6 +490,7 @@ pub const Jest = struct { const vi = JSValue.createEmptyObject(globalObject, 3); vi.put(globalObject, ZigString.static("fn"), mockFn); vi.put(globalObject, ZigString.static("spyOn"), spyOn); + vi.put(globalObject, ZigString.static("module"), mockModuleFn); vi.put(globalObject, ZigString.static("restoreAllMocks"), restoreAllMocks); module.put(globalObject, ZigString.static("vi"), vi); @@ -497,6 +500,7 @@ pub const Jest = struct { extern fn Bun__Jest__testPreloadObject(*JSC.JSGlobalObject) JSC.JSValue; extern fn Bun__Jest__testModuleObject(*JSC.JSGlobalObject) JSC.JSValue; extern fn JSMock__jsMockFn(*JSC.JSGlobalObject, *JSC.CallFrame) JSC.JSValue; + extern fn JSMock__jsModuleMock(*JSC.JSGlobalObject, *JSC.CallFrame) JSC.JSValue; extern fn JSMock__jsNow(*JSC.JSGlobalObject, *JSC.CallFrame) JSC.JSValue; extern fn JSMock__jsSetSystemTime(*JSC.JSGlobalObject, *JSC.CallFrame) JSC.JSValue; extern fn JSMock__jsRestoreAllMocks(*JSC.JSGlobalObject, *JSC.CallFrame) JSC.JSValue; diff --git a/src/cli/test_command.zig b/src/cli/test_command.zig index 141d7ff253..5d4313998d 100644 --- a/src/cli/test_command.zig +++ b/src/cli/test_command.zig @@ -580,6 +580,7 @@ pub const TestCommand = struct { var snapshot_file_buf = std.ArrayList(u8).init(ctx.allocator); var snapshot_values = Snapshots.ValuesHashMap.init(ctx.allocator); var snapshot_counts = bun.StringHashMap(usize).init(ctx.allocator); + JSC.isBunTest = true; var reporter = try ctx.allocator.create(CommandLineReporter); reporter.* = CommandLineReporter{ diff --git a/test/js/bun/test/mock/mock-module-fixture.ts b/test/js/bun/test/mock/mock-module-fixture.ts new file mode 100644 index 0000000000..1bc820d5c7 --- /dev/null +++ b/test/js/bun/test/mock/mock-module-fixture.ts @@ -0,0 +1,9 @@ +export function fn() { + return 42; +} + +export function iCallFn() { + return fn(); +} + +export const variable = 7; diff --git a/test/js/bun/test/mock/mock-module.test.ts b/test/js/bun/test/mock/mock-module.test.ts new file mode 100644 index 0000000000..04e5779ddb --- /dev/null +++ b/test/js/bun/test/mock/mock-module.test.ts @@ -0,0 +1,74 @@ +// TODO: +// - Write tests for errors +// - Write tests for Promise +// - Write tests for Promise rejection +// - Write tests for pending promise when a module already exists +// - Write test for export * from +// - Write test for export {foo} from "./foo" +// - Write test for import {foo} from "./foo"; export {foo} + +import { mock, test, expect } from "bun:test"; +import { fn, iCallFn, variable } from "./mock-module-fixture"; + +test("mocking a local file", async () => { + expect(fn()).toEqual(42); + expect(variable).toEqual(7); + + mock.module("./mock-module-fixture.ts", () => { + return { + fn: () => 1, + variable: 8, + }; + }); + expect(fn()).toEqual(1); + expect(variable).toEqual(8); + mock.module("./mock-module-fixture.ts", () => { + return { + fn: () => 2, + variable: 9, + }; + }); + expect(fn()).toEqual(2); + expect(variable).toEqual(9); + mock.module("./mock-module-fixture.ts", () => { + return { + fn: () => 3, + variable: 10, + }; + }); + expect(fn()).toEqual(3); + expect(variable).toEqual(10); + expect(require("./mock-module-fixture").fn()).toBe(3); + expect(require("./mock-module-fixture").variable).toBe(10); + expect(iCallFn()).toBe(3); +}); + +test("mocking a package", async () => { + mock.module("ha-ha-ha", () => { + return { + wow: () => 42, + }; + }); + const hahaha = await import("ha-ha-ha"); + expect(hahaha.wow()).toBe(42); + expect(require("ha-ha-ha").wow()).toBe(42); + mock.module("ha-ha-ha", () => { + return { + wow: () => 43, + }; + }); + + expect(hahaha.wow()).toBe(43); + expect(require("ha-ha-ha").wow()).toBe(43); +}); + +test("mocking a builtin", async () => { + mock.module("fs/promises", () => { + return { + readFile: () => Promise.resolve("hello world"), + }; + }); + + const { readFile } = await import("node:fs/promises"); + expect(await readFile("hello.txt", "utf8")).toBe("hello world"); +});