6
6
from __future__ import (print_function , division , unicode_literals ,
7
7
absolute_import )
8
8
from builtins import open
9
+ import pytest
9
10
11
+ from .... import config
10
12
from ... import engine as pe
11
13
from ....interfaces import base as nib
12
14
from ....interfaces .utility import IdentityInterface , Function , Merge
@@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
45
47
output1 = nib .traits .Int (desc = 'ouput' )
46
48
47
49
48
- class IncrementInterface (nib .BaseInterface ):
50
+ class IncrementInterface (nib .SimpleInterface ):
49
51
input_spec = IncrementInputSpec
50
52
output_spec = IncrementOutputSpec
51
53
52
54
def _run_interface (self , runtime ):
53
55
runtime .returncode = 0
56
+ self ._results ['output1' ] = self .inputs .input1 + self .inputs .inc
54
57
return runtime
55
58
56
- def _list_outputs (self ):
57
- outputs = self ._outputs ().get ()
58
- outputs ['output1' ] = self .inputs .input1 + self .inputs .inc
59
- return outputs
60
-
61
59
62
60
_sums = []
63
61
@@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
73
71
operands = nib .traits .List (nib .traits .Int , desc = 'operands' )
74
72
75
73
76
- class SumInterface (nib .BaseInterface ):
74
+ class SumInterface (nib .SimpleInterface ):
77
75
input_spec = SumInputSpec
78
76
output_spec = SumOutputSpec
79
77
80
78
def _run_interface (self , runtime ):
81
- runtime .returncode = 0
82
- return runtime
83
-
84
- def _list_outputs (self ):
85
79
global _sum
86
80
global _sum_operands
87
- outputs = self . _outputs (). get ()
88
- outputs ['operands' ] = self .inputs .input1
89
- _sum_operands . append ( outputs [ 'operands' ] )
90
- outputs [ 'output1' ] = sum (self .inputs .input1 )
91
- _sums .append (outputs [ 'output1' ] )
92
- return outputs
81
+ runtime . returncode = 0
82
+ self . _results ['operands' ] = self .inputs .input1
83
+ self . _results [ 'output1' ] = sum ( self . inputs . input1 )
84
+ _sum_operands . append (self .inputs .input1 )
85
+ _sums .append (sum ( self . inputs . input1 ) )
86
+ return runtime
93
87
94
88
95
89
_set_len = None
@@ -148,35 +142,42 @@ def _list_outputs(self):
148
142
return outputs
149
143
150
144
151
- def test_join_expansion (tmpdir ):
145
+ @pytest .mark .parametrize ('needed_outputs' , [True , False ])
146
+ def test_join_expansion (tmpdir , needed_outputs ):
147
+ global _sums
148
+ global _sum_operands
149
+ global _products
152
150
tmpdir .chdir ()
153
151
152
+ config .set ('execution' , 'remove_unnecessary_outputs' , ['false' , 'true' ][needed_outputs ])
154
153
# Make the workflow.
155
154
wf = pe .Workflow (name = 'test' )
156
155
# the iterated input node
157
156
inputspec = pe .Node (IdentityInterface (fields = ['n' ]), name = 'inputspec' )
158
157
inputspec .iterables = [('n' , [1 , 2 ])]
159
158
# a pre-join node in the iterated path
160
159
pre_join1 = pe .Node (IncrementInterface (), name = 'pre_join1' )
161
- wf .connect (inputspec , 'n' , pre_join1 , 'input1' )
162
160
# another pre-join node in the iterated path
163
161
pre_join2 = pe .Node (IncrementInterface (), name = 'pre_join2' )
164
- wf .connect (pre_join1 , 'output1' , pre_join2 , 'input1' )
165
162
# the join node
166
163
join = pe .JoinNode (
167
164
SumInterface (),
168
165
joinsource = 'inputspec' ,
169
166
joinfield = 'input1' ,
170
167
name = 'join' )
171
- wf .connect (pre_join2 , 'output1' , join , 'input1' )
172
168
# an uniterated post-join node
173
169
post_join1 = pe .Node (IncrementInterface (), name = 'post_join1' )
174
- wf .connect (join , 'output1' , post_join1 , 'input1' )
175
170
# a post-join node in the iterated path
176
171
post_join2 = pe .Node (ProductInterface (), name = 'post_join2' )
177
- wf .connect (join , 'output1' , post_join2 , 'input1' )
178
- wf .connect (pre_join1 , 'output1' , post_join2 , 'input2' )
179
172
173
+ wf .connect ([
174
+ (inputspec , pre_join1 , [('n' , 'input1' )]),
175
+ (pre_join1 , pre_join2 , [('output1' , 'input1' )]),
176
+ (pre_join1 , post_join2 , [('output1' , 'input2' )]),
177
+ (pre_join2 , join , [('output1' , 'input1' )]),
178
+ (join , post_join1 , [('output1' , 'input1' )]),
179
+ (join , post_join2 , [('output1' , 'input1' )]),
180
+ ])
180
181
result = wf .run ()
181
182
182
183
# the two expanded pre-join predecessor nodes feed into one join node
@@ -185,8 +186,8 @@ def test_join_expansion(tmpdir):
185
186
# the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
186
187
# node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
187
188
# Nipype factors away the IdentityInterface.
188
- assert len (
189
- result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
189
+ assert len (result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
190
+
190
191
# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
191
192
assert len (_sums ) == 1 , "The number of join outputs is incorrect"
192
193
assert _sums [
@@ -198,6 +199,11 @@ def test_join_expansion(tmpdir):
198
199
assert len (_products ) == 2 ,\
199
200
"The number of iterated post-join outputs is incorrect"
200
201
202
+ # Clean up so that parametrized tests work out
203
+ _products = []
204
+ _sum_operands = []
205
+ _sums = []
206
+
201
207
202
208
def test_node_joinsource (tmpdir ):
203
209
"""Test setting the joinsource to a Node."""
0 commit comments