@@ -1166,45 +1166,76 @@ class AdjointGenerator
1166
1166
eraseIfUnused (IEI);
1167
1167
if (gutils->isConstantInstruction (&IEI))
1168
1168
return ;
1169
- if (Mode == DerivativeMode::ReverseModePrimal)
1170
- return ;
1171
1169
1172
- IRBuilder<> Builder2 (IEI.getParent ());
1173
- getReverseBuilder (Builder2);
1170
+ switch (Mode) {
1171
+ case DerivativeMode::ForwardMode: {
1172
+ IRBuilder<> Builder2 (&IEI);
1173
+ getForwardBuilder (Builder2);
1174
1174
1175
- Value *dif1 = diffe (&IEI, Builder2);
1175
+ Value *orig_vector = IEI.getOperand (0 );
1176
+ Value *orig_inserted = IEI.getOperand (1 );
1177
+ Value *orig_index = IEI.getOperand (2 );
1176
1178
1177
- Value *orig_op0 = IEI.getOperand (0 );
1178
- Value *orig_op1 = IEI.getOperand (1 );
1179
- Value *op1 = gutils->getNewFromOriginal (orig_op1);
1180
- Value *op2 = gutils->getNewFromOriginal (IEI.getOperand (2 ));
1179
+ Value *diff_inserted = gutils->isConstantValue (orig_inserted)
1180
+ ? ConstantFP::get (orig_inserted->getType (), 0 )
1181
+ : diffe (orig_inserted, Builder2);
1181
1182
1182
- size_t size0 = 1 ;
1183
- if (orig_op0->getType ()->isSized ())
1184
- size0 = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1185
- orig_op0->getType ()) +
1186
- 7 ) /
1187
- 8 ;
1188
- size_t size1 = 1 ;
1189
- if (orig_op1->getType ()->isSized ())
1190
- size1 = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1191
- orig_op1->getType ()) +
1192
- 7 ) /
1193
- 8 ;
1183
+ Value *prediff =
1184
+ gutils->isConstantValue (orig_vector)
1185
+ ? diffe (orig_vector, Builder2)
1186
+ : ConstantVector::getNullValue (orig_vector->getType ());
1187
+ auto dindex = Builder2.CreateInsertElement (
1188
+ prediff, diff_inserted, gutils->getNewFromOriginal (orig_index));
1189
+ setDiffe (&IEI, dindex, Builder2);
1194
1190
1195
- if (!gutils-> isConstantValue (orig_op0))
1196
- addToDiffe (orig_op0,
1197
- Builder2. CreateInsertElement (
1198
- dif1, Constant::getNullValue (op1-> getType ()),
1199
- lookup (op2, Builder2)),
1200
- Builder2, TR. addingType (size0, orig_op0) );
1191
+ return ;
1192
+ }
1193
+ case DerivativeMode::ReverseModeGradient:
1194
+ case DerivativeMode::ReverseModeCombined: {
1195
+ IRBuilder<> Builder2 (IEI. getParent ());
1196
+ getReverseBuilder (Builder2 );
1201
1197
1202
- if (!gutils->isConstantValue (orig_op1))
1203
- addToDiffe (orig_op1,
1204
- Builder2.CreateExtractElement (dif1, lookup (op2, Builder2)),
1205
- Builder2, TR.addingType (size1, orig_op1));
1198
+ Value *dif1 = diffe (&IEI, Builder2);
1199
+
1200
+ Value *orig_op0 = IEI.getOperand (0 );
1201
+ Value *orig_op1 = IEI.getOperand (1 );
1202
+ Value *op1 = gutils->getNewFromOriginal (orig_op1);
1203
+ Value *op2 = gutils->getNewFromOriginal (IEI.getOperand (2 ));
1206
1204
1207
- setDiffe (&IEI, Constant::getNullValue (IEI.getType ()), Builder2);
1205
+ size_t size0 = 1 ;
1206
+ if (orig_op0->getType ()->isSized ())
1207
+ size0 =
1208
+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1209
+ orig_op0->getType ()) +
1210
+ 7 ) /
1211
+ 8 ;
1212
+ size_t size1 = 1 ;
1213
+ if (orig_op1->getType ()->isSized ())
1214
+ size1 =
1215
+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1216
+ orig_op1->getType ()) +
1217
+ 7 ) /
1218
+ 8 ;
1219
+
1220
+ if (!gutils->isConstantValue (orig_op0))
1221
+ addToDiffe (orig_op0,
1222
+ Builder2.CreateInsertElement (
1223
+ dif1, Constant::getNullValue (op1->getType ()),
1224
+ lookup (op2, Builder2)),
1225
+ Builder2, TR.addingType (size0, orig_op0));
1226
+
1227
+ if (!gutils->isConstantValue (orig_op1))
1228
+ addToDiffe (orig_op1,
1229
+ Builder2.CreateExtractElement (dif1, lookup (op2, Builder2)),
1230
+ Builder2, TR.addingType (size1, orig_op1));
1231
+
1232
+ setDiffe (&IEI, Constant::getNullValue (IEI.getType ()), Builder2);
1233
+ return ;
1234
+ }
1235
+ case DerivativeMode::ReverseModePrimal: {
1236
+ return ;
1237
+ }
1238
+ }
1208
1239
}
1209
1240
1210
1241
void visitShuffleVectorInst (llvm::ShuffleVectorInst &SVI) {
0 commit comments