Skip to content

Commit 232ac23

Browse files
authored
Merge pull request #2867 from TheChymera/order
FIX: Sort conditions in bids_gen_info to ensure consistent order
2 parents ef02ff0 + 5fa95f4 commit 232ac23

File tree

3 files changed

+32
-10
lines changed

3 files changed

+32
-10
lines changed

nipype/algorithms/modelgen.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
"""
1414
from __future__ import (print_function, division, unicode_literals,
1515
absolute_import)
16+
17+
str_basetype = str
1618
from builtins import range, str, bytes, int
1719

1820
from copy import deepcopy
19-
import os, math, csv
21+
import csv, math, os
2022

2123
from nibabel import load
2224
import numpy as np
@@ -145,7 +147,7 @@ def scale_timings(timelist, input_units, output_units, time_repetition):
145147
return timelist
146148

147149
def bids_gen_info(bids_event_files,
148-
condition_column='trial_type',
150+
condition_column='',
149151
amplitude_column=None,
150152
time_repetition=False,
151153
):
@@ -173,9 +175,13 @@ def bids_gen_info(bids_event_files,
173175
info = []
174176
for bids_event_file in bids_event_files:
175177
with open(bids_event_file) as f:
176-
f_events = csv.DictReader(f, skipinitialspace=True, delimiter='\t')
178+
f_events = csv.DictReader(f, skipinitialspace=True, delimiter=str_basetype('\t'))
177179
events = [{k: v for k, v in row.items()} for row in f_events]
178-
conditions = list(set([i[condition_column] for i in events]))
180+
if not condition_column:
181+
condition_column = '_trial_type'
182+
for i in events:
183+
i.update({condition_column: 'ev0'})
184+
conditions = sorted(set([i[condition_column] for i in events]))
179185
runinfo = Bunch(conditions=[], onsets=[], durations=[], amplitudes=[])
180186
for condition in conditions:
181187
selected_events = [i for i in events if i[condition_column]==condition]
@@ -185,10 +191,7 @@ def bids_gen_info(bids_event_files,
185191
decimals = math.ceil(-math.log10(time_repetition))
186192
onsets = [np.round(i, decimals) for i in onsets]
187193
durations = [np.round(i ,decimals) for i in durations]
188-
if condition:
189-
runinfo.conditions.append(condition)
190-
else:
191-
runinfo.conditions.append('e0')
194+
runinfo.conditions.append(condition)
192195
runinfo.onsets.append(onsets)
193196
runinfo.durations.append(durations)
194197
try:

nipype/algorithms/tests/test_modelgen.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,19 @@
1111

1212
import pytest
1313
import numpy.testing as npt
14+
from nipype.testing import example_data
1415
from nipype.interfaces.base import Bunch, TraitError
15-
from nipype.algorithms.modelgen import (SpecifyModel, SpecifySparseModel,
16-
SpecifySPMModel)
16+
from nipype.algorithms.modelgen import (bids_gen_info, SpecifyModel,
17+
SpecifySparseModel, SpecifySPMModel)
18+
19+
20+
def test_bids_gen_info():
21+
fname = example_data('events.tsv')
22+
res = bids_gen_info([fname])
23+
assert res[0].onsets == [[183.75, 313.75, 483.75, 633.75, 783.75, 933.75, 1083.75, 1233.75]]
24+
assert res[0].durations == [[20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0]]
25+
assert res[0].amplitudes ==[[1, 1, 1, 1, 1, 1, 1, 1]]
26+
assert res[0].conditions == ['ev0']
1727

1828

1929
def test_modelgen1(tmpdir):

nipype/testing/data/events.tsv

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
onset duration frequency pulse_width amplitude
2+
183.75 20.0 20.0 0.005 1.0
3+
313.75 20.0 20.0 0.005 1.0
4+
483.75 20.0 20.0 0.005 1.0
5+
633.75 20.0 20.0 0.005 1.0
6+
783.75 20.0 20.0 0.005 1.0
7+
933.75 20.0 20.0 0.005 1.0
8+
1083.75 20.0 20.0 0.005 1.0
9+
1233.75 20.0 20.0 0.005 1.0

0 commit comments

Comments
 (0)