Skip to content

Commit d40597a

Browse files
Minor changes to enable AD of Fortran. (rust-lang#295)
1 parent dce5f8b commit d40597a

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

enzyme/Enzyme/AdjointGenerator.h

+18-4
Original file line numberDiff line numberDiff line change
@@ -1014,10 +1014,20 @@ class AdjointGenerator
10141014
diffe(&SI, Builder2), "diffe" + op2->getName());
10151015

10161016
setDiffe(&SI, Constant::getNullValue(SI.getType()), Builder2);
1017-
if (dif1)
1018-
addToDiffe(orig_op1, dif1, Builder2, TR.addingType(size, orig_op1));
1019-
if (dif2)
1020-
addToDiffe(orig_op2, dif2, Builder2, TR.addingType(size, orig_op2));
1017+
if (dif1) {
1018+
Type *addingType = TR.addingType(size, orig_op1);
1019+
if (addingType || !looseTypeAnalysis)
1020+
addToDiffe(orig_op1, dif1, Builder2, addingType);
1021+
else
1022+
llvm::errs() << " warning: assuming integral for " << SI << "\n";
1023+
}
1024+
if (dif2) {
1025+
Type *addingType = TR.addingType(size, orig_op2);
1026+
if (addingType || !looseTypeAnalysis)
1027+
addToDiffe(orig_op2, dif2, Builder2, addingType);
1028+
else
1029+
llvm::errs() << " warning: assuming integral for " << SI << "\n";
1030+
}
10211031
}
10221032

10231033
void createSelectInstDual(llvm::SelectInst &SI) {
@@ -1710,8 +1720,12 @@ class AdjointGenerator
17101720
}
17111721
goto def;
17121722
}
1723+
case Instruction::Mul:
1724+
case Instruction::Sub:
17131725
case Instruction::Add: {
17141726
if (looseTypeAnalysis) {
1727+
llvm::errs() << "warning: binary operator is integer and constant: "
1728+
<< BO << "\n";
17151729
// if loose type analysis, assume this integer add is constant
17161730
return;
17171731
}

enzyme/Enzyme/GradientUtils.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -2509,12 +2509,12 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
25092509
Vals.push_back(cast<Constant>(
25102510
invertPointerM(CD->getElementAsConstant(i), BuilderM)));
25112511
}
2512-
return ConstantDataArray::get(CD->getContext(), Vals);
2512+
return ConstantArray::get(CD->getType(), Vals);
25132513
} else if (auto CD = dyn_cast<ConstantArray>(oval)) {
25142514
SmallVector<Constant *, 1> Vals;
25152515
for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) {
2516-
Vals.push_back(
2517-
cast<Constant>(invertPointerM(CD->getOperand(i), BuilderM)));
2516+
Value *val = invertPointerM(CD->getOperand(i), BuilderM);
2517+
Vals.push_back(cast<Constant>(val));
25182518
}
25192519
return ConstantArray::get(CD->getType(), Vals);
25202520
} else if (auto CD = dyn_cast<ConstantStruct>(oval)) {

0 commit comments

Comments
 (0)