Skip to content

Commit 4d370a8

Browse files
committed
[SYSTEMML-2273] Codegen support for nary min/max in cell/row templates
This patch adds operator fusion and code generation support for the new nary min/max operator in cell and row templates. For both scenarios, we map the nary operator to binary codegen operations which are anyway fused on a cell or row level without additional overhead and thus no need for additional template or runtime primitives.
1 parent 8d32079 commit 4d370a8

File tree

8 files changed

+207
-22
lines changed

8 files changed

+207
-22
lines changed

src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java

+16-3
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@
3737
import org.apache.sysml.hops.Hop.DataGenMethod;
3838
import org.apache.sysml.hops.Hop.OpOp2;
3939
import org.apache.sysml.hops.Hop.OpOp3;
40+
import org.apache.sysml.hops.Hop.OpOpN;
4041
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
4142
import org.apache.sysml.hops.IndexingOp;
4243
import org.apache.sysml.hops.LiteralOp;
44+
import org.apache.sysml.hops.NaryOp;
4345
import org.apache.sysml.hops.ParameterizedBuiltinOp;
4446
import org.apache.sysml.hops.TernaryOp;
4547
import org.apache.sysml.hops.codegen.cplan.CNode;
@@ -82,7 +84,8 @@ public boolean open(Hop hop) {
8284
|| (hop instanceof IndexingOp && hop.getInput().get(0).getDim2() >= 0
8385
&& (((IndexingOp)hop).isColLowerEqualsUpper() || hop.getDim2()==1))
8486
|| (HopRewriteUtils.isDataGenOpWithLiteralInputs(hop, DataGenMethod.SEQ)
85-
&& HopRewriteUtils.hasOnlyUnaryBinaryParents(hop, true));
87+
&& HopRewriteUtils.hasOnlyUnaryBinaryParents(hop, true))
88+
|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix());
8689
}
8790

8891
@Override
@@ -93,7 +96,8 @@ public boolean fuse(Hop hop, Hop input) {
9396
&& hop.getDim1()==1 && hop.getDim2()==1)
9497
&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))
9598
|| (HopRewriteUtils.isTransposeOperation(hop)
96-
&& hop.getDim1()==1 && hop.getDim2()>1));
99+
&& hop.getDim1()==1 && hop.getDim2()>1))
100+
|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix());
97101
}
98102

99103
@Override
@@ -103,7 +107,8 @@ public boolean merge(Hop hop, Hop input) {
103107
|| (hop instanceof AggBinaryOp && hop.getInput().indexOf(input)==0
104108
&& HopRewriteUtils.isTransposeOperation(input))))
105109
|| (HopRewriteUtils.isDataGenOpWithLiteralInputs(input, DataGenMethod.SEQ)
106-
&& HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false));
110+
&& HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false))
111+
|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix());
107112
}
108113

109114
@Override
@@ -221,6 +226,14 @@ else if(hop instanceof TernaryOp) {
221226
out = new CNodeTernary(cdata1, cdata2, cdata3,
222227
TernaryType.valueOf(top.getOp().name()));
223228
}
229+
else if(HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX)) {
230+
String op = ((NaryOp)hop).getOp().name();
231+
CNode[] inputs = hop.getInput().stream().map(c ->
232+
TemplateUtils.wrapLookupIfNecessary(tmp.get(c.getHopID()), c)).toArray(CNode[]::new);
233+
out = new CNodeBinary(inputs[0], inputs[1], BinType.valueOf(op));
234+
for( int i=2; i<hop.getInput().size(); i++ )
235+
out = new CNodeBinary(out, inputs[i], BinType.valueOf(op));
236+
}
224237
else if( hop instanceof ParameterizedBuiltinOp ) {
225238
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
226239
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));

src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java

+28-15
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.sysml.hops.Hop;
3232
import org.apache.sysml.hops.IndexingOp;
3333
import org.apache.sysml.hops.LiteralOp;
34+
import org.apache.sysml.hops.NaryOp;
3435
import org.apache.sysml.hops.ParameterizedBuiltinOp;
3536
import org.apache.sysml.hops.TernaryOp;
3637
import org.apache.sysml.hops.UnaryOp;
@@ -92,6 +93,7 @@ public boolean open(Hop hop) {
9293
|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
9394
|| HopRewriteUtils.isTernary(hop, OpOp3.PLUS_MULT, OpOp3.MINUS_MULT)
9495
|| (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
96+
|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix())
9597
|| (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV
9698
&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
9799
|| (hop instanceof AggBinaryOp && hop.dimsKnown() && LibMatrixMult.isSkinnyRightHandSide(
@@ -117,6 +119,7 @@ public boolean fuse(Hop hop, Hop input) {
117119
( (hop instanceof BinaryOp && isValidBinaryOperation(hop))
118120
|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
119121
|| (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
122+
|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix())
120123
|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp)
121124
&& TemplateCell.isValidOperation(hop))
122125
|| HopRewriteUtils.isTernary(hop, OpOp3.PLUS_MULT, OpOp3.MINUS_MULT)
@@ -141,6 +144,7 @@ public boolean merge(Hop hop, Hop input) {
141144
&& hop.getDim1() > 1 && input.getDim1()>1)
142145
|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
143146
|| (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
147+
|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix())
144148
|| (HopRewriteUtils.isDataGenOpWithLiteralInputs(input, DataGenMethod.SEQ)
145149
&& HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false))
146150
|| (hop instanceof AggBinaryOp
@@ -408,19 +412,11 @@ else if(hop instanceof BinaryOp)
408412
&& (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix())))
409413
{
410414
if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) {
411-
if( TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2)
412-
|| TemplateUtils.isRowVector(cdata2)) ) {
413-
String opname = "VECT_"+((BinaryOp)hop).getOp().name();
414-
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
415-
}
416-
else {
417-
String opname = "VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR";
418-
if( TemplateUtils.isColVector(cdata1) )
419-
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
420-
if( TemplateUtils.isColVector(cdata2) )
421-
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
422-
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
423-
}
415+
if( TemplateUtils.isColVector(cdata1) )
416+
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
417+
if( TemplateUtils.isColVector(cdata2) )
418+
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
419+
out = getVectorBinary(cdata1, cdata2, ((BinaryOp)hop).getOp().name());
424420
if( cdata1 instanceof CNodeData && !inHops2.containsKey("X")
425421
&& !(cdata1.getDataType()==DataType.SCALAR) ) {
426422
inHops2.put("X", hop.getInput().get(0));
@@ -463,7 +459,7 @@ else if(hop instanceof TernaryOp) {
463459
TernaryType.valueOf(top.getOp().name()));
464460
}
465461
}
466-
else if(HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
462+
else if( hop instanceof NaryOp ) {
467463
CNode[] inputs = new CNode[hop.getInput().size()];
468464
for( int i=0; i<hop.getInput().size(); i++ ) {
469465
Hop c = hop.getInput().get(i);
@@ -474,7 +470,14 @@ else if(HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
474470
if( i==0 && cdata instanceof CNodeData && !inHops2.containsKey("X") )
475471
inHops2.put("X", c);
476472
}
477-
out = new CNodeNary(inputs, NaryType.VECT_CBIND);
473+
if( HopRewriteUtils.isNary(hop, OpOpN.CBIND) ) {
474+
out = new CNodeNary(inputs, NaryType.VECT_CBIND);
475+
}
476+
else if( HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) ) {
477+
out = getVectorBinary(inputs[0], inputs[1], ((NaryOp)hop).getOp().name());
478+
for( int i=2; i<hop.getInput().size(); i++ )
479+
out = getVectorBinary(out, inputs[i], ((NaryOp)hop).getOp().name());
480+
}
478481
}
479482
else if( hop instanceof ParameterizedBuiltinOp ) {
480483
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
@@ -505,6 +508,16 @@ else if( hop instanceof IndexingOp ) {
505508
tmp.put(hop.getHopID(), out);
506509
}
507510

511+
private CNodeBinary getVectorBinary(CNode cdata1, CNode cdata2, String name) {
512+
if( TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2)
513+
|| TemplateUtils.isRowVector(cdata2)) ) {
514+
return new CNodeBinary(cdata1, cdata2, BinType.valueOf("VECT_"+name));
515+
}
516+
else {
517+
return new CNodeBinary(cdata1, cdata2, BinType.valueOf("VECT_"+name+"_SCALAR"));
518+
}
519+
}
520+
508521
/**
509522
* Comparator to order input hops of the row aggregate template. We try
510523
* to order matrices-vectors-scalars via sorting by number of cells but

src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java

+21-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
5858
private static final String TEST_NAME20 = TEST_NAME+20; //bitwAnd() operation
5959
private static final String TEST_NAME21 = TEST_NAME+21; //relu operation, (X>0)*dout
6060
private static final String TEST_NAME22 = TEST_NAME+22; //sum(X * seq(1,N) + t(seq(M,1)))
61+
private static final String TEST_NAME23 = TEST_NAME+23; //sum(min(X,Y,Z))
6162

6263
private static final String TEST_DIR = "functions/codegen/";
6364
private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/";
@@ -70,7 +71,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
7071
@Override
7172
public void setUp() {
7273
TestUtils.clearAssertionInformation();
73-
for( int i=1; i<=22; i++ ) {
74+
for( int i=1; i<=23; i++ ) {
7475
addTestConfiguration( TEST_NAME+i, new TestConfiguration(
7576
TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) );
7677
}
@@ -382,6 +383,21 @@ public void testCodegenCellwise22() {
382383
public void testCodegenCellwiseRewrite22_sp() {
383384
testCodegenIntegration( TEST_NAME22, true, ExecType.SPARK );
384385
}
386+
387+
@Test
388+
public void testCodegenCellwiseRewrite23() {
389+
testCodegenIntegration( TEST_NAME23, true, ExecType.CP );
390+
}
391+
392+
@Test
393+
public void testCodegenCellwise23() {
394+
testCodegenIntegration( TEST_NAME23, false, ExecType.CP );
395+
}
396+
397+
@Test
398+
public void testCodegenCellwiseRewrite23_sp() {
399+
testCodegenIntegration( TEST_NAME23, true, ExecType.SPARK );
400+
}
385401

386402
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
387403
{
@@ -409,7 +425,7 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType
409425

410426
String HOME = SCRIPT_DIR + TEST_DIR;
411427
fullDMLScriptName = HOME + testname + ".dml";
412-
programArgs = new String[]{"-explain", "hops", "-stats", "-args", output("S") };
428+
programArgs = new String[]{"-explain", "-stats", "-args", output("S") };
413429

414430
fullRScriptName = HOME + testname + ".R";
415431
rCmd = getRCmd(inputDir(), expectedDir());
@@ -420,7 +436,7 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType
420436
runRScript(true);
421437

422438
if(testname.equals(TEST_NAME6) || testname.equals(TEST_NAME7)
423-
|| testname.equals(TEST_NAME9) || testname.equals(TEST_NAME10) ) {
439+
|| testname.equals(TEST_NAME9) || testname.equals(TEST_NAME10)) {
424440
//compare scalars
425441
HashMap<CellIndex, Double> dmlfile = readDMLScalarFromHDFS("S");
426442
HashMap<CellIndex, Double> rfile = readRScalarFromFS("S");
@@ -451,6 +467,8 @@ else if( testname.equals(TEST_NAME17) )
451467
Assert.assertTrue(!heavyHittersContainsSubString("xor"));
452468
else if( testname.equals(TEST_NAME22) )
453469
Assert.assertTrue(!heavyHittersContainsSubString("seq"));
470+
else if( testname.equals(TEST_NAME23) )
471+
Assert.assertTrue(!heavyHittersContainsSubString("min","nmin"));
454472
}
455473
finally {
456474
rtplatform = platformOld;

src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java

+20-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ public class RowAggTmplTest extends AutomatedTestBase
7878
private static final String TEST_NAME39 = TEST_NAME+"39"; //BitwAnd operation
7979
private static final String TEST_NAME40 = TEST_NAME+"40"; //relu operation -> (X>0)* dout
8080
private static final String TEST_NAME41 = TEST_NAME+"41"; //X*rowSums(X/seq(1,N)+t(seq(M,1)))
81+
private static final String TEST_NAME42 = TEST_NAME+"42"; //X/rowSums(min(X, Y, Z))
8182

8283
private static final String TEST_DIR = "functions/codegen/";
8384
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/";
@@ -89,7 +90,7 @@ public class RowAggTmplTest extends AutomatedTestBase
8990
@Override
9091
public void setUp() {
9192
TestUtils.clearAssertionInformation();
92-
for(int i=1; i<=41; i++)
93+
for(int i=1; i<=42; i++)
9394
addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) );
9495
}
9596

@@ -707,6 +708,21 @@ public void testCodegenRowAgg41CP() {
707708
public void testCodegenRowAgg41SP() {
708709
testCodegenIntegration( TEST_NAME41, false, ExecType.SPARK );
709710
}
711+
712+
@Test
713+
public void testCodegenRowAggRewrite42CP() {
714+
testCodegenIntegration( TEST_NAME42, true, ExecType.CP );
715+
}
716+
717+
@Test
718+
public void testCodegenRowAgg42CP() {
719+
testCodegenIntegration( TEST_NAME42, false, ExecType.CP );
720+
}
721+
722+
@Test
723+
public void testCodegenRowAgg42SP() {
724+
testCodegenIntegration( TEST_NAME42, false, ExecType.SPARK );
725+
}
710726

711727
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
712728
{
@@ -766,6 +782,9 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType
766782
Assert.assertTrue(!heavyHittersContainsSubString("xor"));
767783
if( testname.equals(TEST_NAME41) )
768784
Assert.assertTrue(!heavyHittersContainsSubString("seq"));
785+
if( testname.equals(TEST_NAME42) )
786+
Assert.assertTrue(!heavyHittersContainsSubString("min","nmin")
787+
&& !heavyHittersContainsSubString("spoof", 2));
769788
}
770789
finally {
771790
rtplatform = platformOld;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args<-commandArgs(TRUE)
23+
options(digits=22)
24+
library("Matrix")
25+
26+
X = matrix(6, 500, 2);
27+
Y = matrix(7, 500, 2);
28+
Z = matrix(8, 500, 2);
29+
R = as.matrix(sum(pmin(X,Y,Z)));
30+
31+
writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = matrix(6, 500, 2);
23+
Y = matrix(7, 500, 2);
24+
Z = matrix(8, 500, 2);
25+
26+
while(FALSE){}
27+
28+
R = as.matrix(sum(min(X,Y,Z)));
29+
30+
write(R, $1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args<-commandArgs(TRUE)
23+
options(digits=22)
24+
library("Matrix")
25+
26+
X = matrix(6, 500, 2);
27+
Y = matrix(7, 500, 2);
28+
Z = matrix(8, 500, 2);
29+
R = X / (rowSums(pmin(X, Y, Z)) %*% matrix(1,1,2));
30+
31+
writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = matrix(6, 500, 2);
23+
Y = matrix(7, 500, 2);
24+
Z = matrix(8, 500, 2);
25+
26+
while(FALSE){}
27+
28+
R = X / rowSums(min(X, Y, Z));
29+
30+
write(R, $1)

0 commit comments

Comments
 (0)