Skip to content

Commit d0e0331

Browse files
authored
ForwardMode: insertelement inst (rust-lang#355)
* ForwardMode insertelement inst
1 parent ef3a0ac commit d0e0331

File tree

1 file changed

+63
-32
lines changed

1 file changed

+63
-32
lines changed

enzyme/Enzyme/AdjointGenerator.h

+63-32
Original file line numberDiff line numberDiff line change
@@ -1166,45 +1166,76 @@ class AdjointGenerator
11661166
eraseIfUnused(IEI);
11671167
if (gutils->isConstantInstruction(&IEI))
11681168
return;
1169-
if (Mode == DerivativeMode::ReverseModePrimal)
1170-
return;
11711169

1172-
IRBuilder<> Builder2(IEI.getParent());
1173-
getReverseBuilder(Builder2);
1170+
switch (Mode) {
1171+
case DerivativeMode::ForwardMode: {
1172+
IRBuilder<> Builder2(&IEI);
1173+
getForwardBuilder(Builder2);
11741174

1175-
Value *dif1 = diffe(&IEI, Builder2);
1175+
Value *orig_vector = IEI.getOperand(0);
1176+
Value *orig_inserted = IEI.getOperand(1);
1177+
Value *orig_index = IEI.getOperand(2);
11761178

1177-
Value *orig_op0 = IEI.getOperand(0);
1178-
Value *orig_op1 = IEI.getOperand(1);
1179-
Value *op1 = gutils->getNewFromOriginal(orig_op1);
1180-
Value *op2 = gutils->getNewFromOriginal(IEI.getOperand(2));
1179+
Value *diff_inserted = gutils->isConstantValue(orig_inserted)
1180+
? ConstantFP::get(orig_inserted->getType(), 0)
1181+
: diffe(orig_inserted, Builder2);
11811182

1182-
size_t size0 = 1;
1183-
if (orig_op0->getType()->isSized())
1184-
size0 = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1185-
orig_op0->getType()) +
1186-
7) /
1187-
8;
1188-
size_t size1 = 1;
1189-
if (orig_op1->getType()->isSized())
1190-
size1 = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1191-
orig_op1->getType()) +
1192-
7) /
1193-
8;
1183+
Value *prediff =
1184+
gutils->isConstantValue(orig_vector)
1185+
? diffe(orig_vector, Builder2)
1186+
: ConstantVector::getNullValue(orig_vector->getType());
1187+
auto dindex = Builder2.CreateInsertElement(
1188+
prediff, diff_inserted, gutils->getNewFromOriginal(orig_index));
1189+
setDiffe(&IEI, dindex, Builder2);
11941190

1195-
if (!gutils->isConstantValue(orig_op0))
1196-
addToDiffe(orig_op0,
1197-
Builder2.CreateInsertElement(
1198-
dif1, Constant::getNullValue(op1->getType()),
1199-
lookup(op2, Builder2)),
1200-
Builder2, TR.addingType(size0, orig_op0));
1191+
return;
1192+
}
1193+
case DerivativeMode::ReverseModeGradient:
1194+
case DerivativeMode::ReverseModeCombined: {
1195+
IRBuilder<> Builder2(IEI.getParent());
1196+
getReverseBuilder(Builder2);
12011197

1202-
if (!gutils->isConstantValue(orig_op1))
1203-
addToDiffe(orig_op1,
1204-
Builder2.CreateExtractElement(dif1, lookup(op2, Builder2)),
1205-
Builder2, TR.addingType(size1, orig_op1));
1198+
Value *dif1 = diffe(&IEI, Builder2);
1199+
1200+
Value *orig_op0 = IEI.getOperand(0);
1201+
Value *orig_op1 = IEI.getOperand(1);
1202+
Value *op1 = gutils->getNewFromOriginal(orig_op1);
1203+
Value *op2 = gutils->getNewFromOriginal(IEI.getOperand(2));
12061204

1207-
setDiffe(&IEI, Constant::getNullValue(IEI.getType()), Builder2);
1205+
size_t size0 = 1;
1206+
if (orig_op0->getType()->isSized())
1207+
size0 =
1208+
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1209+
orig_op0->getType()) +
1210+
7) /
1211+
8;
1212+
size_t size1 = 1;
1213+
if (orig_op1->getType()->isSized())
1214+
size1 =
1215+
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1216+
orig_op1->getType()) +
1217+
7) /
1218+
8;
1219+
1220+
if (!gutils->isConstantValue(orig_op0))
1221+
addToDiffe(orig_op0,
1222+
Builder2.CreateInsertElement(
1223+
dif1, Constant::getNullValue(op1->getType()),
1224+
lookup(op2, Builder2)),
1225+
Builder2, TR.addingType(size0, orig_op0));
1226+
1227+
if (!gutils->isConstantValue(orig_op1))
1228+
addToDiffe(orig_op1,
1229+
Builder2.CreateExtractElement(dif1, lookup(op2, Builder2)),
1230+
Builder2, TR.addingType(size1, orig_op1));
1231+
1232+
setDiffe(&IEI, Constant::getNullValue(IEI.getType()), Builder2);
1233+
return;
1234+
}
1235+
case DerivativeMode::ReverseModePrimal: {
1236+
return;
1237+
}
1238+
}
12081239
}
12091240

12101241
void visitShuffleVectorInst(llvm::ShuffleVectorInst &SVI) {

0 commit comments

Comments
 (0)