@@ -143,7 +143,6 @@ namespace {
143
143
Value *buildMinimalMultiplyDAG (IRBuilder<> &Builder,
144
144
SmallVectorImpl<Factor> &Factors);
145
145
Value *OptimizeMul (BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
146
- void LinearizeExprTree (BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
147
146
Value *RemoveFactorFromExpression (Value *V, Value *Factor);
148
147
void EraseInst (Instruction *I);
149
148
void OptimizeInst (Instruction *I);
@@ -251,10 +250,148 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
251
250
return Res;
252
251
}
253
252
253
+ // / CarmichaelShift - Returns k such that lambda(2^Bitwidth) = 2^k, where lambda
254
+ // / is the Carmichael function. This means that x^(2^k) === 1 mod 2^Bitwidth for
255
+ // / every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic.
256
+ // / Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every
257
+ // / even x in Bitwidth-bit arithmetic.
258
+ static unsigned CarmichaelShift (unsigned Bitwidth) {
259
+ if (Bitwidth < 3 )
260
+ return Bitwidth - 1 ;
261
+ return Bitwidth - 2 ;
262
+ }
263
+
264
+ // / IncorporateWeight - Add the extra weight 'RHS' to the existing weight 'LHS',
265
+ // / reducing the combined weight using any special properties of the operation.
266
+ // / The existing weight LHS represents the computation X op X op ... op X where
267
+ // / X occurs LHS times. The combined weight represents X op X op ... op X with
268
+ // / X occurring LHS + RHS times. If op is "Xor" for example then the combined
269
+ // / operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even;
270
+ // / the routine returns 1 in LHS in the first case, and 0 in LHS in the second.
271
+ static void IncorporateWeight (APInt &LHS, const APInt &RHS, unsigned Opcode) {
272
+ // If we were working with infinite precision arithmetic then the combined
273
+ // weight would be LHS + RHS. But we are using finite precision arithmetic,
274
+ // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct
275
+ // for nilpotent operations and addition, but not for idempotent operations
276
+ // and multiplication), so it is important to correctly reduce the combined
277
+ // weight back into range if wrapping would be wrong.
278
+
279
+ // If RHS is zero then the weight didn't change.
280
+ if (RHS.isMinValue ())
281
+ return ;
282
+ // If LHS is zero then the combined weight is RHS.
283
+ if (LHS.isMinValue ()) {
284
+ LHS = RHS;
285
+ return ;
286
+ }
287
+ // From this point on we know that neither LHS nor RHS is zero.
288
+
289
+ if (Instruction::isIdempotent (Opcode)) {
290
+ // Idempotent means X op X === X, so any non-zero weight is equivalent to a
291
+ // weight of 1. Keeping weights at zero or one also means that wrapping is
292
+ // not a problem.
293
+ assert (LHS == 1 && RHS == 1 && " Weights not reduced!" );
294
+ return ; // Return a weight of 1.
295
+ }
296
+ if (Instruction::isNilpotent (Opcode)) {
297
+ // Nilpotent means X op X === 0, so reduce weights modulo 2.
298
+ assert (LHS == 1 && RHS == 1 && " Weights not reduced!" );
299
+ LHS = 0 ; // 1 + 1 === 0 modulo 2.
300
+ return ;
301
+ }
302
+ if (Opcode == Instruction::Add) {
303
+ // TODO: Reduce the weight by exploiting nsw/nuw?
304
+ LHS += RHS;
305
+ return ;
306
+ }
307
+
308
+ assert (Opcode == Instruction::Mul && " Unknown associative operation!" );
309
+ unsigned Bitwidth = LHS.getBitWidth ();
310
+ // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth
311
+ // can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth
312
+ // bit number x, since either x is odd in which case x^CM = 1, or x is even in
313
+ // which case both x^W and x^(W - CM) are zero. By subtracting off multiples
314
+ // of CM like this weights can always be reduced to the range [0, CM+Bitwidth)
315
+ // which by a happy accident means that they can always be represented using
316
+ // Bitwidth bits.
317
+ // TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than
318
+ // the Carmichael number).
319
+ if (Bitwidth > 3 ) {
320
+ // / CM - The value of Carmichael's lambda function.
321
+ APInt CM = APInt::getOneBitSet (Bitwidth, CarmichaelShift (Bitwidth));
322
+ // Any weight W >= Threshold can be replaced with W - CM.
323
+ APInt Threshold = CM + Bitwidth;
324
+ assert (LHS.ult (Threshold) && RHS.ult (Threshold) && " Weights not reduced!" );
325
+ // For Bitwidth 4 or more the following sum does not overflow.
326
+ LHS += RHS;
327
+ while (LHS.uge (Threshold))
328
+ LHS -= CM;
329
+ } else {
330
+ // To avoid problems with overflow do everything the same as above but using
331
+ // a larger type.
332
+ unsigned CM = 1U << CarmichaelShift (Bitwidth);
333
+ unsigned Threshold = CM + Bitwidth;
334
+ assert (LHS.getZExtValue () < Threshold && RHS.getZExtValue () < Threshold &&
335
+ " Weights not reduced!" );
336
+ unsigned Total = LHS.getZExtValue () + RHS.getZExtValue ();
337
+ while (Total >= Threshold)
338
+ Total -= CM;
339
+ LHS = Total;
340
+ }
341
+ }
342
+
343
+ // / EvaluateRepeatedConstant - Compute C op C op ... op C where the constant C
344
+ // / is repeated Weight times.
345
+ static Constant *EvaluateRepeatedConstant (unsigned Opcode, Constant *C,
346
+ APInt Weight) {
347
+ // For addition the result can be efficiently computed as the product of the
348
+ // constant and the weight.
349
+ if (Opcode == Instruction::Add)
350
+ return ConstantExpr::getMul (C, ConstantInt::get (C->getContext (), Weight));
351
+
352
+ // The weight might be huge, so compute by repeated squaring to ensure that
353
+ // compile time is proportional to the logarithm of the weight.
354
+ Constant *Result = 0 ;
355
+ Constant *Power = C; // Successively C, C op C, (C op C) op (C op C) etc.
356
+ // Visit the bits in Weight.
357
+ while (Weight != 0 ) {
358
+ // If the current bit in Weight is non-zero do Result = Result op Power.
359
+ if (Weight[0 ])
360
+ Result = Result ? ConstantExpr::get (Opcode, Result, Power) : Power;
361
+ // Move on to the next bit if any more are non-zero.
362
+ Weight = Weight.lshr (1 );
363
+ if (Weight.isMinValue ())
364
+ break ;
365
+ // Square the power.
366
+ Power = ConstantExpr::get (Opcode, Power, Power);
367
+ }
368
+
369
+ assert (Result && " Only positive weights supported!" );
370
+ return Result;
371
+ }
372
+
373
+ typedef std::pair<Value*, APInt> RepeatedValue;
374
+
254
375
// / LinearizeExprTree - Given an associative binary expression, return the leaf
255
- // / nodes in Ops. The original expression is the same as Ops[0] op ... Ops[N].
256
- // / Note that a node may occur multiple times in Ops, but if so all occurrences
257
- // / are consecutive in the vector.
376
+ // / nodes in Ops along with their weights (how many times the leaf occurs). The
377
+ // / original expression is the same as
378
+ // / (Ops[0].first op Ops[0].first op ... Ops[0].first) <- Ops[0].second times
379
+ // / op
380
+ // / (Ops[1].first op Ops[1].first op ... Ops[1].first) <- Ops[1].second times
381
+ // / op
382
+ // / ...
383
+ // / op
384
+ // / (Ops[N].first op Ops[N].first op ... Ops[N].first) <- Ops[N].second times
385
+ // /
386
+ // / Note that the values Ops[0].first, ..., Ops[N].first are all distinct, and
387
+ // / they are all non-constant except possibly for the last one, which if it is
388
+ // / constant will have weight one (Ops[N].second === 1).
389
+ // /
390
+ // / This routine may modify the function, in which case it returns 'true'. The
391
+ // / changes it makes may well be destructive, changing the value computed by 'I'
392
+ // / to something completely different. Thus if the routine returns 'true' then
393
+ // / you MUST either replace I with a new expression computed from the Ops array,
394
+ // / or use RewriteExprTree to put the values back in.
258
395
// /
259
396
// / A leaf node is either not a binary operation of the same kind as the root
260
397
// / node 'I' (i.e. is not a binary operator at all, or is, but with a different
@@ -276,7 +413,7 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
276
413
// / + * | F, G
277
414
// /
278
415
// / The leaf nodes are C, E, F and G. The Ops array will contain (maybe not in
279
- // / that order) C, E, F, F, G, G .
416
+ // / that order) ( C, 1), ( E, 1), ( F, 2), (G, 2) .
280
417
// /
281
418
// / The expression is maximal: if some instruction is a binary operator of the
282
419
// / same kind as 'I', and all of its uses are non-leaf nodes of the expression,
@@ -287,7 +424,8 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
287
424
// / order to ensure that every non-root node in the expression has *exactly one*
288
425
// / use by a non-leaf node of the expression. This destruction means that the
289
426
// / caller MUST either replace 'I' with a new expression or use something like
290
- // / RewriteExprTree to put the values back in.
427
+ // / RewriteExprTree to put the values back in if the routine indicates that it
428
+ // / made a change by returning 'true'.
291
429
// /
292
430
// / In the above example either the right operand of A or the left operand of B
293
431
// / will be replaced by undef. If it is B's operand then this gives:
@@ -310,9 +448,14 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
310
448
// / of the expression) if it can turn them into binary operators of the right
311
449
// / type and thus make the expression bigger.
312
450
313
- void Reassociate:: LinearizeExprTree (BinaryOperator *I,
314
- SmallVectorImpl<ValueEntry > &Ops) {
451
+ static bool LinearizeExprTree (BinaryOperator *I,
452
+ SmallVectorImpl<RepeatedValue > &Ops) {
315
453
DEBUG (dbgs () << " LINEARIZE: " << *I << ' \n ' );
454
+ unsigned Bitwidth = I->getType ()->getScalarType ()->getPrimitiveSizeInBits ();
455
+ unsigned Opcode = I->getOpcode ();
456
+ assert (Instruction::isAssociative (Opcode) &&
457
+ Instruction::isCommutative (Opcode) &&
458
+ " Expected an associative and commutative operation!" );
316
459
317
460
// Visit all operands of the expression, keeping track of their weight (the
318
461
// number of paths from the expression root to the operand, or if you like
@@ -324,9 +467,9 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
324
467
// with their weights, representing a certain number of paths to the operator.
325
468
// If an operator occurs in the worklist multiple times then we found multiple
326
469
// ways to get to it.
327
- SmallVector<std::pair<BinaryOperator*, unsigned >, 8 > Worklist; // (Op, Weight)
328
- Worklist.push_back (std::make_pair (I, 1 ));
329
- unsigned Opcode = I-> getOpcode () ;
470
+ SmallVector<std::pair<BinaryOperator*, APInt >, 8 > Worklist; // (Op, Weight)
471
+ Worklist.push_back (std::make_pair (I, APInt (Bitwidth, 1 ) ));
472
+ bool MadeChange = false ;
330
473
331
474
// Leaves of the expression are values that either aren't the right kind of
332
475
// operation (eg: a constant, or a multiply in an add tree), or are, but have
@@ -343,21 +486,20 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
343
486
344
487
// Leaves - Keeps track of the set of putative leaves as well as the number of
345
488
// paths to each leaf seen so far.
346
- typedef SmallMap<Value*, unsigned , 8 > LeafMap;
489
+ typedef SmallMap<Value*, APInt , 8 > LeafMap;
347
490
LeafMap Leaves; // Leaf -> Total weight so far.
348
491
SmallVector<Value*, 8 > LeafOrder; // Ensure deterministic leaf output order.
349
492
350
493
#ifndef NDEBUG
351
494
SmallPtrSet<Value*, 8 > Visited; // For sanity checking the iteration scheme.
352
495
#endif
353
496
while (!Worklist.empty ()) {
354
- std::pair<BinaryOperator*, unsigned > P = Worklist.pop_back_val ();
497
+ std::pair<BinaryOperator*, APInt > P = Worklist.pop_back_val ();
355
498
I = P.first ; // We examine the operands of this binary operator.
356
- assert (P.second >= 1 && " No paths to here, so how did we get here?!" );
357
499
358
500
for (unsigned OpIdx = 0 ; OpIdx < 2 ; ++OpIdx) { // Visit operands.
359
501
Value *Op = I->getOperand (OpIdx);
360
- unsigned Weight = P.second ; // Number of paths to this operand.
502
+ APInt Weight = P.second ; // Number of paths to this operand.
361
503
DEBUG (dbgs () << " OPERAND: " << *Op << " (" << Weight << " )\n " );
362
504
assert (!Op->use_empty () && " No uses, so how did we get to it?!" );
363
505
@@ -389,7 +531,7 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
389
531
assert (Visited.count (Op) && " In leaf map but not visited!" );
390
532
391
533
// Update the number of paths to the leaf.
392
- It->second += Weight;
534
+ IncorporateWeight ( It->second , Weight, Opcode) ;
393
535
394
536
// The leaf already has one use from inside the expression. As we want
395
537
// exactly one such use, drop this new use of the leaf.
@@ -450,21 +592,44 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
450
592
451
593
// The leaves, repeated according to their weights, represent the linearized
452
594
// form of the expression.
595
+ Constant *Cst = 0 ; // Accumulate constants here.
453
596
for (unsigned i = 0 , e = LeafOrder.size (); i != e; ++i) {
454
597
Value *V = LeafOrder[i];
455
598
LeafMap::iterator It = Leaves.find (V);
456
599
if (It == Leaves.end ())
457
- // Leaf already output, or node initially thought to be a leaf wasn't.
600
+ // Node initially thought to be a leaf wasn't.
458
601
continue ;
459
602
assert (!isReassociableOp (V, Opcode) && " Shouldn't be a leaf!" );
460
- unsigned Weight = It->second ;
461
- assert (Weight > 0 && " No paths to this value!" );
462
- // FIXME: Rather than repeating values Weight times, use a vector of
463
- // (ValueEntry, multiplicity) pairs.
464
- Ops.append (Weight, ValueEntry (getRank (V), V));
603
+ APInt Weight = It->second ;
604
+ if (Weight.isMinValue ())
605
+ // Leaf already output or weight reduction eliminated it.
606
+ continue ;
465
607
// Ensure the leaf is only output once.
466
- Leaves.erase (It);
608
+ It->second = 0 ;
609
+ // Glob all constants together into Cst.
610
+ if (Constant *C = dyn_cast<Constant>(V)) {
611
+ C = EvaluateRepeatedConstant (Opcode, C, Weight);
612
+ Cst = Cst ? ConstantExpr::get (Opcode, Cst, C) : C;
613
+ continue ;
614
+ }
615
+ // Add non-constant
616
+ Ops.push_back (std::make_pair (V, Weight));
617
+ }
618
+
619
+ // Add any constants back into Ops, all globbed together and reduced to having
620
+ // weight 1 for the convenience of users.
621
+ if (Cst && Cst != ConstantExpr::getBinOpIdentity (Opcode, I->getType ()))
622
+ Ops.push_back (std::make_pair (Cst, APInt (Bitwidth, 1 )));
623
+
624
+ // For nilpotent operations or addition there may be no operands, for example
625
+ // because the expression was "X xor X" or consisted of 2^Bitwidth additions:
626
+ // in both cases the weight reduces to 0 causing the value to be skipped.
627
+ if (Ops.empty ()) {
628
+ Constant *Identity = ConstantExpr::getBinOpIdentity (Opcode, I->getType ());
629
+ Ops.push_back (std::make_pair (Identity, APInt (Bitwidth, 1 )));
467
630
}
631
+
632
+ return MadeChange;
468
633
}
469
634
470
635
// RewriteExprTree - Now that the operands for this expression tree are
@@ -775,8 +940,15 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) {
775
940
BinaryOperator *BO = isReassociableOp (V, Instruction::Mul);
776
941
if (!BO) return 0 ;
777
942
943
+ SmallVector<RepeatedValue, 8 > Tree;
944
+ MadeChange |= LinearizeExprTree (BO, Tree);
778
945
SmallVector<ValueEntry, 8 > Factors;
779
- LinearizeExprTree (BO, Factors);
946
+ Factors.reserve (Tree.size ());
947
+ for (unsigned i = 0 , e = Tree.size (); i != e; ++i) {
948
+ RepeatedValue E = Tree[i];
949
+ Factors.append (E.second .getZExtValue (),
950
+ ValueEntry (getRank (E.first ), E.first ));
951
+ }
780
952
781
953
bool FoundFactor = false ;
782
954
bool NeedsNegate = false ;
@@ -1439,8 +1611,15 @@ Value *Reassociate::ReassociateExpression(BinaryOperator *I) {
1439
1611
1440
1612
// First, walk the expression tree, linearizing the tree, collecting the
1441
1613
// operand information.
1614
+ SmallVector<RepeatedValue, 8 > Tree;
1615
+ MadeChange |= LinearizeExprTree (I, Tree);
1442
1616
SmallVector<ValueEntry, 8 > Ops;
1443
- LinearizeExprTree (I, Ops);
1617
+ Ops.reserve (Tree.size ());
1618
+ for (unsigned i = 0 , e = Tree.size (); i != e; ++i) {
1619
+ RepeatedValue E = Tree[i];
1620
+ Ops.append (E.second .getZExtValue (),
1621
+ ValueEntry (getRank (E.first ), E.first ));
1622
+ }
1444
1623
1445
1624
DEBUG (dbgs () << " RAIn:\t " ; PrintOps (I, Ops); dbgs () << ' \n ' );
1446
1625
0 commit comments