diff --git a/src/bun.js/bindings/NodeVMSourceTextModule.cpp b/src/bun.js/bindings/NodeVMSourceTextModule.cpp index 00c8763995..5927c4b35e 100644 --- a/src/bun.js/bindings/NodeVMSourceTextModule.cpp +++ b/src/bun.js/bindings/NodeVMSourceTextModule.cpp @@ -127,13 +127,32 @@ JSValue NodeVMSourceTextModule::createModuleRecord(JSGlobalObject* globalObject) JSArray* requestsArray = JSC::constructEmptyArray(globalObject, nullptr, requests.size()); - // MarkedArgumentBuffer buffer; - const auto& builtinNames = WebCore::clientData(vm)->builtinNames(); const JSC::Identifier& specifierIdentifier = builtinNames.specifierPublicName(); const JSC::Identifier& attributesIdentifier = builtinNames.attributesPublicName(); const JSC::Identifier& hostDefinedImportTypeIdentifier = builtinNames.hostDefinedImportTypePublicName(); + WTF::Vector attributesNodes; + attributesNodes.reserveInitialCapacity(requests.size()); + + for (StatementNode* statement = node->statements()->firstStatement(); statement; statement = statement->next()) { + // Assumption: module declarations occur here in the same order they occur in `requestedModules`. + if (statement->isModuleDeclarationNode()) { + ModuleDeclarationNode* moduleDeclaration = static_cast(statement); + if (moduleDeclaration->isImportDeclarationNode()) { + ImportDeclarationNode* importDeclaration = static_cast(moduleDeclaration); + ASSERT_WITH_MESSAGE(attributesNodes.size() < requests.size(), "More attributes nodes than requests"); + ASSERT_WITH_MESSAGE(importDeclaration->moduleName()->moduleName().string().string() == WTF::String(*requests.at(attributesNodes.size()).m_specifier), "Module name mismatch"); + attributesNodes.append(importDeclaration->attributesList()); + } else if (moduleDeclaration->hasAttributesList()) { + // Necessary to make the indices of `attributesNodes` and `requests` match up + attributesNodes.append(nullptr); + } + } + } + + ASSERT_WITH_MESSAGE(attributesNodes.size() == requests.size(), "Attributes node count doesn't match request count (%zu != %zu)", attributesNodes.size(), requests.size()); + for (unsigned i = 0; i < requests.size(); ++i) { const auto& request = requests[i]; @@ -144,6 +163,9 @@ JSValue NodeVMSourceTextModule::createModuleRecord(JSGlobalObject* globalObject) WTF::String attributesTypeString = "unknown"_str; + WTF::HashMap attributeMap; + JSObject* attributesObject = constructEmptyObject(globalObject); + if (request.m_attributes) { JSValue attributesType {}; switch (request.m_attributes->type()) { @@ -170,23 +192,24 @@ JSValue NodeVMSourceTextModule::createModuleRecord(JSGlobalObject* globalObject) break; } - WTF::HashMap attributeMap { - { "type"_s, attributesTypeString }, - }; - - JSObject* attributesObject = constructEmptyObject(globalObject, globalObject->objectPrototype(), 1); + attributeMap.set("type"_s, WTFMove(attributesTypeString)); attributesObject->putDirect(vm, JSC::Identifier::fromString(vm, "type"_s), attributesType); + if (const String& hostDefinedImportType = request.m_attributes->hostDefinedImportType(); !hostDefinedImportType.isEmpty()) { attributesObject->putDirect(vm, hostDefinedImportTypeIdentifier, JSC::jsString(vm, hostDefinedImportType)); attributeMap.set("hostDefinedImportType"_s, hostDefinedImportType); } - requestObject->putDirect(vm, attributesIdentifier, attributesObject); - addModuleRequest({ WTF::String(*request.m_specifier), WTFMove(attributeMap) }); - } else { - addModuleRequest({ WTF::String(*request.m_specifier), {} }); - requestObject->putDirect(vm, attributesIdentifier, JSC::jsNull()); } + if (ImportAttributesListNode* attributesNode = attributesNodes.at(i)) { + for (auto [key, value] : attributesNode->attributes()) { + attributeMap.set(key->string(), value->string()); + attributesObject->putDirect(vm, *key, JSC::jsString(vm, value->string())); + } + } + + requestObject->putDirect(vm, attributesIdentifier, attributesObject); + addModuleRequest({ WTF::String(*request.m_specifier), WTFMove(attributeMap) }); requestsArray->putDirectIndex(globalObject, i, requestObject); } diff --git a/test/js/node/test/parallel/test-vm-module-link.js b/test/js/node/test/parallel/test-vm-module-link.js new file mode 100644 index 0000000000..26dcb69885 --- /dev/null +++ b/test/js/node/test/parallel/test-vm-module-link.js @@ -0,0 +1,168 @@ +'use strict'; + +// Flags: --experimental-vm-modules --harmony-import-attributes + +const common = require('../common'); + +const assert = require('assert'); + +const { SourceTextModule } = require('vm'); + +async function simple() { + const foo = new SourceTextModule('export default 5;'); + await foo.link(common.mustNotCall()); + + globalThis.fiveResult = undefined; + const bar = new SourceTextModule('import five from "foo"; fiveResult = five'); + + assert.deepStrictEqual(bar.dependencySpecifiers, ['foo']); + + await bar.link(common.mustCall((specifier, module) => { + assert.strictEqual(module, bar); + assert.strictEqual(specifier, 'foo'); + return foo; + })); + + await bar.evaluate(); + assert.strictEqual(globalThis.fiveResult, 5); + delete globalThis.fiveResult; +} + +async function invalidLinkValue() { + const invalidValues = [ + undefined, + null, + {}, + SourceTextModule.prototype, + ]; + + for (const value of invalidValues) { + const module = new SourceTextModule('import "foo"'); + await assert.rejects(module.link(() => value), { + code: 'ERR_VM_MODULE_NOT_MODULE', + }); + } +} + +async function depth() { + const foo = new SourceTextModule('export default 5'); + await foo.link(common.mustNotCall()); + + async function getProxy(parentName, parentModule) { + const mod = new SourceTextModule(` + import ${parentName} from '${parentName}'; + export default ${parentName}; + `); + await mod.link(common.mustCall((specifier, module) => { + assert.strictEqual(module, mod); + assert.strictEqual(specifier, parentName); + return parentModule; + })); + return mod; + } + + const bar = await getProxy('foo', foo); + const baz = await getProxy('bar', bar); + const barz = await getProxy('baz', baz); + + await barz.evaluate(); + + assert.strictEqual(barz.namespace.default, 5); +} + +async function circular() { + const foo = new SourceTextModule(` + import getFoo from 'bar'; + export let foo = 42; + export default getFoo(); + `); + const bar = new SourceTextModule(` + import { foo } from 'foo'; + export default function getFoo() { + return foo; + } + `); + await foo.link(common.mustCall(async (specifier, module) => { + if (specifier === 'bar') { + assert.strictEqual(module, foo); + return bar; + } + assert.strictEqual(specifier, 'foo'); + assert.strictEqual(module, bar); + assert.strictEqual(foo.status, 'linking'); + return foo; + }, 2)); + + assert.strictEqual(bar.status, 'linked'); + + await foo.evaluate(); + assert.strictEqual(foo.namespace.default, 42); +} + +async function circular2() { + const sourceMap = { + 'root': ` + import * as a from './a.mjs'; + import * as b from './b.mjs'; + if (!('fromA' in a)) + throw new Error(); + if (!('fromB' in a)) + throw new Error(); + if (!('fromA' in b)) + throw new Error(); + if (!('fromB' in b)) + throw new Error(); + `, + './a.mjs': ` + export * from './b.mjs'; + export let fromA; + `, + './b.mjs': ` + export * from './a.mjs'; + export let fromB; + ` + }; + const moduleMap = new Map(); + const rootModule = new SourceTextModule(sourceMap.root, { + identifier: 'vm:root', + }); + async function link(specifier, referencingModule) { + if (moduleMap.has(specifier)) { + return moduleMap.get(specifier); + } + const mod = new SourceTextModule(sourceMap[specifier], { + identifier: new URL(specifier, 'file:///').href, + }); + moduleMap.set(specifier, mod); + return mod; + } + await rootModule.link(link); + await rootModule.evaluate(); +} + +async function asserts() { + const m = new SourceTextModule(` + import "foo" with { n1: 'v1', n2: 'v2' }; + `, { identifier: 'm' }); + await m.link((s, r, p) => { + assert.strictEqual(s, 'foo'); + assert.strictEqual(r.identifier, 'm'); + assert.strictEqual(p.attributes.n1, 'v1'); + assert.strictEqual(p.assert.n1, 'v1'); + assert.strictEqual(p.attributes.n2, 'v2'); + assert.strictEqual(p.assert.n2, 'v2'); + return new SourceTextModule(''); + }); +} + +const finished = common.mustCall(); + +(async function main() { + await simple(); + await invalidLinkValue(); + await depth(); + await circular(); + await circular2(); + await asserts(); + finished(); +})().then(common.mustCall());