Skip to content

Commit aec1b98

Browse files
committed
TST: Parametrize over config needed_outputs JoinNode expansion tests
1 parent 6f2f94f commit aec1b98

File tree

1 file changed

+32
-26
lines changed

1 file changed

+32
-26
lines changed

nipype/pipeline/engine/tests/test_join.py

+32-26
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from __future__ import (print_function, division, unicode_literals,
77
absolute_import)
88
from builtins import open
9+
import pytest
910

11+
from .... import config
1012
from ... import engine as pe
1113
from ....interfaces import base as nib
1214
from ....interfaces.utility import IdentityInterface, Function, Merge
@@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
4547
output1 = nib.traits.Int(desc='ouput')
4648

4749

48-
class IncrementInterface(nib.BaseInterface):
50+
class IncrementInterface(nib.SimpleInterface):
4951
input_spec = IncrementInputSpec
5052
output_spec = IncrementOutputSpec
5153

5254
def _run_interface(self, runtime):
5355
runtime.returncode = 0
56+
self._results['output1'] = self.inputs.input1 + self.inputs.inc
5457
return runtime
5558

56-
def _list_outputs(self):
57-
outputs = self._outputs().get()
58-
outputs['output1'] = self.inputs.input1 + self.inputs.inc
59-
return outputs
60-
6159

6260
_sums = []
6361

@@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
7371
operands = nib.traits.List(nib.traits.Int, desc='operands')
7472

7573

76-
class SumInterface(nib.BaseInterface):
74+
class SumInterface(nib.SimpleInterface):
7775
input_spec = SumInputSpec
7876
output_spec = SumOutputSpec
7977

8078
def _run_interface(self, runtime):
81-
runtime.returncode = 0
82-
return runtime
83-
84-
def _list_outputs(self):
8579
global _sum
8680
global _sum_operands
87-
outputs = self._outputs().get()
88-
outputs['operands'] = self.inputs.input1
89-
_sum_operands.append(outputs['operands'])
90-
outputs['output1'] = sum(self.inputs.input1)
91-
_sums.append(outputs['output1'])
92-
return outputs
81+
runtime.returncode = 0
82+
self._results['operands'] = self.inputs.input1
83+
self._results['output1'] = sum(self.inputs.input1)
84+
_sum_operands.append(self.inputs.input1)
85+
_sums.append(sum(self.inputs.input1))
86+
return runtime
9387

9488

9589
_set_len = None
@@ -148,35 +142,42 @@ def _list_outputs(self):
148142
return outputs
149143

150144

151-
def test_join_expansion(tmpdir):
145+
@pytest.mark.parametrize('needed_outputs', [True, False])
146+
def test_join_expansion(tmpdir, needed_outputs):
147+
global _sums
148+
global _sum_operands
149+
global _products
152150
tmpdir.chdir()
153151

152+
config.set('execution', 'remove_unnecessary_outputs', ['false', 'true'][needed_outputs])
154153
# Make the workflow.
155154
wf = pe.Workflow(name='test')
156155
# the iterated input node
157156
inputspec = pe.Node(IdentityInterface(fields=['n']), name='inputspec')
158157
inputspec.iterables = [('n', [1, 2])]
159158
# a pre-join node in the iterated path
160159
pre_join1 = pe.Node(IncrementInterface(), name='pre_join1')
161-
wf.connect(inputspec, 'n', pre_join1, 'input1')
162160
# another pre-join node in the iterated path
163161
pre_join2 = pe.Node(IncrementInterface(), name='pre_join2')
164-
wf.connect(pre_join1, 'output1', pre_join2, 'input1')
165162
# the join node
166163
join = pe.JoinNode(
167164
SumInterface(),
168165
joinsource='inputspec',
169166
joinfield='input1',
170167
name='join')
171-
wf.connect(pre_join2, 'output1', join, 'input1')
172168
# an uniterated post-join node
173169
post_join1 = pe.Node(IncrementInterface(), name='post_join1')
174-
wf.connect(join, 'output1', post_join1, 'input1')
175170
# a post-join node in the iterated path
176171
post_join2 = pe.Node(ProductInterface(), name='post_join2')
177-
wf.connect(join, 'output1', post_join2, 'input1')
178-
wf.connect(pre_join1, 'output1', post_join2, 'input2')
179172

173+
wf.connect([
174+
(inputspec, pre_join1, [('n', 'input1')]),
175+
(pre_join1, pre_join2, [('output1', 'input1')]),
176+
(pre_join1, post_join2, [('output1', 'input2')]),
177+
(pre_join2, join, [('output1', 'input1')]),
178+
(join, post_join1, [('output1', 'input1')]),
179+
(join, post_join2, [('output1', 'input1')]),
180+
])
180181
result = wf.run()
181182

182183
# the two expanded pre-join predecessor nodes feed into one join node
@@ -185,8 +186,8 @@ def test_join_expansion(tmpdir):
185186
# the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
186187
# node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
187188
# Nipype factors away the IdentityInterface.
188-
assert len(
189-
result.nodes()) == 8, "The number of expanded nodes is incorrect."
189+
assert len(result.nodes()) == 8, "The number of expanded nodes is incorrect."
190+
190191
# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
191192
assert len(_sums) == 1, "The number of join outputs is incorrect"
192193
assert _sums[
@@ -198,6 +199,11 @@ def test_join_expansion(tmpdir):
198199
assert len(_products) == 2,\
199200
"The number of iterated post-join outputs is incorrect"
200201

202+
# Clean up so that parametrized tests work out
203+
_products = []
204+
_sum_operands = []
205+
_sums = []
206+
201207

202208
def test_node_joinsource(tmpdir):
203209
"""Test setting the joinsource to a Node."""

0 commit comments

Comments
 (0)