Skip to content

Commit 669b422

Browse files
authored
Simplify argument handling and add enzyme_nofree option (rust-lang#371)
* Add inner loop test * Abstract primal and reverse api * Add nofree option * Keep AtomicAdd in key * Preserve non-cache allocations
1 parent 717ae73 commit 669b422

11 files changed

+689
-591
lines changed

enzyme/Enzyme/AdjointGenerator.h

+165-106
Large diffs are not rendered by default.

enzyme/Enzyme/CApi.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,20 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
371371
argnum++;
372372
}
373373
return wrap(eunwrap(Logic).CreatePrimalAndGradient(
374-
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
375-
eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, (DerivativeMode)mode,
376-
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
377-
uncacheable_args, eunwrap(augmented), AtomicAdd, PostOpt));
374+
(ReverseCacheKey){
375+
.todiff = cast<Function>(unwrap(todiff)),
376+
.retType = (DIFFE_TYPE)retType,
377+
.constant_args = nconstant_args,
378+
.uncacheable_args = uncacheable_args,
379+
.returnUsed = returnValue,
380+
.shadowReturnUsed = dretUsed,
381+
.mode = (DerivativeMode)mode,
382+
.freeMemory = true,
383+
.AtomicAdd = AtomicAdd,
384+
.additionalType = unwrap(additionalArg),
385+
.typeInfo = eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
386+
},
387+
eunwrap(TA).TLI, eunwrap(TA), eunwrap(augmented), PostOpt));
378388
}
379389
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
380390
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,

0 commit comments

Comments
 (0)