Skip to content

TST: Parametrize JoinNode expansion tests over config needed_outputs #2981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 1, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions nipype/pipeline/engine/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from __future__ import (print_function, division, unicode_literals,
absolute_import)
from builtins import open
import pytest

from .... import config
from ... import engine as pe
from ....interfaces import base as nib
from ....interfaces.utility import IdentityInterface, Function, Merge
Expand Down Expand Up @@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
output1 = nib.traits.Int(desc='ouput')


class IncrementInterface(nib.BaseInterface):
class IncrementInterface(nib.SimpleInterface):
input_spec = IncrementInputSpec
output_spec = IncrementOutputSpec

def _run_interface(self, runtime):
runtime.returncode = 0
self._results['output1'] = self.inputs.input1 + self.inputs.inc
return runtime

def _list_outputs(self):
outputs = self._outputs().get()
outputs['output1'] = self.inputs.input1 + self.inputs.inc
return outputs


_sums = []

Expand All @@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
operands = nib.traits.List(nib.traits.Int, desc='operands')


class SumInterface(nib.BaseInterface):
class SumInterface(nib.SimpleInterface):
input_spec = SumInputSpec
output_spec = SumOutputSpec

def _run_interface(self, runtime):
runtime.returncode = 0
return runtime

def _list_outputs(self):
global _sum
global _sum_operands
outputs = self._outputs().get()
outputs['operands'] = self.inputs.input1
_sum_operands.append(outputs['operands'])
outputs['output1'] = sum(self.inputs.input1)
_sums.append(outputs['output1'])
return outputs
runtime.returncode = 0
self._results['operands'] = self.inputs.input1
self._results['output1'] = sum(self.inputs.input1)
_sum_operands.append(self.inputs.input1)
_sums.append(sum(self.inputs.input1))
return runtime


_set_len = None
Expand Down Expand Up @@ -148,35 +142,48 @@ def _list_outputs(self):
return outputs


def test_join_expansion(tmpdir):
@pytest.mark.parametrize('needed_outputs', ['true', 'false'])
def test_join_expansion(tmpdir, needed_outputs):
global _sums
global _sum_operands
global _products
tmpdir.chdir()

# Clean up, just in case some other test modified them
_products = []
_sum_operands = []
_sums = []

prev_state = config.get('execution', 'remove_unnecessary_outputs')
config.set('execution', 'remove_unnecessary_outputs', needed_outputs)
# Make the workflow.
wf = pe.Workflow(name='test')
# the iterated input node
inputspec = pe.Node(IdentityInterface(fields=['n']), name='inputspec')
inputspec.iterables = [('n', [1, 2])]
# a pre-join node in the iterated path
pre_join1 = pe.Node(IncrementInterface(), name='pre_join1')
wf.connect(inputspec, 'n', pre_join1, 'input1')
# another pre-join node in the iterated path
pre_join2 = pe.Node(IncrementInterface(), name='pre_join2')
wf.connect(pre_join1, 'output1', pre_join2, 'input1')
# the join node
join = pe.JoinNode(
SumInterface(),
joinsource='inputspec',
joinfield='input1',
name='join')
wf.connect(pre_join2, 'output1', join, 'input1')
# an uniterated post-join node
post_join1 = pe.Node(IncrementInterface(), name='post_join1')
wf.connect(join, 'output1', post_join1, 'input1')
# a post-join node in the iterated path
post_join2 = pe.Node(ProductInterface(), name='post_join2')
wf.connect(join, 'output1', post_join2, 'input1')
wf.connect(pre_join1, 'output1', post_join2, 'input2')

wf.connect([
(inputspec, pre_join1, [('n', 'input1')]),
(pre_join1, pre_join2, [('output1', 'input1')]),
(pre_join1, post_join2, [('output1', 'input2')]),
(pre_join2, join, [('output1', 'input1')]),
(join, post_join1, [('output1', 'input1')]),
(join, post_join2, [('output1', 'input1')]),
])
result = wf.run()

# the two expanded pre-join predecessor nodes feed into one join node
Expand All @@ -185,8 +192,8 @@ def test_join_expansion(tmpdir):
# the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
# node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
# Nipype factors away the IdentityInterface.
assert len(
result.nodes()) == 8, "The number of expanded nodes is incorrect."
assert len(result.nodes()) == 8, "The number of expanded nodes is incorrect."

# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
assert len(_sums) == 1, "The number of join outputs is incorrect"
assert _sums[
Expand All @@ -197,6 +204,7 @@ def test_join_expansion(tmpdir):
# there are two iterations of the post-join node in the iterable path
assert len(_products) == 2,\
"The number of iterated post-join outputs is incorrect"
config.set('execution', 'remove_unnecessary_outputs', prev_state)


def test_node_joinsource(tmpdir):
Expand Down