Skip to content

[ENH] Add all DIPY workflows dynamically #2905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 18, 2019
Merged
4 changes: 4 additions & 0 deletions .zenodo.json
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@
{
"name": "Haselgrove, Christian"
},
{
"name": "Koudoro, Serge",
"affiliation": "Indiana University, IN, USA"
},
{
"affiliation": "1 McGill Centre for Integrative Neuroscience (MCIN), Ludmer Centre for Neuroinformatics and Mental Health, Montreal Neurological Institute (MNI), McGill University, Montr\u00e9al, 3801 University Street, WB-208, H3A 2B4, Qu\u00e9bec, Canada. 2 University of Lyon, CNRS, INSERM, CREATIS., Villeurbanne, 7, avenue Jean Capelle, 69621, France.",
"name": "Glatard, Tristan",
Expand Down
30 changes: 30 additions & 0 deletions nipype/interfaces/dipy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from ..base import (traits, File, isdefined, LibraryBaseInterface,
BaseInterfaceInputSpec, TraitedSpec)

# List of workflows to ignore
SKIP_WORKFLOWS_LIST = ['Workflow', 'CombinedWorkflow']

HAVE_DIPY = True

try:
import dipy
from dipy.workflows.base import IntrospectiveArgumentParser
Expand Down Expand Up @@ -211,3 +215,29 @@ def _list_outputs(self):
"_run_interface": _run_interface,
"_list_outputs:": _list_outputs})
return newclass


def get_dipy_workflows(module):
"""Search for DIPY workflow class.

Parameters
----------
module : object
module object

Returns
-------
l_wkflw : list of tuple
This a list of tuple containing 2 elements:
Worflow name, Workflow class obj

Examples
--------
>>> from dipy.workflows import align # doctest: +SKIP
>>> get_dipy_workflows(align) # doctest: +SKIP

"""
return [(m, obj) for m, obj in inspect.getmembers(module)
if inspect.isclass(obj) and
issubclass(obj, module.Workflow) and
m not in SKIP_WORKFLOWS_LIST]
17 changes: 16 additions & 1 deletion nipype/interfaces/dipy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,30 @@
import os.path as op
import nibabel as nb
import numpy as np
from distutils.version import LooseVersion

from ...utils import NUMPY_MMAP

from ... import logging
from ..base import (traits, TraitedSpec, File, isdefined)
from .base import DipyBaseInterface
from .base import (HAVE_DIPY, dipy_version, dipy_to_nipype_interface,
get_dipy_workflows, DipyBaseInterface)

IFLOGGER = logging.getLogger('nipype.interface')

if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.15'):
from dipy.workflows import denoise, mask

l_wkflw = get_dipy_workflows(denoise) + get_dipy_workflows(mask)
for name, obj in l_wkflw:
new_name = name.replace('Flow', '')
globals()[new_name] = dipy_to_nipype_interface(new_name, obj)
del l_wkflw

else:
IFLOGGER.info("We advise you to upgrade DIPY version. This upgrade will"
" open access to more function")


class ResampleInputSpec(TraitedSpec):
in_file = File(
Expand Down
22 changes: 11 additions & 11 deletions nipype/interfaces/dipy/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,24 @@
from ... import logging
from ..base import TraitedSpec, File, traits, isdefined
from .base import (DipyDiffusionInterface, DipyBaseInterfaceInputSpec,
HAVE_DIPY, dipy_version, dipy_to_nipype_interface)
HAVE_DIPY, dipy_version, dipy_to_nipype_interface,
get_dipy_workflows)


IFLOGGER = logging.getLogger('nipype.interface')

if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.15'):
from dipy.workflows.reconst import (ReconstDkiFlow, ReconstCSAFlow,
ReconstCSDFlow, ReconstMAPMRIFlow,
ReconstDtiFlow)

DKIModel = dipy_to_nipype_interface("DKIModel", ReconstDkiFlow)
MapmriModel = dipy_to_nipype_interface("MapmriModel", ReconstMAPMRIFlow)
DTIModel = dipy_to_nipype_interface("DTIModel", ReconstDtiFlow)
CSAModel = dipy_to_nipype_interface("CSAModel", ReconstCSAFlow)
CSDModel = dipy_to_nipype_interface("CSDModel", ReconstCSDFlow)
from dipy.workflows import reconst

l_wkflw = get_dipy_workflows(reconst)
for name, obj in l_wkflw:
new_name = name.replace('Flow', '')
globals()[new_name] = dipy_to_nipype_interface(new_name, obj)
del l_wkflw

else:
IFLOGGER.info("We advise you to upgrade DIPY version. This upgrade will"
" activate DKIModel, MapmriModel, DTIModel, CSAModel, CSDModel.")
" open access to more models")


class RESTOREInputSpec(DipyBaseInterfaceInputSpec):
Expand Down
17 changes: 10 additions & 7 deletions nipype/interfaces/dipy/registration.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@

from distutils.version import LooseVersion
from ... import logging
from .base import HAVE_DIPY, dipy_version, dipy_to_nipype_interface
from .base import (HAVE_DIPY, dipy_version, dipy_to_nipype_interface,
get_dipy_workflows)

IFLOGGER = logging.getLogger('nipype.interface')

if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.15'):

from dipy.workflows.align import ResliceFlow, SlrWithQbxFlow
if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.15'):
from dipy.workflows import align

Reslice = dipy_to_nipype_interface("Reslice", ResliceFlow)
StreamlineRegistration = dipy_to_nipype_interface("StreamlineRegistration",
SlrWithQbxFlow)
l_wkflw = get_dipy_workflows(align)
for name, obj in l_wkflw:
new_name = name.replace('Flow', '')
globals()[new_name] = dipy_to_nipype_interface(new_name, obj)
del l_wkflw

else:
IFLOGGER.info("We advise you to upgrade DIPY version. This upgrade will"
" activate Reslice, StreamlineRegistration.")
" open access to more function")
20 changes: 20 additions & 0 deletions nipype/interfaces/dipy/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

from distutils.version import LooseVersion
from ... import logging
from .base import (HAVE_DIPY, dipy_version, dipy_to_nipype_interface,
get_dipy_workflows)

IFLOGGER = logging.getLogger('nipype.interface')

if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.16'):
from dipy.workflows import stats

l_wkflw = get_dipy_workflows(stats)
for name, obj in l_wkflw:
new_name = name.replace('Flow', '')
globals()[new_name] = dipy_to_nipype_interface(new_name, obj)
del l_wkflw

else:
IFLOGGER.info("We advise you to upgrade DIPY version. This upgrade will"
" open access to more function")
13 changes: 12 additions & 1 deletion nipype/interfaces/dipy/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from collections import namedtuple
from ...base import traits, TraitedSpec, BaseInterfaceInputSpec
from ..base import (convert_to_traits_type, create_interface_specs,
dipy_to_nipype_interface, DipyBaseInterface, no_dipy)
dipy_to_nipype_interface, DipyBaseInterface, no_dipy,
get_dipy_workflows)


def test_convert_to_traits_type():
Expand Down Expand Up @@ -136,6 +137,16 @@ def run(self, in_files, param1=1, out_dir='', out_ref='out1.txt'):
new_specs().run()


@pytest.mark.skipif(no_dipy(), reason="DIPY is not installed")
def test_get_dipy_workflows():
from dipy.workflows import align

l_wkflw = get_dipy_workflows(align)
for name, obj in l_wkflw:
assert name.endswith('Flow')
assert issubclass(obj, align.Workflow)


if __name__ == "__main__":
test_convert_to_traits_type()
test_create_interface_specs()
Expand Down
20 changes: 12 additions & 8 deletions nipype/interfaces/dipy/tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,32 @@
from ..base import (TraitedSpec, BaseInterfaceInputSpec, File, isdefined,
traits)
from .base import (DipyBaseInterface, HAVE_DIPY, dipy_version,
dipy_to_nipype_interface)
dipy_to_nipype_interface, get_dipy_workflows)

IFLOGGER = logging.getLogger('nipype.interface')


if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.15'):

from dipy.workflows.segment import RecoBundlesFlow, LabelsBundlesFlow
if HAVE_DIPY and (LooseVersion('0.15') >= LooseVersion(dipy_version()) >= LooseVersion('0.16')):
try:
from dipy.workflows.tracking import LocalFiberTrackingPAMFlow as DetTrackFlow
except ImportError: # different name in 0.15
from dipy.workflows.tracking import DetTrackPAMFlow as DetTrackFlow

RecoBundles = dipy_to_nipype_interface("RecoBundles", RecoBundlesFlow)
LabelsBundles = dipy_to_nipype_interface("LabelsBundles",
LabelsBundlesFlow)
DeterministicTracking = dipy_to_nipype_interface("DeterministicTracking",
DetTrackFlow)

if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.15'):
from dipy.workflows import segment, tracking

l_wkflw = get_dipy_workflows(segment) + get_dipy_workflows(tracking)
for name, obj in l_wkflw:
new_name = name.replace('Flow', '')
globals()[new_name] = dipy_to_nipype_interface(new_name, obj)
del l_wkflw

else:
IFLOGGER.info("We advise you to upgrade DIPY version. This upgrade will"
" activate RecoBundles, LabelsBundles, DeterministicTracking.")
" open access to more function")


class TrackDensityMapInputSpec(BaseInterfaceInputSpec):
Expand Down