Skip to content

Commit 6f78640

Browse files
authored
Refactor top level (rust-lang#212)
replace topLevel with mode
1 parent 10e61f3 commit 6f78640

12 files changed

+185
-188
lines changed

enzyme/Enzyme/AdjointGenerator.h

+22-33
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,7 @@ class AdjointGenerator
275275
Mode == DerivativeMode::ForwardMode
276276
? false
277277
: is_value_needed_in_reverse<ValueType::ShadowPtr>(
278-
TR, gutils, &I,
279-
/*toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
280-
oldUnreachable);
278+
TR, gutils, &I, Mode, oldUnreachable);
281279

282280
switch (Mode) {
283281

@@ -333,9 +331,7 @@ class AdjointGenerator
333331
Mode == DerivativeMode::ForwardMode
334332
? false
335333
: is_value_needed_in_reverse<ValueType::Primal>(
336-
TR, gutils, &I,
337-
/*toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
338-
oldUnreachable);
334+
TR, gutils, &I, Mode, oldUnreachable);
339335
//! Store loads that need to be cached for use in reverse pass
340336
if (cache_reads_always ||
341337
(!cache_reads_never && can_modref && primalNeededInReverse)) {
@@ -2535,7 +2531,6 @@ class AdjointGenerator
25352531
}
25362532

25372533
auto argType = argi->getType();
2538-
bool fwdMode = Mode == DerivativeMode::ForwardMode;
25392534

25402535
if (!argType->isFPOrFPVectorTy() &&
25412536
TR.query(call.getArgOperand(i)).Inner0().isPossiblePointer()) {
@@ -2567,13 +2562,13 @@ class AdjointGenerator
25672562

25682563
// Note sometimes whattype mistakenly says something should be constant
25692564
// [because composed of integer pointers alone]
2570-
assert(whatType(argType, fwdMode) == DIFFE_TYPE::DUP_ARG ||
2571-
whatType(argType, fwdMode) == DIFFE_TYPE::CONSTANT);
2565+
assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG ||
2566+
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
25722567
} else {
25732568
assert(0 && "out for omp not handled");
25742569
argsInverted.push_back(DIFFE_TYPE::OUT_DIFF);
2575-
assert(whatType(argType, fwdMode) == DIFFE_TYPE::OUT_DIFF ||
2576-
whatType(argType, fwdMode) == DIFFE_TYPE::CONSTANT);
2570+
assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
2571+
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
25772572
}
25782573
}
25792574

@@ -2802,10 +2797,9 @@ class AdjointGenerator
28022797
newcalled = gutils->Logic.CreatePrimalAndGradient(
28032798
cast<Function>(called), subretType, argsInverted, gutils->TLI,
28042799
TR.analysis, /*returnValue*/ false,
2805-
/*subdretptr*/ false, /*topLevel*/ false,
2800+
/*subdretptr*/ false, DerivativeMode::ReverseModeGradient,
28062801
tape ? PointerType::getUnqual(tape->getType()) : nullptr,
28072802
nextTypeInfo, uncacheable_args, subdata, /*AtomicAdd*/ true,
2808-
/*fwdMode*/ false,
28092803
/*postopt*/ false, /*omp*/ true);
28102804

28112805
auto numargs = ConstantInt::get(Type::getInt32Ty(call.getContext()),
@@ -4013,9 +4007,7 @@ class AdjointGenerator
40134007
// TO FREE'ing
40144008
if (Mode != DerivativeMode::ReverseModeCombined) {
40154009
if ((is_value_needed_in_reverse<ValueType::Primal>(
4016-
TR, gutils, orig,
4017-
/*topLevel*/ Mode == DerivativeMode::ReverseModeCombined,
4018-
oldUnreachable) &&
4010+
TR, gutils, orig, Mode, oldUnreachable) &&
40194011
!gutils->unnecessaryIntermediates.count(orig)) ||
40204012
hasMetadata(orig, "enzyme_fromstack")) {
40214013
Value *nop = gutils->cacheForReverse(BuilderZ, op,
@@ -4301,7 +4293,6 @@ class AdjointGenerator
43014293
}
43024294

43034295
auto argType = argi->getType();
4304-
bool fwdMode = Mode == DerivativeMode::ForwardMode;
43054296

43064297
if (!argType->isFPOrFPVectorTy() &&
43074298
(TR.query(orig->getArgOperand(i)).Inner0().isPossiblePointer() ||
@@ -4334,14 +4325,14 @@ class AdjointGenerator
43344325

43354326
// Note sometimes whattype mistakenly says something should be constant
43364327
// [because composed of integer pointers alone]
4337-
assert(whatType(argType, fwdMode) == DIFFE_TYPE::DUP_ARG ||
4338-
whatType(argType, fwdMode) == DIFFE_TYPE::CONSTANT);
4328+
assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG ||
4329+
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
43394330
} else {
43404331
if (foreignFunction)
43414332
assert(!argType->isIntOrIntVectorTy());
43424333
argsInverted.push_back(DIFFE_TYPE::OUT_DIFF);
4343-
assert(whatType(argType, fwdMode) == DIFFE_TYPE::OUT_DIFF ||
4344-
whatType(argType, fwdMode) == DIFFE_TYPE::CONSTANT);
4334+
assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
4335+
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
43454336
}
43464337
}
43474338
if (called) {
@@ -4588,9 +4579,7 @@ class AdjointGenerator
45884579

45894580
if (Mode == DerivativeMode::ReverseModePrimal &&
45904581
is_value_needed_in_reverse<ValueType::Primal>(
4591-
TR, gutils, orig,
4592-
/*topLevel*/ Mode == DerivativeMode::ReverseModeCombined,
4593-
oldUnreachable) &&
4582+
TR, gutils, orig, Mode, oldUnreachable) &&
45944583
!gutils->unnecessaryIntermediates.count(orig)) {
45954584
gutils->cacheForReverse(BuilderZ, dcall,
45964585
getIndex(orig, CacheType::Self));
@@ -4623,8 +4612,7 @@ class AdjointGenerator
46234612

46244613
if (subretused) {
46254614
if (is_value_needed_in_reverse<ValueType::Primal>(
4626-
TR, gutils, orig, Mode == DerivativeMode::ReverseModeCombined,
4627-
oldUnreachable) &&
4615+
TR, gutils, orig, Mode, oldUnreachable) &&
46284616
!gutils->unnecessaryIntermediates.count(orig)) {
46294617
cachereplace = BuilderZ.CreatePHI(orig->getType(), 1,
46304618
orig->getName() + "_tmpcacheB");
@@ -4716,8 +4704,7 @@ class AdjointGenerator
47164704
if (/*!topLevel*/ Mode != DerivativeMode::ReverseModeCombined &&
47174705
subretused && !orig->doesNotAccessMemory()) {
47184706
if (is_value_needed_in_reverse<ValueType::Primal>(
4719-
TR, gutils, orig, Mode == DerivativeMode::ReverseModeCombined,
4720-
oldUnreachable) &&
4707+
TR, gutils, orig, Mode, oldUnreachable) &&
47214708
!gutils->unnecessaryIntermediates.count(orig)) {
47224709
assert(!replaceFunction);
47234710
cachereplace = BuilderZ.CreatePHI(orig->getType(), 1,
@@ -4751,19 +4738,21 @@ class AdjointGenerator
47514738
bool subdretptr = (subretType == DIFFE_TYPE::DUP_ARG ||
47524739
subretType == DIFFE_TYPE::DUP_NONEED) &&
47534740
replaceFunction && (call.getNumUses() != 0);
4754-
bool subtopLevel = replaceFunction || !modifyPrimal;
4741+
DerivativeMode subMode = (replaceFunction || !modifyPrimal)
4742+
? DerivativeMode::ReverseModeCombined
4743+
: DerivativeMode::ReverseModeGradient;
47554744
if (called) {
47564745
newcalled = gutils->Logic.CreatePrimalAndGradient(
47574746
cast<Function>(called), subretType, argsInverted, gutils->TLI,
47584747
TR.analysis, /*returnValue*/ retUsed,
4759-
/*subdretptr*/ subdretptr, /*topLevel*/ subtopLevel,
4760-
tape ? tape->getType() : nullptr, nextTypeInfo, uncacheable_args,
4761-
subdata, gutils->AtomicAdd, /*fwdMode*/ false); //, LI, DT);
4748+
/*subdretptr*/ subdretptr, subMode, tape ? tape->getType() : nullptr,
4749+
nextTypeInfo, uncacheable_args, subdata,
4750+
gutils->AtomicAdd); //, LI, DT);
47624751
if (!newcalled)
47634752
return;
47644753
} else {
47654754

4766-
assert(!subtopLevel);
4755+
assert(subMode != DerivativeMode::ReverseModeCombined);
47674756

47684757
#if LLVM_VERSION_MAJOR >= 11
47694758
auto callval = orig->getCalledOperand();

enzyme/Enzyme/CApi.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
241241
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
242242
CDIFFE_TYPE *constant_args, size_t constant_args_size,
243243
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
244-
uint8_t topLevel, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
244+
CDerivativeMode mode, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
245245
uint8_t *_uncacheable_args, size_t uncacheable_args_size,
246246
EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd, uint8_t PostOpt) {
247247
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
@@ -256,10 +256,9 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
256256
}
257257
return wrap(eunwrap(Logic).CreatePrimalAndGradient(
258258
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
259-
eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, topLevel,
259+
eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, (DerivativeMode)mode,
260260
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
261-
uncacheable_args, eunwrap(augmented), AtomicAdd, /*fwdMode*/ false,
262-
PostOpt));
261+
uncacheable_args, eunwrap(augmented), AtomicAdd, PostOpt));
263262
}
264263
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
265264
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,

enzyme/Enzyme/CApi.h

+11-3
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,21 @@ typedef enum {
109109
// but don't need the forward
110110
} CDIFFE_TYPE;
111111

112+
typedef enum {
113+
DEM_ForwardMode = 0,
114+
DEM_ReverseModePrimal = 1,
115+
DEM_ReverseModeGradient = 2,
116+
DEM_ReverseModeCombined = 3,
117+
} CDerivativeMode;
118+
112119
LLVMValueRef EnzymeCreatePrimalAndGradient(
113120
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
114121
CDIFFE_TYPE *constant_args, size_t constant_args_size,
115122
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
116-
uint8_t topLevel, LLVMTypeRef additionalArg, struct CFnTypeInfo typeInfo,
117-
uint8_t *_uncacheable_args, size_t uncacheable_args_size,
118-
EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd, uint8_t PostOpt);
123+
CDerivativeMode mode, LLVMTypeRef additionalArg,
124+
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
125+
size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented,
126+
uint8_t AtomicAdd, uint8_t PostOpt);
119127

120128
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
121129
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,

enzyme/Enzyme/DifferentialUseAnalysis.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ static inline bool is_use_directly_needed_in_reverse(
141141
template <ValueType VT, bool OneLevel = false>
142142
static inline bool is_value_needed_in_reverse(
143143
TypeResults &TR, const GradientUtils *gutils, const Value *inst,
144-
bool topLevel, std::map<UsageKey, bool> &seen,
144+
DerivativeMode mode, std::map<UsageKey, bool> &seen,
145145
const SmallPtrSetImpl<BasicBlock *> &oldUnreachable) {
146146
auto idx = UsageKey(inst, VT);
147147
if (seen.find(idx) != seen.end())
@@ -221,7 +221,7 @@ static inline bool is_value_needed_in_reverse(
221221
continue;
222222

223223
if (!OneLevel && is_value_needed_in_reverse<ValueType::ShadowPtr>(
224-
TR, gutils, user, topLevel, seen, oldUnreachable)) {
224+
TR, gutils, user, mode, seen, oldUnreachable)) {
225225
return seen[idx] = true;
226226
}
227227
continue;
@@ -230,7 +230,7 @@ static inline bool is_value_needed_in_reverse(
230230
assert(VT == ValueType::Primal);
231231

232232
// If a sub user needs, we need
233-
if (!OneLevel && is_value_needed_in_reverse<VT>(TR, gutils, user, topLevel,
233+
if (!OneLevel && is_value_needed_in_reverse<VT>(TR, gutils, user, mode,
234234
seen, oldUnreachable)) {
235235
return seen[idx] = true;
236236
}
@@ -242,7 +242,7 @@ static inline bool is_value_needed_in_reverse(
242242
// otherwise it will use the local cache (rather than save for a separate
243243
// backwards cache)
244244
// We also don't need this if looking at the shadow rather than primal
245-
if (!topLevel) {
245+
if (mode != DerivativeMode::ReverseModeCombined) {
246246
// Proving that none of the uses (or uses' uses) are used in control flow
247247
// allows us to safely not do this load
248248

@@ -296,7 +296,7 @@ static inline bool is_value_needed_in_reverse(
296296
.Inner0()
297297
.isPossiblePointer()) {
298298
if (is_value_needed_in_reverse<ValueType::ShadowPtr>(
299-
TR, gutils, user, topLevel, seen, oldUnreachable)) {
299+
TR, gutils, user, mode, seen, oldUnreachable)) {
300300
return seen[idx] = true;
301301
}
302302
}
@@ -314,9 +314,9 @@ static inline bool is_value_needed_in_reverse(
314314
template <ValueType VT>
315315
static inline bool is_value_needed_in_reverse(
316316
TypeResults &TR, const GradientUtils *gutils, const Value *inst,
317-
bool topLevel, const SmallPtrSetImpl<BasicBlock *> &oldUnreachable) {
317+
DerivativeMode mode, const SmallPtrSetImpl<BasicBlock *> &oldUnreachable) {
318318
std::map<UsageKey, bool> seen;
319-
return is_value_needed_in_reverse<VT>(TR, gutils, inst, topLevel, seen,
319+
return is_value_needed_in_reverse<VT>(TR, gutils, inst, mode, seen,
320320
oldUnreachable);
321321
}
322322

enzyme/Enzyme/Enzyme.cpp

+12-15
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ class Enzyme : public ModulePass {
215215
cast<ConstantInt>(CI->getArgOperand(i))->getSExtValue();
216216
continue;
217217
} else {
218-
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
218+
ty = whatType(PTy, mode);
219219
}
220220
} else if (isa<LoadInst>(res) &&
221221
isa<ConstantExpr>(cast<LoadInst>(res)->getOperand(0)) &&
@@ -250,7 +250,7 @@ class Enzyme : public ModulePass {
250250
cast<ConstantInt>(CI->getArgOperand(i))->getSExtValue();
251251
continue;
252252
} else {
253-
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
253+
ty = whatType(PTy, mode);
254254
}
255255
} else if (isa<GlobalVariable>(res)) {
256256
auto gv = cast<GlobalVariable>(res);
@@ -272,7 +272,7 @@ class Enzyme : public ModulePass {
272272
++i;
273273
res = CI->getArgOperand(i);
274274
} else {
275-
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
275+
ty = whatType(PTy, mode);
276276
}
277277
} else if (isa<ConstantExpr>(res) && cast<ConstantExpr>(res)->isCast() &&
278278
isa<GlobalVariable>(cast<ConstantExpr>(res)->getOperand(0))) {
@@ -295,7 +295,7 @@ class Enzyme : public ModulePass {
295295
++i;
296296
res = CI->getArgOperand(i);
297297
} else {
298-
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
298+
ty = whatType(PTy, mode);
299299
}
300300
} else if (isa<CastInst>(res) && cast<CastInst>(res) &&
301301
isa<AllocaInst>(cast<CastInst>(res)->getOperand(0))) {
@@ -318,7 +318,7 @@ class Enzyme : public ModulePass {
318318
++i;
319319
res = CI->getArgOperand(i);
320320
} else {
321-
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
321+
ty = whatType(PTy, mode);
322322
}
323323
} else if (isa<AllocaInst>(res)) {
324324
auto gv = cast<AllocaInst>(res);
@@ -340,10 +340,10 @@ class Enzyme : public ModulePass {
340340
++i;
341341
res = CI->getArgOperand(i);
342342
} else {
343-
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
343+
ty = whatType(PTy, mode);
344344
}
345345
} else
346-
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
346+
ty = whatType(PTy, mode);
347347

348348
constants.push_back(ty);
349349

@@ -440,8 +440,7 @@ class Enzyme : public ModulePass {
440440
mode != DerivativeMode::ForwardMode &&
441441
cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy();
442442

443-
DIFFE_TYPE retType = whatType(cast<Function>(fn)->getReturnType(),
444-
mode == DerivativeMode::ForwardMode);
443+
DIFFE_TYPE retType = whatType(cast<Function>(fn)->getReturnType(), mode);
445444

446445
std::map<Argument *, bool> volatile_args;
447446
FnTypeInfo type_args(cast<Function>(fn));
@@ -478,10 +477,9 @@ class Enzyme : public ModulePass {
478477
case DerivativeMode::ReverseModeCombined:
479478
newFunc = Logic.CreatePrimalAndGradient(
480479
cast<Function>(fn), retType, constants, TLI, TA,
481-
/*should return*/ false, /*dretPtr*/ false, /*topLevel*/ true,
480+
/*should return*/ false, /*dretPtr*/ false, mode,
482481
/*addedType*/ nullptr, type_args, volatile_args,
483-
/*index mapping*/ nullptr, AtomicAdd,
484-
mode == DerivativeMode::ForwardMode, PostOpt);
482+
/*index mapping*/ nullptr, AtomicAdd, PostOpt);
485483
break;
486484
case DerivativeMode::ReverseModePrimal:
487485
case DerivativeMode::ReverseModeGradient: {
@@ -515,9 +513,8 @@ class Enzyme : public ModulePass {
515513
else
516514
newFunc = Logic.CreatePrimalAndGradient(
517515
cast<Function>(fn), retType, constants, TLI, TA,
518-
/*should return*/ false, /*dretPtr*/ false, /*topLevel*/ false,
519-
tapeType, type_args, volatile_args, &aug, AtomicAdd,
520-
/*fwdMode*/ false, PostOpt);
516+
/*should return*/ false, /*dretPtr*/ false, mode, tapeType,
517+
type_args, volatile_args, &aug, AtomicAdd, PostOpt);
521518
}
522519
}
523520

0 commit comments

Comments
 (0)