31
31
import org .apache .sysml .hops .Hop ;
32
32
import org .apache .sysml .hops .IndexingOp ;
33
33
import org .apache .sysml .hops .LiteralOp ;
34
+ import org .apache .sysml .hops .NaryOp ;
34
35
import org .apache .sysml .hops .ParameterizedBuiltinOp ;
35
36
import org .apache .sysml .hops .TernaryOp ;
36
37
import org .apache .sysml .hops .UnaryOp ;
@@ -92,6 +93,7 @@ public boolean open(Hop hop) {
92
93
|| (HopRewriteUtils .isBinary (hop , OpOp2 .CBIND ) && hop .getInput ().get (0 ).isMatrix () && hop .dimsKnown ())
93
94
|| HopRewriteUtils .isTernary (hop , OpOp3 .PLUS_MULT , OpOp3 .MINUS_MULT )
94
95
|| (HopRewriteUtils .isNary (hop , OpOpN .CBIND ) && hop .getInput ().get (0 ).isMatrix () && hop .dimsKnown ())
96
+ || (HopRewriteUtils .isNary (hop , OpOpN .MIN , OpOpN .MAX ) && hop .isMatrix ())
95
97
|| (hop instanceof AggBinaryOp && hop .dimsKnown () && hop .getDim2 ()==1 //MV
96
98
&& hop .getInput ().get (0 ).getDim1 ()>1 && hop .getInput ().get (0 ).getDim2 ()>1 )
97
99
|| (hop instanceof AggBinaryOp && hop .dimsKnown () && LibMatrixMult .isSkinnyRightHandSide (
@@ -117,6 +119,7 @@ public boolean fuse(Hop hop, Hop input) {
117
119
( (hop instanceof BinaryOp && isValidBinaryOperation (hop ))
118
120
|| (HopRewriteUtils .isBinary (hop , OpOp2 .CBIND ) && hop .getInput ().get (0 ).isMatrix () && hop .dimsKnown ())
119
121
|| (HopRewriteUtils .isNary (hop , OpOpN .CBIND ) && hop .getInput ().get (0 ).isMatrix () && hop .dimsKnown ())
122
+ || (HopRewriteUtils .isNary (hop , OpOpN .MIN , OpOpN .MAX ) && hop .isMatrix ())
120
123
|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp )
121
124
&& TemplateCell .isValidOperation (hop ))
122
125
|| HopRewriteUtils .isTernary (hop , OpOp3 .PLUS_MULT , OpOp3 .MINUS_MULT )
@@ -141,6 +144,7 @@ public boolean merge(Hop hop, Hop input) {
141
144
&& hop .getDim1 () > 1 && input .getDim1 ()>1 )
142
145
|| (HopRewriteUtils .isBinary (hop , OpOp2 .CBIND ) && hop .getInput ().get (0 ).isMatrix () && hop .dimsKnown ())
143
146
|| (HopRewriteUtils .isNary (hop , OpOpN .CBIND ) && hop .getInput ().get (0 ).isMatrix () && hop .dimsKnown ())
147
+ || (HopRewriteUtils .isNary (hop , OpOpN .MIN , OpOpN .MAX ) && hop .isMatrix ())
144
148
|| (HopRewriteUtils .isDataGenOpWithLiteralInputs (input , DataGenMethod .SEQ )
145
149
&& HopRewriteUtils .hasOnlyUnaryBinaryParents (input , false ))
146
150
|| (hop instanceof AggBinaryOp
@@ -408,19 +412,11 @@ else if(hop instanceof BinaryOp)
408
412
&& (cdata1 .getDataType ().isMatrix () || cdata2 .getDataType ().isMatrix ())))
409
413
{
410
414
if ( HopRewriteUtils .isBinary (hop , SUPPORTED_VECT_BINARY ) ) {
411
- if ( TemplateUtils .isMatrix (cdata1 ) && (TemplateUtils .isMatrix (cdata2 )
412
- || TemplateUtils .isRowVector (cdata2 )) ) {
413
- String opname = "VECT_" +((BinaryOp )hop ).getOp ().name ();
414
- out = new CNodeBinary (cdata1 , cdata2 , BinType .valueOf (opname ));
415
- }
416
- else {
417
- String opname = "VECT_" +((BinaryOp )hop ).getOp ().name ()+"_SCALAR" ;
418
- if ( TemplateUtils .isColVector (cdata1 ) )
419
- cdata1 = new CNodeUnary (cdata1 , UnaryType .LOOKUP_R );
420
- if ( TemplateUtils .isColVector (cdata2 ) )
421
- cdata2 = new CNodeUnary (cdata2 , UnaryType .LOOKUP_R );
422
- out = new CNodeBinary (cdata1 , cdata2 , BinType .valueOf (opname ));
423
- }
415
+ if ( TemplateUtils .isColVector (cdata1 ) )
416
+ cdata1 = new CNodeUnary (cdata1 , UnaryType .LOOKUP_R );
417
+ if ( TemplateUtils .isColVector (cdata2 ) )
418
+ cdata2 = new CNodeUnary (cdata2 , UnaryType .LOOKUP_R );
419
+ out = getVectorBinary (cdata1 , cdata2 , ((BinaryOp )hop ).getOp ().name ());
424
420
if ( cdata1 instanceof CNodeData && !inHops2 .containsKey ("X" )
425
421
&& !(cdata1 .getDataType ()==DataType .SCALAR ) ) {
426
422
inHops2 .put ("X" , hop .getInput ().get (0 ));
@@ -463,7 +459,7 @@ else if(hop instanceof TernaryOp) {
463
459
TernaryType .valueOf (top .getOp ().name ()));
464
460
}
465
461
}
466
- else if (HopRewriteUtils . isNary ( hop , OpOpN . CBIND ) ) {
462
+ else if ( hop instanceof NaryOp ) {
467
463
CNode [] inputs = new CNode [hop .getInput ().size ()];
468
464
for ( int i =0 ; i <hop .getInput ().size (); i ++ ) {
469
465
Hop c = hop .getInput ().get (i );
@@ -474,7 +470,14 @@ else if(HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
474
470
if ( i ==0 && cdata instanceof CNodeData && !inHops2 .containsKey ("X" ) )
475
471
inHops2 .put ("X" , c );
476
472
}
477
- out = new CNodeNary (inputs , NaryType .VECT_CBIND );
473
+ if ( HopRewriteUtils .isNary (hop , OpOpN .CBIND ) ) {
474
+ out = new CNodeNary (inputs , NaryType .VECT_CBIND );
475
+ }
476
+ else if ( HopRewriteUtils .isNary (hop , OpOpN .MIN , OpOpN .MAX ) ) {
477
+ out = getVectorBinary (inputs [0 ], inputs [1 ], ((NaryOp )hop ).getOp ().name ());
478
+ for ( int i =2 ; i <hop .getInput ().size (); i ++ )
479
+ out = getVectorBinary (out , inputs [i ], ((NaryOp )hop ).getOp ().name ());
480
+ }
478
481
}
479
482
else if ( hop instanceof ParameterizedBuiltinOp ) {
480
483
CNode cdata1 = tmp .get (((ParameterizedBuiltinOp )hop ).getTargetHop ().getHopID ());
@@ -505,6 +508,16 @@ else if( hop instanceof IndexingOp ) {
505
508
tmp .put (hop .getHopID (), out );
506
509
}
507
510
511
+ private CNodeBinary getVectorBinary (CNode cdata1 , CNode cdata2 , String name ) {
512
+ if ( TemplateUtils .isMatrix (cdata1 ) && (TemplateUtils .isMatrix (cdata2 )
513
+ || TemplateUtils .isRowVector (cdata2 )) ) {
514
+ return new CNodeBinary (cdata1 , cdata2 , BinType .valueOf ("VECT_" +name ));
515
+ }
516
+ else {
517
+ return new CNodeBinary (cdata1 , cdata2 , BinType .valueOf ("VECT_" +name +"_SCALAR" ));
518
+ }
519
+ }
520
+
508
521
/**
509
522
* Comparator to order input hops of the row aggregate template. We try
510
523
* to order matrices-vectors-scalars via sorting by number of cells but
0 commit comments