diff --git a/src/bun.js/bindings/webcore/MessagePort.cpp b/src/bun.js/bindings/webcore/MessagePort.cpp index a684e1c832..89c7873300 100644 --- a/src/bun.js/bindings/webcore/MessagePort.cpp +++ b/src/bun.js/bindings/webcore/MessagePort.cpp @@ -322,8 +322,11 @@ void MessagePort::dispatchMessages() if (!context || context->activeDOMObjectsAreSuspended() || !isEntangled()) return; - auto messagesTakenHandler = [this, protectedThis = Ref { *this }](Vector&& messages, CompletionHandler&& completionCallback) mutable { - RefPtr context = scriptExecutionContext(); + auto executionContextIdentifier = scriptExecutionContext()->identifier(); + + auto messagesTakenHandler = [this, protectedThis = Ref { *this }, executionContextIdentifier](Vector&& messages, CompletionHandler&& completionCallback) mutable { + RefPtr context = ScriptExecutionContext::getScriptExecutionContext(executionContextIdentifier); + if (!context || !context->globalObject()) { completionCallback(); return; @@ -333,7 +336,7 @@ void MessagePort::dispatchMessages() processMessages(*context, WTFMove(messages), WTFMove(completionCallback)); }; - MessagePortChannelProvider::fromContext(*context).takeAllMessagesForPort(m_identifier, WTFMove(messagesTakenHandler)); + MessagePortChannelProvider::fromContext(*context).takeAllMessagesForPort(executionContextIdentifier, m_identifier, WTFMove(messagesTakenHandler)); } // synchronous for node:worker_threads.receiveMessageOnPort diff --git a/src/bun.js/bindings/webcore/MessagePortChannelProvider.h b/src/bun.js/bindings/webcore/MessagePortChannelProvider.h index 416a2279e0..280866117d 100644 --- a/src/bun.js/bindings/webcore/MessagePortChannelProvider.h +++ b/src/bun.js/bindings/webcore/MessagePortChannelProvider.h @@ -26,6 +26,7 @@ #pragma once #include "ProcessIdentifier.h" +#include "ScriptExecutionContext.h" #include #include @@ -58,7 +59,7 @@ public: virtual void messagePortDisentangled(const MessagePortIdentifier& local) = 0; virtual void messagePortClosed(const MessagePortIdentifier& local) = 0; - virtual void takeAllMessagesForPort(const MessagePortIdentifier&, CompletionHandler&&, CompletionHandler&&)>&&) = 0; + virtual void takeAllMessagesForPort(const ScriptExecutionContextIdentifier, const MessagePortIdentifier&, CompletionHandler&&, CompletionHandler&&)>&&) = 0; virtual void tryTakeMessageForPort(const MessagePortIdentifier&, CompletionHandler&&)>&&) = 0; virtual void postMessageToRemote(MessageWithMessagePorts&&, const MessagePortIdentifier& remoteTarget) = 0; }; diff --git a/src/bun.js/bindings/webcore/MessagePortChannelProviderImpl.cpp b/src/bun.js/bindings/webcore/MessagePortChannelProviderImpl.cpp index 45f4e7efc1..ae784b1632 100644 --- a/src/bun.js/bindings/webcore/MessagePortChannelProviderImpl.cpp +++ b/src/bun.js/bindings/webcore/MessagePortChannelProviderImpl.cpp @@ -27,9 +27,13 @@ #include "MessagePortChannelProviderImpl.h" #include "MessagePort.h" +#include "ScriptExecutionContext.h" +#include "BunClientData.h" #include #include +extern "C" void* Bun__getVM(); + namespace WebCore { MessagePortChannelProviderImpl::MessagePortChannelProviderImpl() = default; @@ -82,17 +86,38 @@ void MessagePortChannelProviderImpl::postMessageToRemote(MessageWithMessagePorts }); } -void MessagePortChannelProviderImpl::takeAllMessagesForPort(const MessagePortIdentifier& port, CompletionHandler&&, CompletionHandler&&)>&& outerCallback) +void MessagePortChannelProviderImpl::takeAllMessagesForPort(const ScriptExecutionContextIdentifier identifier, const MessagePortIdentifier& port, CompletionHandler&&, CompletionHandler&&)>&& outerCallback) { - // It is the responsibility of outerCallback to get itself to the appropriate thread (e.g. WebWorker thread) - auto callback = [outerCallback = WTFMove(outerCallback)](Vector&& messages, CompletionHandler&& messageDeliveryCallback) mutable { - ASSERT(isMainThread()); - outerCallback(WTFMove(messages), WTFMove(messageDeliveryCallback)); - }; + if (WTF::isMainThread()) { + m_registry.takeAllMessagesForPort(port, WTFMove(outerCallback)); + return; + } - ScriptExecutionContext::ensureOnMainThread([weakRegistry = WeakPtr { m_registry }, port, callback = WTFMove(callback)](ScriptExecutionContext& context) mutable { - if (CheckedPtr registry = weakRegistry.get()) - registry->takeAllMessagesForPort(port, WTFMove(callback)); + auto currentVM = Bun__getVM(); + if (!currentVM) { + outerCallback({}, [](){}); // already destroyed + return; + } + + ScriptExecutionContext::ensureOnMainThread([weakRegistry = WeakPtr { m_registry }, port, outerCallback = WTFMove(outerCallback), identifier](ScriptExecutionContext& mainContext) mutable { + CheckedPtr registry = weakRegistry.get(); + if (!registry) { + ScriptExecutionContext::ensureOnContextThread(identifier, [outerCallback = WTFMove(outerCallback)](ScriptExecutionContext&) mutable { + outerCallback({}, [](){}); + }); + return; + } + + registry->takeAllMessagesForPort(port, [outerCallback = WTFMove(outerCallback), identifier](Vector&& messages, CompletionHandler&& completionHandler) mutable { + ScriptExecutionContext::ensureOnContextThread(identifier, [outerCallback = WTFMove(outerCallback), messages = WTFMove(messages), completionHandler = WTFMove(completionHandler)](ScriptExecutionContext&) mutable { + auto wrappedCompletionHandler = [completionHandler = WTFMove(completionHandler)]() mutable { + ScriptExecutionContext::ensureOnMainThread([completionHandler = WTFMove(completionHandler)](ScriptExecutionContext&) mutable { + completionHandler(); + }); + }; + outerCallback(WTFMove(messages), WTFMove(wrappedCompletionHandler)); + }); + }); }); } diff --git a/src/bun.js/bindings/webcore/MessagePortChannelProviderImpl.h b/src/bun.js/bindings/webcore/MessagePortChannelProviderImpl.h index 17f756960e..c389ceec80 100644 --- a/src/bun.js/bindings/webcore/MessagePortChannelProviderImpl.h +++ b/src/bun.js/bindings/webcore/MessagePortChannelProviderImpl.h @@ -27,6 +27,7 @@ #include "MessagePortChannelProvider.h" #include "MessagePortChannelRegistry.h" +#include "ScriptExecutionContext.h" namespace WebCore { @@ -41,7 +42,7 @@ private: void messagePortDisentangled(const MessagePortIdentifier& local) final; void messagePortClosed(const MessagePortIdentifier& local) final; void postMessageToRemote(MessageWithMessagePorts&&, const MessagePortIdentifier& remoteTarget) final; - void takeAllMessagesForPort(const MessagePortIdentifier&, CompletionHandler&&, CompletionHandler&&)>&&) final; + void takeAllMessagesForPort(const ScriptExecutionContextIdentifier identifier, const MessagePortIdentifier&, CompletionHandler&&, CompletionHandler&&)>&&) final; void tryTakeMessageForPort(const MessagePortIdentifier&, CompletionHandler&&)>&&) final; MessagePortChannelRegistry m_registry;