@@ -275,9 +275,7 @@ class AdjointGenerator
275
275
Mode == DerivativeMode::ForwardMode
276
276
? false
277
277
: is_value_needed_in_reverse<ValueType::ShadowPtr>(
278
- TR, gutils, &I,
279
- /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
280
- oldUnreachable);
278
+ TR, gutils, &I, Mode, oldUnreachable);
281
279
282
280
switch (Mode) {
283
281
@@ -333,9 +331,7 @@ class AdjointGenerator
333
331
Mode == DerivativeMode::ForwardMode
334
332
? false
335
333
: is_value_needed_in_reverse<ValueType::Primal>(
336
- TR, gutils, &I,
337
- /* toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
338
- oldUnreachable);
334
+ TR, gutils, &I, Mode, oldUnreachable);
339
335
// ! Store loads that need to be cached for use in reverse pass
340
336
if (cache_reads_always ||
341
337
(!cache_reads_never && can_modref && primalNeededInReverse)) {
@@ -2535,7 +2531,6 @@ class AdjointGenerator
2535
2531
}
2536
2532
2537
2533
auto argType = argi->getType ();
2538
- bool fwdMode = Mode == DerivativeMode::ForwardMode;
2539
2534
2540
2535
if (!argType->isFPOrFPVectorTy () &&
2541
2536
TR.query (call.getArgOperand (i)).Inner0 ().isPossiblePointer ()) {
@@ -2567,13 +2562,13 @@ class AdjointGenerator
2567
2562
2568
2563
// Note sometimes whattype mistakenly says something should be constant
2569
2564
// [because composed of integer pointers alone]
2570
- assert (whatType (argType, fwdMode ) == DIFFE_TYPE::DUP_ARG ||
2571
- whatType (argType, fwdMode ) == DIFFE_TYPE::CONSTANT);
2565
+ assert (whatType (argType, Mode ) == DIFFE_TYPE::DUP_ARG ||
2566
+ whatType (argType, Mode ) == DIFFE_TYPE::CONSTANT);
2572
2567
} else {
2573
2568
assert (0 && " out for omp not handled" );
2574
2569
argsInverted.push_back (DIFFE_TYPE::OUT_DIFF);
2575
- assert (whatType (argType, fwdMode ) == DIFFE_TYPE::OUT_DIFF ||
2576
- whatType (argType, fwdMode ) == DIFFE_TYPE::CONSTANT);
2570
+ assert (whatType (argType, Mode ) == DIFFE_TYPE::OUT_DIFF ||
2571
+ whatType (argType, Mode ) == DIFFE_TYPE::CONSTANT);
2577
2572
}
2578
2573
}
2579
2574
@@ -2802,10 +2797,9 @@ class AdjointGenerator
2802
2797
newcalled = gutils->Logic .CreatePrimalAndGradient (
2803
2798
cast<Function>(called), subretType, argsInverted, gutils->TLI ,
2804
2799
TR.analysis , /* returnValue*/ false ,
2805
- /* subdretptr*/ false , /* topLevel */ false ,
2800
+ /* subdretptr*/ false , DerivativeMode::ReverseModeGradient ,
2806
2801
tape ? PointerType::getUnqual (tape->getType ()) : nullptr ,
2807
2802
nextTypeInfo, uncacheable_args, subdata, /* AtomicAdd*/ true ,
2808
- /* fwdMode*/ false ,
2809
2803
/* postopt*/ false , /* omp*/ true );
2810
2804
2811
2805
auto numargs = ConstantInt::get (Type::getInt32Ty (call.getContext ()),
@@ -4013,9 +4007,7 @@ class AdjointGenerator
4013
4007
// TO FREE'ing
4014
4008
if (Mode != DerivativeMode::ReverseModeCombined) {
4015
4009
if ((is_value_needed_in_reverse<ValueType::Primal>(
4016
- TR, gutils, orig,
4017
- /* topLevel*/ Mode == DerivativeMode::ReverseModeCombined,
4018
- oldUnreachable) &&
4010
+ TR, gutils, orig, Mode, oldUnreachable) &&
4019
4011
!gutils->unnecessaryIntermediates .count (orig)) ||
4020
4012
hasMetadata (orig, " enzyme_fromstack" )) {
4021
4013
Value *nop = gutils->cacheForReverse (BuilderZ, op,
@@ -4301,7 +4293,6 @@ class AdjointGenerator
4301
4293
}
4302
4294
4303
4295
auto argType = argi->getType ();
4304
- bool fwdMode = Mode == DerivativeMode::ForwardMode;
4305
4296
4306
4297
if (!argType->isFPOrFPVectorTy () &&
4307
4298
(TR.query (orig->getArgOperand (i)).Inner0 ().isPossiblePointer () ||
@@ -4334,14 +4325,14 @@ class AdjointGenerator
4334
4325
4335
4326
// Note sometimes whattype mistakenly says something should be constant
4336
4327
// [because composed of integer pointers alone]
4337
- assert (whatType (argType, fwdMode ) == DIFFE_TYPE::DUP_ARG ||
4338
- whatType (argType, fwdMode ) == DIFFE_TYPE::CONSTANT);
4328
+ assert (whatType (argType, Mode ) == DIFFE_TYPE::DUP_ARG ||
4329
+ whatType (argType, Mode ) == DIFFE_TYPE::CONSTANT);
4339
4330
} else {
4340
4331
if (foreignFunction)
4341
4332
assert (!argType->isIntOrIntVectorTy ());
4342
4333
argsInverted.push_back (DIFFE_TYPE::OUT_DIFF);
4343
- assert (whatType (argType, fwdMode ) == DIFFE_TYPE::OUT_DIFF ||
4344
- whatType (argType, fwdMode ) == DIFFE_TYPE::CONSTANT);
4334
+ assert (whatType (argType, Mode ) == DIFFE_TYPE::OUT_DIFF ||
4335
+ whatType (argType, Mode ) == DIFFE_TYPE::CONSTANT);
4345
4336
}
4346
4337
}
4347
4338
if (called) {
@@ -4588,9 +4579,7 @@ class AdjointGenerator
4588
4579
4589
4580
if (Mode == DerivativeMode::ReverseModePrimal &&
4590
4581
is_value_needed_in_reverse<ValueType::Primal>(
4591
- TR, gutils, orig,
4592
- /* topLevel*/ Mode == DerivativeMode::ReverseModeCombined,
4593
- oldUnreachable) &&
4582
+ TR, gutils, orig, Mode, oldUnreachable) &&
4594
4583
!gutils->unnecessaryIntermediates .count (orig)) {
4595
4584
gutils->cacheForReverse (BuilderZ, dcall,
4596
4585
getIndex (orig, CacheType::Self));
@@ -4623,8 +4612,7 @@ class AdjointGenerator
4623
4612
4624
4613
if (subretused) {
4625
4614
if (is_value_needed_in_reverse<ValueType::Primal>(
4626
- TR, gutils, orig, Mode == DerivativeMode::ReverseModeCombined,
4627
- oldUnreachable) &&
4615
+ TR, gutils, orig, Mode, oldUnreachable) &&
4628
4616
!gutils->unnecessaryIntermediates .count (orig)) {
4629
4617
cachereplace = BuilderZ.CreatePHI (orig->getType (), 1 ,
4630
4618
orig->getName () + " _tmpcacheB" );
@@ -4716,8 +4704,7 @@ class AdjointGenerator
4716
4704
if (/* !topLevel*/ Mode != DerivativeMode::ReverseModeCombined &&
4717
4705
subretused && !orig->doesNotAccessMemory ()) {
4718
4706
if (is_value_needed_in_reverse<ValueType::Primal>(
4719
- TR, gutils, orig, Mode == DerivativeMode::ReverseModeCombined,
4720
- oldUnreachable) &&
4707
+ TR, gutils, orig, Mode, oldUnreachable) &&
4721
4708
!gutils->unnecessaryIntermediates .count (orig)) {
4722
4709
assert (!replaceFunction);
4723
4710
cachereplace = BuilderZ.CreatePHI (orig->getType (), 1 ,
@@ -4751,19 +4738,21 @@ class AdjointGenerator
4751
4738
bool subdretptr = (subretType == DIFFE_TYPE::DUP_ARG ||
4752
4739
subretType == DIFFE_TYPE::DUP_NONEED) &&
4753
4740
replaceFunction && (call.getNumUses () != 0 );
4754
- bool subtopLevel = replaceFunction || !modifyPrimal;
4741
+ DerivativeMode subMode = (replaceFunction || !modifyPrimal)
4742
+ ? DerivativeMode::ReverseModeCombined
4743
+ : DerivativeMode::ReverseModeGradient;
4755
4744
if (called) {
4756
4745
newcalled = gutils->Logic .CreatePrimalAndGradient (
4757
4746
cast<Function>(called), subretType, argsInverted, gutils->TLI ,
4758
4747
TR.analysis , /* returnValue*/ retUsed,
4759
- /* subdretptr*/ subdretptr, /* topLevel */ subtopLevel ,
4760
- tape ? tape-> getType () : nullptr , nextTypeInfo, uncacheable_args,
4761
- subdata, gutils->AtomicAdd , /* fwdMode */ false ); // , LI, DT);
4748
+ /* subdretptr*/ subdretptr, subMode, tape ? tape-> getType () : nullptr ,
4749
+ nextTypeInfo, uncacheable_args, subdata ,
4750
+ gutils->AtomicAdd ); // , LI, DT);
4762
4751
if (!newcalled)
4763
4752
return ;
4764
4753
} else {
4765
4754
4766
- assert (!subtopLevel );
4755
+ assert (subMode != DerivativeMode::ReverseModeCombined );
4767
4756
4768
4757
#if LLVM_VERSION_MAJOR >= 11
4769
4758
auto callval = orig->getCalledOperand ();
0 commit comments