Skip to content

Commit 87bc358

Browse files
committed
[HOTFIX] Fix validation of scalar-scalar binary min/max operations
This recent introduction of nary min/max operations corrupted the language validation path for scalar-scalar operations. This patch fixes various issues related to (1) value type inference, (2) output dimension/blocksize propagation, and (3) the handling of all scalar nary min/max operations.
1 parent 4d370a8 commit 87bc358

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java

+11-9
Original file line numberDiff line numberDiff line change
@@ -574,18 +574,19 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
574574
case MIN:
575575
case MAX:
576576
//min(X), min(X,s), min(s,X), min(s,r), min(X,Y)
577-
//unary
578-
if (getSecondExpr() == null) {
577+
if (getSecondExpr() == null) { //unary
579578
checkNumParameters(1);
580579
checkMatrixParam(getFirstExpr());
581580
output.setDataType(DataType.SCALAR);
582581
output.setValueType(id.getValueType());
583582
output.setDimensions(0, 0);
584583
output.setBlockDimensions (0, 0);
585584
}
586-
587-
//nary operation
588-
else {
585+
else if( getAllExpr().length == 2 ) { //binary
586+
checkNumParameters(2);
587+
setBinaryOutputProperties(output);
588+
}
589+
else { //nary
589590
for( Expression e : getAllExpr() )
590591
checkMatrixScalarParam(e);
591592
setNaryOutputProperties(output);
@@ -1463,17 +1464,18 @@ private void setNaryOutputProperties(DataIdentifier output) {
14631464
e -> e.getOutput().getDataType().isScalar()) ? DataType.SCALAR : DataType.MATRIX;
14641465
Expression firstM = dt.isMatrix() ? Arrays.stream(getAllExpr()).filter(
14651466
e -> e.getOutput().getDataType().isMatrix()).findFirst().get() : null;
1466-
ValueType vt = dt.isMatrix() ? ValueType.DOUBLE : ValueType.BOOLEAN;
1467+
ValueType vt = dt.isMatrix() ? ValueType.DOUBLE : ValueType.INT;
14671468
for( Expression e : getAllExpr() ) {
14681469
vt = computeValueType(e, e.getOutput().getValueType(), vt, true);
14691470
if( e.getOutput().getDataType().isMatrix() )
14701471
checkMatchingDimensions(firstM, e, true);
14711472
}
14721473
output.setDataType(dt);
14731474
output.setValueType(vt);
1474-
output.setDimensions(firstM.getOutput().getDim1(), firstM.getOutput().getDim2());
1475-
output.setBlockDimensions (
1476-
firstM.getOutput().getRowsInBlock(), firstM.getOutput().getColumnsInBlock());
1475+
output.setDimensions(dt.isMatrix() ? firstM.getOutput().getDim1() : 0,
1476+
dt.isMatrix() ? firstM.getOutput().getDim2() : 0);
1477+
output.setBlockDimensions (dt.isMatrix() ? firstM.getOutput().getRowsInBlock() : 0,
1478+
dt.isMatrix() ? firstM.getOutput().getColumnsInBlock() : 0);
14771479
}
14781480

14791481
private void expandArguments() {

0 commit comments

Comments
 (0)