Skip to content

Commit 6094b3b

Browse files
committed
[ORC] Unify task dispatch across ExecutionSession and ExecutorProcessControl.
Updates ExecutionSession to use the ExecutorProcessControl object's TaskDispatcher rather than having a separate dispatch function. This gives the TaskDispatcher a global view of all tasks to be executed, and provides a single point to wait on for tasks to complete when shutting down the JIT.
1 parent 5fef5e6 commit 6094b3b

File tree

11 files changed

+157
-61
lines changed

11 files changed

+157
-61
lines changed

llvm/include/llvm/ExecutionEngine/Orc/Core.h

+1-13
Original file line numberDiff line numberDiff line change
@@ -1443,9 +1443,6 @@ class ExecutionSession {
14431443
/// Send a result to the remote.
14441444
using SendResultFunction = unique_function<void(shared::WrapperFunctionResult)>;
14451445

1446-
/// For dispatching ORC tasks (typically materialization tasks).
1447-
using DispatchTaskFunction = unique_function<void(std::unique_ptr<Task> T)>;
1448-
14491446
/// An asynchronous wrapper-function callable from the executor via
14501447
/// jit-dispatch.
14511448
using JITDispatchHandlerFunction = unique_function<void(
@@ -1568,12 +1565,6 @@ class ExecutionSession {
15681565
/// Unhandled errors can be sent here to log them.
15691566
void reportError(Error Err) { ReportError(std::move(Err)); }
15701567

1571-
/// Set the task dispatch function.
1572-
ExecutionSession &setDispatchTask(DispatchTaskFunction DispatchTask) {
1573-
this->DispatchTask = std::move(DispatchTask);
1574-
return *this;
1575-
}
1576-
15771568
/// Search the given JITDylibs to find the flags associated with each of the
15781569
/// given symbols.
15791570
void lookupFlags(LookupKind K, JITDylibSearchOrder SearchOrder,
@@ -1648,7 +1639,7 @@ class ExecutionSession {
16481639
void dispatchTask(std::unique_ptr<Task> T) {
16491640
assert(T && "T must be non-null");
16501641
DEBUG_WITH_TYPE("orc", dumpDispatchInfo(*T));
1651-
DispatchTask(std::move(T));
1642+
EPC->getDispatcher().dispatch(std::move(T));
16521643
}
16531644

16541645
/// Run a wrapper function in the executor.
@@ -1762,8 +1753,6 @@ class ExecutionSession {
17621753
logAllUnhandledErrors(std::move(Err), errs(), "JIT session error: ");
17631754
}
17641755

1765-
static void runOnCurrentThread(std::unique_ptr<Task> T) { T->run(); }
1766-
17671756
void dispatchOutstandingMUs();
17681757

17691758
static std::unique_ptr<MaterializationResponsibility>
@@ -1869,7 +1858,6 @@ class ExecutionSession {
18691858
std::unique_ptr<ExecutorProcessControl> EPC;
18701859
std::unique_ptr<Platform> P;
18711860
ErrorReporter ReportError = logErrorsToStdErr;
1872-
DispatchTaskFunction DispatchTask = runOnCurrentThread;
18731861

18741862
std::vector<ResourceManager *> ResourceManagers;
18751863

llvm/include/llvm/ExecutionEngine/Orc/LLJIT.h

+16-9
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ class LLJIT {
254254

255255
DataLayout DL;
256256
Triple TT;
257-
std::unique_ptr<DefaultThreadPool> CompileThreads;
258257

259258
std::unique_ptr<ObjectLayer> ObjLinkingLayer;
260259
std::unique_ptr<ObjectTransformLayer> ObjTransformLayer;
@@ -325,6 +324,7 @@ class LLJITBuilderState {
325324
PlatformSetupFunction SetUpPlatform;
326325
NotifyCreatedFunction NotifyCreated;
327326
unsigned NumCompileThreads = 0;
327+
std::optional<bool> SupportConcurrentCompilation;
328328

329329
/// Called prior to JIT class construcion to fix up defaults.
330330
Error prepareForConstruction();
@@ -333,7 +333,7 @@ class LLJITBuilderState {
333333
template <typename JITType, typename SetterImpl, typename State>
334334
class LLJITBuilderSetters {
335335
public:
336-
/// Set a ExecutorProcessControl for this instance.
336+
/// Set an ExecutorProcessControl for this instance.
337337
/// This should not be called if ExecutionSession has already been set.
338338
SetterImpl &
339339
setExecutorProcessControl(std::unique_ptr<ExecutorProcessControl> EPC) {
@@ -462,19 +462,26 @@ class LLJITBuilderSetters {
462462
///
463463
/// If this method is not called, behavior will be as if it were called with
464464
/// a zero argument.
465+
///
466+
/// This setting should not be used if a custom ExecutionSession or
467+
/// ExecutorProcessControl object is set: in those cases a custom
468+
/// TaskDispatcher should be used instead.
465469
SetterImpl &setNumCompileThreads(unsigned NumCompileThreads) {
466470
impl().NumCompileThreads = NumCompileThreads;
467471
return impl();
468472
}
469473

470-
/// Set an ExecutorProcessControl object.
474+
/// If set, this forces LLJIT concurrent compilation support to be either on
475+
/// or off. This controls the selection of compile function (concurrent vs
476+
/// single threaded) and whether or not sub-modules are cloned to new
477+
/// contexts for lazy emission.
471478
///
472-
/// If the platform uses ObjectLinkingLayer by default and no
473-
/// ObjectLinkingLayerCreator has been set then the ExecutorProcessControl
474-
/// object will be used to supply the memory manager for the
475-
/// ObjectLinkingLayer.
476-
SetterImpl &setExecutorProcessControl(ExecutorProcessControl &EPC) {
477-
impl().EPC = &EPC;
479+
/// If not explicitly set then concurrency support will be turned on if
480+
/// NumCompileThreads is set to a non-zero value, or if a custom
481+
/// ExecutionSession or ExecutorProcessControl instance is provided.
482+
SetterImpl &setSupportConcurrentCompilation(
483+
std::optional<bool> SupportConcurrentCompilation) {
484+
impl().SupportConcurrentCompilation = SupportConcurrentCompilation;
478485
return impl();
479486
}
480487

llvm/include/llvm/ExecutionEngine/Orc/TaskDispatch.h

+7
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,20 @@ class InPlaceTaskDispatcher : public TaskDispatcher {
114114

115115
class DynamicThreadPoolTaskDispatcher : public TaskDispatcher {
116116
public:
117+
DynamicThreadPoolTaskDispatcher(
118+
std::optional<size_t> MaxMaterializationThreads)
119+
: MaxMaterializationThreads(MaxMaterializationThreads) {}
117120
void dispatch(std::unique_ptr<Task> T) override;
118121
void shutdown() override;
119122
private:
120123
std::mutex DispatchMutex;
121124
bool Running = true;
122125
size_t Outstanding = 0;
123126
std::condition_variable OutstandingCV;
127+
128+
std::optional<size_t> MaxMaterializationThreads;
129+
size_t NumMaterializationThreads = 0;
130+
std::deque<std::unique_ptr<Task>> MaterializationTaskQueue;
124131
};
125132

126133
#endif // LLVM_ENABLE_THREADS

llvm/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ SelfExecutorProcessControl::Create(
6363

6464
if (!D) {
6565
#if LLVM_ENABLE_THREADS
66-
D = std::make_unique<DynamicThreadPoolTaskDispatcher>();
66+
D = std::make_unique<DynamicThreadPoolTaskDispatcher>(std::nullopt);
6767
#else
6868
D = std::make_unique<InPlaceTaskDispatcher>();
6969
#endif

llvm/lib/ExecutionEngine/Orc/LLJIT.cpp

+53-24
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,37 @@ Error LLJITBuilderState::prepareForConstruction() {
667667
return JTMBOrErr.takeError();
668668
}
669669

670+
if ((ES || EPC) && NumCompileThreads)
671+
return make_error<StringError>(
672+
"NumCompileThreads cannot be used with a custom ExecutionSession or "
673+
"ExecutorProcessControl",
674+
inconvertibleErrorCode());
675+
676+
#if !LLVM_ENABLE_THREADS
677+
if (NumCompileThreads)
678+
return make_error<StringError>(
679+
"LLJIT num-compile-threads is " + Twine(NumCompileThreads) +
680+
" but LLVM was compiled with LLVM_ENABLE_THREADS=Off",
681+
inconvertibleErrorCode());
682+
#endif // !LLVM_ENABLE_THREADS
683+
684+
bool ConcurrentCompilationSettingDefaulted = !SupportConcurrentCompilation;
685+
if (!SupportConcurrentCompilation) {
686+
#if LLVM_ENABLE_THREADS
687+
SupportConcurrentCompilation = NumCompileThreads || ES || EPC;
688+
#else
689+
SupportConcurrentCompilation = false;
690+
#endif // LLVM_ENABLE_THREADS
691+
} else {
692+
#if !LLVM_ENABLE_THREADS
693+
if (*SupportConcurrentCompilation)
694+
return make_error<StringError>(
695+
"LLJIT concurrent compilation support requested, but LLVM was built "
696+
"with LLVM_ENABLE_THREADS=Off",
697+
inconvertibleErrorCode());
698+
#endif // !LLVM_ENABLE_THREADS
699+
}
700+
670701
LLVM_DEBUG({
671702
dbgs() << " JITTargetMachineBuilder is "
672703
<< JITTargetMachineBuilderPrinter(*JTMB, " ")
@@ -684,11 +715,13 @@ Error LLJITBuilderState::prepareForConstruction() {
684715
<< (CreateCompileFunction ? "Yes" : "No") << "\n"
685716
<< " Custom platform-setup function: "
686717
<< (SetUpPlatform ? "Yes" : "No") << "\n"
687-
<< " Number of compile threads: " << NumCompileThreads;
688-
if (!NumCompileThreads)
689-
dbgs() << " (code will be compiled on the execution thread)\n";
718+
<< " Support concurrent compilation: "
719+
<< (*SupportConcurrentCompilation ? "Yes" : "No");
720+
if (ConcurrentCompilationSettingDefaulted)
721+
dbgs() << " (defaulted based on ES / EPC)\n";
690722
else
691723
dbgs() << "\n";
724+
dbgs() << " Number of compile threads: " << NumCompileThreads << "\n";
692725
});
693726

694727
// Create DL if not specified.
@@ -705,7 +738,19 @@ Error LLJITBuilderState::prepareForConstruction() {
705738
dbgs() << "ExecutorProcessControl not specified, "
706739
"Creating SelfExecutorProcessControl instance\n";
707740
});
708-
if (auto EPCOrErr = SelfExecutorProcessControl::Create())
741+
742+
std::unique_ptr<TaskDispatcher> D = nullptr;
743+
#if LLVM_ENABLE_THREADS
744+
if (*SupportConcurrentCompilation) {
745+
std::optional<size_t> NumThreads = std ::nullopt;
746+
if (NumCompileThreads)
747+
NumThreads = NumCompileThreads;
748+
D = std::make_unique<DynamicThreadPoolTaskDispatcher>(NumThreads);
749+
} else
750+
D = std::make_unique<InPlaceTaskDispatcher>();
751+
#endif // LLVM_ENABLE_THREADS
752+
if (auto EPCOrErr =
753+
SelfExecutorProcessControl::Create(nullptr, std::move(D), nullptr))
709754
EPC = std::move(*EPCOrErr);
710755
else
711756
return EPCOrErr.takeError();
@@ -790,8 +835,6 @@ Error LLJITBuilderState::prepareForConstruction() {
790835
}
791836

792837
LLJIT::~LLJIT() {
793-
if (CompileThreads)
794-
CompileThreads->wait();
795838
if (auto Err = ES->endSession())
796839
ES->reportError(std::move(Err));
797840
}
@@ -916,9 +959,8 @@ LLJIT::createCompileFunction(LLJITBuilderState &S,
916959
if (S.CreateCompileFunction)
917960
return S.CreateCompileFunction(std::move(JTMB));
918961

919-
// Otherwise default to creating a SimpleCompiler, or ConcurrentIRCompiler,
920-
// depending on the number of threads requested.
921-
if (S.NumCompileThreads > 0)
962+
// If using a custom EPC then use a ConcurrentIRCompiler by default.
963+
if (*S.SupportConcurrentCompilation)
922964
return std::make_unique<ConcurrentIRCompiler>(std::move(JTMB));
923965

924966
auto TM = JTMB.createTargetMachine();
@@ -970,21 +1012,8 @@ LLJIT::LLJIT(LLJITBuilderState &S, Error &Err)
9701012
std::make_unique<IRTransformLayer>(*ES, *TransformLayer);
9711013
}
9721014

973-
if (S.NumCompileThreads > 0) {
1015+
if (*S.SupportConcurrentCompilation)
9741016
InitHelperTransformLayer->setCloneToNewContextOnEmit(true);
975-
CompileThreads = std::make_unique<DefaultThreadPool>(
976-
hardware_concurrency(S.NumCompileThreads));
977-
ES->setDispatchTask([this](std::unique_ptr<Task> T) {
978-
// FIXME: We should be able to use move-capture here, but ThreadPool's
979-
// AsyncTaskTys are std::functions rather than unique_functions
980-
// (because MSVC's std::packaged_tasks don't support move-only types).
981-
// Fix this when all the above gets sorted out.
982-
CompileThreads->async([UnownedT = T.release()]() mutable {
983-
std::unique_ptr<Task> T(UnownedT);
984-
T->run();
985-
});
986-
});
987-
}
9881017

9891018
if (S.SetupProcessSymbolsJITDylib) {
9901019
if (auto ProcSymsJD = S.SetupProcessSymbolsJITDylib(*this)) {
@@ -1240,7 +1269,7 @@ LLLazyJIT::LLLazyJIT(LLLazyJITBuilderState &S, Error &Err) : LLJIT(S, Err) {
12401269
CODLayer = std::make_unique<CompileOnDemandLayer>(
12411270
*ES, *InitHelperTransformLayer, *LCTMgr, std::move(ISMBuilder));
12421271

1243-
if (S.NumCompileThreads > 0)
1272+
if (*S.SupportConcurrentCompilation)
12441273
CODLayer->setCloneToNewContextOnEmit(true);
12451274
}
12461275

llvm/lib/ExecutionEngine/Orc/TaskDispatch.cpp

+42-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/ExecutionEngine/Orc/TaskDispatch.h"
10+
#include "llvm/ExecutionEngine/Orc/Core.h"
1011

1112
namespace llvm {
1213
namespace orc {
@@ -24,16 +25,52 @@ void InPlaceTaskDispatcher::shutdown() {}
2425

2526
#if LLVM_ENABLE_THREADS
2627
void DynamicThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
28+
bool IsMaterializationTask = isa<MaterializationTask>(*T);
29+
2730
{
2831
std::lock_guard<std::mutex> Lock(DispatchMutex);
32+
33+
if (IsMaterializationTask) {
34+
35+
// If this is a materialization task and there are too many running
36+
// already then queue this one up and return early.
37+
if (MaxMaterializationThreads &&
38+
NumMaterializationThreads == *MaxMaterializationThreads) {
39+
MaterializationTaskQueue.push_back(std::move(T));
40+
return;
41+
}
42+
43+
// Otherwise record that we have a materialization task running.
44+
++NumMaterializationThreads;
45+
}
46+
2947
++Outstanding;
3048
}
3149

32-
std::thread([this, T = std::move(T)]() mutable {
33-
T->run();
34-
std::lock_guard<std::mutex> Lock(DispatchMutex);
35-
--Outstanding;
36-
OutstandingCV.notify_all();
50+
std::thread([this, T = std::move(T), IsMaterializationTask]() mutable {
51+
while (true) {
52+
53+
// Run the task.
54+
T->run();
55+
56+
std::lock_guard<std::mutex> Lock(DispatchMutex);
57+
if (!MaterializationTaskQueue.empty()) {
58+
// If there are any materialization tasks running then steal that work.
59+
T = std::move(MaterializationTaskQueue.front());
60+
MaterializationTaskQueue.pop_front();
61+
if (!IsMaterializationTask) {
62+
++NumMaterializationThreads;
63+
IsMaterializationTask = true;
64+
}
65+
} else {
66+
// Otherwise decrement work counters.
67+
if (IsMaterializationTask)
68+
--NumMaterializationThreads;
69+
--Outstanding;
70+
OutstandingCV.notify_all();
71+
return;
72+
}
73+
}
3774
}).detach();
3875
}
3976

llvm/tools/llvm-jitlink/llvm-jitlink.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -807,8 +807,8 @@ static Expected<std::unique_ptr<ExecutorProcessControl>> launchExecutor() {
807807
S.CreateMemoryManager = createSharedMemoryManager;
808808

809809
return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(
810-
std::make_unique<DynamicThreadPoolTaskDispatcher>(), std::move(S),
811-
FromExecutor[ReadEnd], ToExecutor[WriteEnd]);
810+
std::make_unique<DynamicThreadPoolTaskDispatcher>(std::nullopt),
811+
std::move(S), FromExecutor[ReadEnd], ToExecutor[WriteEnd]);
812812
#endif
813813
}
814814

@@ -897,7 +897,7 @@ static Expected<std::unique_ptr<ExecutorProcessControl>> connectToExecutor() {
897897
S.CreateMemoryManager = createSharedMemoryManager;
898898

899899
return SimpleRemoteEPC::Create<FDSimpleRemoteEPCTransport>(
900-
std::make_unique<DynamicThreadPoolTaskDispatcher>(),
900+
std::make_unique<DynamicThreadPoolTaskDispatcher>(std::nullopt),
901901
std::move(S), *SockFD, *SockFD);
902902
#endif
903903
}

llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1005,11 +1005,11 @@ TEST_F(CoreAPIsStandardTest, RedefineBoundWeakSymbol) {
10051005

10061006
TEST_F(CoreAPIsStandardTest, DefineMaterializingSymbol) {
10071007
bool ExpectNoMoreMaterialization = false;
1008-
ES.setDispatchTask([&](std::unique_ptr<Task> T) {
1008+
DispatchOverride = [&](std::unique_ptr<Task> T) {
10091009
if (ExpectNoMoreMaterialization && isa<MaterializationTask>(*T))
10101010
ADD_FAILURE() << "Unexpected materialization";
10111011
T->run();
1012-
});
1012+
};
10131013

10141014
auto MU = std::make_unique<SimpleMaterializationUnit>(
10151015
SymbolFlagsMap({{Foo, FooSym.getFlags()}}),
@@ -1403,7 +1403,7 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) {
14031403

14041404
std::mutex WorkThreadsMutex;
14051405
std::vector<std::thread> WorkThreads;
1406-
ES.setDispatchTask([&](std::unique_ptr<Task> T) {
1406+
DispatchOverride = [&](std::unique_ptr<Task> T) {
14071407
std::promise<void> WaitP;
14081408
std::lock_guard<std::mutex> Lock(WorkThreadsMutex);
14091409
WorkThreads.push_back(
@@ -1412,7 +1412,7 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) {
14121412
T->run();
14131413
}));
14141414
WaitP.set_value();
1415-
});
1415+
};
14161416

14171417
cantFail(JD.define(absoluteSymbols({{Foo, FooSym}})));
14181418

0 commit comments

Comments
 (0)