Skip to content

Commit b2e67b0

Browse files
committed
Rather than gather all the parse trees, then score them all, score them all while processings them. Saves a significant amount of memory. This is especially relevant in the PCFG version (hence not having noticed in two years, since no one retrains that any more these days)
1 parent f636673 commit b2e67b0

File tree

1 file changed

+29
-30
lines changed

1 file changed

+29
-30
lines changed

src/edu/stanford/nlp/parser/metrics/EvaluateTreebank.java

+29-30
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
import java.util.Collections;
1212
import java.util.LinkedList;
1313
import java.util.List;
14+
import java.util.function.BiConsumer;
15+
import java.util.function.Function;
1416

1517
import edu.stanford.nlp.io.NullOutputStream;
18+
import edu.stanford.nlp.io.RuntimeIOException;
1619
import edu.stanford.nlp.ling.*;
1720
import edu.stanford.nlp.ling.SentenceUtils;
1821
import edu.stanford.nlp.math.ArrayMath;
@@ -36,7 +39,6 @@
3639
import edu.stanford.nlp.trees.TreebankLanguagePack;
3740
import edu.stanford.nlp.trees.TreePrint;
3841
import edu.stanford.nlp.trees.TreeTransformer;
39-
import java.util.function.Function;
4042
import edu.stanford.nlp.util.Generics;
4143
import edu.stanford.nlp.util.Pair;
4244
import edu.stanford.nlp.util.ScoredObject;
@@ -557,7 +559,7 @@ else if(pwFileOut != null) {
557559
* Wrapper for a way to pass in a dataset which may need reprocessing to get parse results
558560
*/
559561
public static interface EvaluationDataset {
560-
List<Pair<ParserQuery, Tree>> dataset(PrintWriter pwErr, PrintWriter pwOut, PrintWriter pwFileOut, PrintWriter pwStats, TreePrint treePrint);
562+
void processDataset(PrintWriter pwErr, PrintWriter pwOut, PrintWriter pwFileOut, PrintWriter pwStats, TreePrint treePrint, BiConsumer<ParserQuery, Tree> processResults);
561563

562564
void summarize(PrintWriter pwErr, TreebankLanguagePack tlp);
563565
}
@@ -603,8 +605,7 @@ private List<CoreLabel> getInputSentence(Tree t) {
603605
}
604606
}
605607

606-
public List<Pair<ParserQuery, Tree>> dataset(PrintWriter pwErr, PrintWriter pwOut, PrintWriter pwFileOut, PrintWriter pwStats, TreePrint treePrint) {
607-
List<Pair<ParserQuery, Tree>> results = new ArrayList<>();
608+
public void processDataset(PrintWriter pwErr, PrintWriter pwOut, PrintWriter pwFileOut, PrintWriter pwStats, TreePrint treePrint, BiConsumer<ParserQuery, Tree> processResults) {
608609

609610
if (op.testOptions.testingThreads != 1) {
610611
MulticoreWrapper<List<? extends HasWord>, ParserQuery> wrapper = new MulticoreWrapper<>(op.testOptions.testingThreads, new ParsingThreadsafeProcessor(pqFactory, pwErr));
@@ -619,28 +620,26 @@ public List<Pair<ParserQuery, Tree>> dataset(PrintWriter pwErr, PrintWriter pwOu
619620
while (wrapper.peek()) {
620621
ParserQuery pq = wrapper.poll();
621622
goldTree = goldTrees.poll();
622-
results.add(new Pair<>(pq, goldTree));
623+
processResults.accept(pq, goldTree);
623624
}
624625
} // for tree iterator
625626
wrapper.join();
626627
while (wrapper.peek()) {
627628
ParserQuery pq = wrapper.poll();
628629
Tree goldTree = goldTrees.poll();
629-
results.add(new Pair<>(pq, goldTree));
630+
processResults.accept(pq, goldTree);
630631
}
631632
} else {
633+
ParserQuery pq = pqFactory.parserQuery();
632634
for (Tree goldTree : testTreebank) {
633635
final List<CoreLabel> sentence = getInputSentence(goldTree);
634636

635637
pwErr.println("Parsing [len. " + sentence.size() + "]: " + SentenceUtils.listToString(sentence));
636638

637-
ParserQuery pq = pqFactory.parserQuery();
638639
pq.parseAndReport(sentence, pwErr);
639-
results.add(new Pair<>(pq, goldTree));
640+
processResults.accept(pq, goldTree);
640641
} // for tree iterator
641642
}
642-
643-
return results;
644643
}
645644

646645
public void summarize(PrintWriter pwErr, TreebankLanguagePack tlp) {
@@ -656,8 +655,10 @@ public PreparsedEvaluationDataset(List<Pair<ParserQuery, Tree>> testTreebank) {
656655
this.testTreebank = testTreebank;
657656
}
658657

659-
public List<Pair<ParserQuery, Tree>> dataset(PrintWriter pwErr, PrintWriter pwOut, PrintWriter pwFileOut, PrintWriter pwStats, TreePrint treePrint) {
660-
return testTreebank;
658+
public void processDataset(PrintWriter pwErr, PrintWriter pwOut, PrintWriter pwFileOut, PrintWriter pwStats, TreePrint treePrint, BiConsumer<ParserQuery, Tree> processResults) {
659+
for (Pair<ParserQuery, Tree> result : testTreebank) {
660+
processResults.accept(result.first, result.second);
661+
}
661662
}
662663

663664
public void summarize(PrintWriter pwErr, TreebankLanguagePack tlp) {
@@ -689,55 +690,53 @@ public double testOnTreebank(EvaluationDataset testTreebank) {
689690
TreePrint treePrint = op.testOptions.treePrint(op.tlpParams);
690691
TreebankLangParserParams tlpParams = op.tlpParams;
691692
TreebankLanguagePack tlp = op.langpack();
692-
PrintWriter pwOut, pwErr;
693+
PrintWriter pwOut, pwEvalErr;
693694
if (op.testOptions.quietEvaluation) {
694695
NullOutputStream quiet = new NullOutputStream();
695696
pwOut = tlpParams.pw(quiet);
696-
pwErr = tlpParams.pw(quiet);
697+
pwEvalErr = tlpParams.pw(quiet);
697698
} else {
698699
pwOut = tlpParams.pw();
699-
pwErr = tlpParams.pw(System.err);
700+
pwEvalErr = tlpParams.pw(System.err);
700701
}
701702
if (op.testOptions.verbose) {
702-
testTreebank.summarize(pwErr, tlp);
703+
testTreebank.summarize(pwEvalErr, tlp);
703704
}
704705
if (op.testOptions.evalb) {
705706
EvalbFormatWriter.initEVALBfiles(tlpParams);
706707
}
707708

708-
PrintWriter pwFileOut = null;
709+
final PrintWriter pwFileOut;
709710
if (op.testOptions.writeOutputFiles) {
710711
String fname = op.testOptions.outputFilesPrefix + "." + op.testOptions.outputFilesExtension;
711712
try {
712713
pwFileOut = op.tlpParams.pw(new FileOutputStream(fname));
713714
} catch (IOException ioe) {
714-
ioe.printStackTrace();
715+
throw new RuntimeIOException(ioe);
715716
}
717+
} else {
718+
pwFileOut = null;
716719
}
717720

718-
PrintWriter pwStats = null;
719-
if(op.testOptions.outputkBestEquivocation != null) {
721+
final PrintWriter pwStats;
722+
if (op.testOptions.outputkBestEquivocation != null) {
720723
try {
721724
pwStats = op.tlpParams.pw(new FileOutputStream(op.testOptions.outputkBestEquivocation));
722725
} catch(IOException ioe) {
723-
ioe.printStackTrace();
726+
throw new RuntimeIOException(ioe);
724727
}
728+
} else {
729+
pwStats = null;
725730
}
726731

727-
List<Pair<ParserQuery, Tree>> results = testTreebank.dataset(pwErr, pwOut, pwFileOut, pwStats, treePrint);
728-
for (Pair<ParserQuery, Tree> result : results) {
729-
ParserQuery pq = result.first;
730-
Tree goldTree = result.second;
731-
processResults(pq, goldTree, pwErr, pwOut, pwFileOut, pwStats, treePrint);
732-
}
732+
testTreebank.processDataset(pwEvalErr, pwOut, pwFileOut, pwStats, treePrint,
733+
(pq, goldTree) -> processResults(pq, goldTree, pwEvalErr, pwOut, pwFileOut, pwStats, treePrint));
733734

734735
//Done parsing...print the results of the evaluations
735736
if (treebankTotalTimer != null) {
736737
treebankTotalTimer.done("Testing on treebank");
737738
}
738-
if (op.testOptions.quietEvaluation) {
739-
pwErr = tlpParams.pw(System.err);
740-
}
739+
PrintWriter pwErr = tlpParams.pw(System.err);
741740
if (saidMemMessage) {
742741
ParserUtils.printOutOfMemory(pwErr);
743742
}

0 commit comments

Comments
 (0)