Skip to content

Commit e581b5a

Browse files
committed
[SYSTEMDS-2575] Fix eval function calls (incorrect pinning of inputs)
This patch fixes an issue of indirect eval function calls where wrong input variable names led to missing pinning of inputs and thus too eager cleanup of these variables (which causes crashes if the inputs are used in other operations of the eval call). The fix is simple. We avoid such inconsistent construction and invocation of fcall instructions by using a narrower interface and constructing the materialized names internally in the fcall.
1 parent 586e910 commit e581b5a

File tree

4 files changed

+22
-31
lines changed

4 files changed

+22
-31
lines changed

src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,10 @@ protected void setupUpdateFunction(String updFunc, ExecutionContext ec) {
7777
CPOperand[] boundInputs = inputs.stream()
7878
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
7979
.toArray(CPOperand[]::new);
80-
ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
81-
.collect(Collectors.toCollection(ArrayList::new));
8280
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
8381
.collect(Collectors.toCollection(ArrayList::new));
8482
_inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
85-
inputNames, func.getInputParamNames(), outputNames, "update function");
83+
func.getInputParamNames(), outputNames, "update function");
8684

8785
// Check the inputs of the update function
8886
checkInput(false, inputs, DataType.MATRIX, Statement.PS_FEATURES);

src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,10 @@ protected void setupAggFunc(ExecutionContext ec, String aggFunc) {
104104
CPOperand[] boundInputs = inputs.stream()
105105
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
106106
.toArray(CPOperand[]::new);
107-
ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
108-
.collect(Collectors.toCollection(ArrayList::new));
109107
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
110108
.collect(Collectors.toCollection(ArrayList::new));
111109
_inst = new FunctionCallCPInstruction(ns, fname, boundInputs,
112-
inputNames, func.getInputParamNames(), outputNames, "aggregate function");
110+
func.getInputParamNames(), outputNames, "aggregate function");
113111
}
114112

115113
public abstract void push(int workerID, ListObject value);

src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java

+14-20
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ public void processInstruction(ExecutionContext ec) {
6767
CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1, inputs.length);
6868
List<String> boundOutputNames = new ArrayList<>();
6969
boundOutputNames.add(output.getName());
70-
List<String> boundInputNames = new ArrayList<>();
71-
for (CPOperand input : boundInputs) {
72-
boundInputNames.add(input.getName());
73-
}
7470

7571
//2. copy the created output matrix
7672
MatrixObject outputMO = new MatrixObject(ec.getMatrixObject(output.getName()));
@@ -103,32 +99,30 @@ public void processInstruction(ExecutionContext ec) {
10399
ec.getVariables().put(varName, in);
104100
boundInputs2[i] = new CPOperand(varName, in);
105101
}
106-
boundInputNames = lo.isNamedList() ? lo.getNames() : fpb.getInputParamNames();
107102
boundInputs = boundInputs2;
108103
}
109104

110105
//5. call the function
111106
FunctionCallCPInstruction fcpi = new FunctionCallCPInstruction(null, funcName,
112-
boundInputs, boundInputNames, fpb.getInputParamNames(), boundOutputNames, "eval func");
107+
boundInputs, fpb.getInputParamNames(), boundOutputNames, "eval func");
113108
fcpi.processInstruction(ec);
114109

115110
//6. convert the result to matrix
116111
Data newOutput = ec.getVariable(output);
117-
if (newOutput instanceof MatrixObject) {
118-
return;
119-
}
120-
MatrixBlock mb = null;
121-
if (newOutput instanceof ScalarObject) {
122-
//convert scalar to matrix
123-
mb = new MatrixBlock(((ScalarObject) newOutput).getDoubleValue());
124-
} else if (newOutput instanceof FrameObject) {
125-
//convert frame to matrix
126-
mb = DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
127-
ec.cleanupCacheableData((FrameObject) newOutput);
112+
if (!(newOutput instanceof MatrixObject)) {
113+
MatrixBlock mb = null;
114+
if (newOutput instanceof ScalarObject) {
115+
//convert scalar to matrix
116+
mb = new MatrixBlock(((ScalarObject) newOutput).getDoubleValue());
117+
} else if (newOutput instanceof FrameObject) {
118+
//convert frame to matrix
119+
mb = DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
120+
ec.cleanupCacheableData((FrameObject) newOutput);
121+
}
122+
outputMO.acquireModify(mb);
123+
outputMO.release();
124+
ec.setVariable(output.getName(), outputMO);
128125
}
129-
outputMO.acquireModify(mb);
130-
outputMO.release();
131-
ec.setVariable(output.getName(), outputMO);
132126

133127
//7. cleanup of variable expanded from list
134128
if( boundInputs2 != null ) {

src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
package org.apache.sysds.runtime.instructions.cp;
2121

2222
import java.util.ArrayList;
23+
import java.util.Arrays;
2324
import java.util.HashSet;
2425
import java.util.List;
26+
import java.util.stream.Collectors;
2527

2628
import org.apache.sysds.api.DMLScript;
2729
import org.apache.sysds.lops.Lop;
@@ -55,12 +57,13 @@ public class FunctionCallCPInstruction extends CPInstruction {
5557
private final List<String> _boundOutputNames;
5658

5759
public FunctionCallCPInstruction(String namespace, String functName, CPOperand[] boundInputs,
58-
List<String> boundInputNames, List<String> funArgNames, List<String> boundOutputNames, String istr) {
60+
List<String> funArgNames, List<String> boundOutputNames, String istr) {
5961
super(CPType.External, null, functName, istr);
6062
_functionName = functName;
6163
_namespace = namespace;
6264
_boundInputs = boundInputs;
63-
_boundInputNames = boundInputNames;
65+
_boundInputNames = Arrays.stream(boundInputs).map(i -> i.getName())
66+
.collect(Collectors.toCollection(ArrayList::new));
6467
_funArgNames = funArgNames;
6568
_boundOutputNames = boundOutputNames;
6669
}
@@ -81,19 +84,17 @@ public static FunctionCallCPInstruction parseInstruction(String str) {
8184
int numInputs = Integer.valueOf(parts[3]);
8285
int numOutputs = Integer.valueOf(parts[4]);
8386
CPOperand[] boundInputs = new CPOperand[numInputs];
84-
List<String> boundInputNames = new ArrayList<>();
8587
List<String> funArgNames = new ArrayList<>();
8688
List<String> boundOutputNames = new ArrayList<>();
8789
for (int i = 0; i < numInputs; i++) {
8890
String[] nameValue = IOUtilFunctions.splitByFirst(parts[5 + i], "=");
8991
boundInputs[i] = new CPOperand(nameValue[1]);
9092
funArgNames.add(nameValue[0]);
91-
boundInputNames.add(boundInputs[i].getName());
9293
}
9394
for (int i = 0; i < numOutputs; i++)
9495
boundOutputNames.add(parts[5 + numInputs + i]);
9596
return new FunctionCallCPInstruction ( namespace, functionName,
96-
boundInputs, boundInputNames, funArgNames, boundOutputNames, str );
97+
boundInputs, funArgNames, boundOutputNames, str );
9798
}
9899

99100
@Override

0 commit comments

Comments
 (0)