diff --git a/.travis.yml b/.travis.yml index d57a4205d2..fa5c199a8f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,7 +12,7 @@ python: env: - INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler" CI_SKIP_TEST=1 - INSTALL_DEB_DEPENDECIES=false NIPYPE_EXTRAS="doc,tests,fmri,profiler" CI_SKIP_TEST=1 -- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler,duecredit" CI_SKIP_TEST=1 +- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler,duecredit,ssh" CI_SKIP_TEST=1 - INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler" PIP_FLAGS="--pre" CI_SKIP_TEST=1 addons: diff --git a/docker/generate_dockerfiles.sh b/docker/generate_dockerfiles.sh index 2f0ed5eaa7..e1f21b3139 100755 --- a/docker/generate_dockerfiles.sh +++ b/docker/generate_dockerfiles.sh @@ -103,7 +103,7 @@ function generate_main_dockerfile() { --arg PYTHON_VERSION_MAJOR=3 PYTHON_VERSION_MINOR=6 BUILD_DATE VCS_REF VERSION \ --miniconda env_name=neuro \ conda_install='python=${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR} - icu=58.1 libxml2 libxslt matplotlib mkl numpy + icu=58.1 libxml2 libxslt matplotlib mkl numpy paramiko pandas psutil scikit-learn scipy traits=4.6.0' \ pip_opts="-e" \ pip_install="/src/nipype[all]" \ diff --git a/nipype/info.py b/nipype/info.py index 8a7e3b4348..93b5dd2eb8 100644 --- a/nipype/info.py +++ b/nipype/info.py @@ -163,7 +163,8 @@ def get_nipype_gitversion(): 'profiler': ['psutil>=5.0'], 'duecredit': ['duecredit'], 'xvfbwrapper': ['xvfbwrapper'], - 'pybids': ['pybids'] + 'pybids': ['pybids'], + 'ssh': ['paramiko'], # 'mesh': ['mayavi'] # Enable when it works } diff --git a/nipype/interfaces/io.py b/nipype/interfaces/io.py index c4fdd75216..ee8b1a28c7 100644 --- a/nipype/interfaces/io.py +++ b/nipype/interfaces/io.py @@ -31,6 +31,7 @@ import shutil import subprocess import re +import copy import tempfile from os.path import join, dirname from warnings import warn @@ -38,7 +39,9 @@ import sqlite3 from .. import config, logging -from ..utils.filemanip import copyfile, list_to_filename, filename_to_list +from ..utils.filemanip import ( + copyfile, list_to_filename, filename_to_list, + get_related_files, related_filetype_sets) from ..utils.misc import human_order_sorted, str2bool from .base import ( TraitedSpec, traits, Str, File, Directory, BaseInterface, InputMultiPath, @@ -2412,6 +2415,65 @@ def __init__(self, infields=None, outfields=None, **kwargs): and self.inputs.template[-1] != '$'): self.inputs.template += '$' + def _get_files_over_ssh(self, template): + """Get the files matching template over an SSH connection.""" + # Connect over SSH + client = self._get_ssh_client() + sftp = client.open_sftp() + sftp.chdir(self.inputs.base_directory) + + # Get all files in the dir, and filter for desired files + template_dir = os.path.dirname(template) + template_base = os.path.basename(template) + every_file_in_dir = sftp.listdir(template_dir) + if self.inputs.template_expression == 'fnmatch': + outfiles = fnmatch.filter(every_file_in_dir, template_base) + elif self.inputs.template_expression == 'regexp': + regexp = re.compile(template_base) + outfiles = list(filter(regexp.match, every_file_in_dir)) + else: + raise ValueError('template_expression value invalid') + + if len(outfiles) == 0: + # no files + msg = 'Output template: %s returned no files' % template + if self.inputs.raise_on_empty: + raise IOError(msg) + else: + warn(msg) + + # return value + outfiles = None + + else: + # found files, sort and save to outputs + if self.inputs.sort_filelist: + outfiles = human_order_sorted(outfiles) + + # actually download the files, if desired + if self.inputs.download_files: + files_to_download = copy.copy(outfiles) # make sure new list! + + # check to see if there are any related files to download + for file_to_download in files_to_download: + related_to_current = get_related_files( + file_to_download, include_this_file=False) + existing_related_not_downloading = [ + f for f in related_to_current + if f in every_file_in_dir and f not in files_to_download] + files_to_download.extend(existing_related_not_downloading) + + for f in files_to_download: + try: + sftp.get(os.path.join(template_dir, f), f) + except IOError: + iflogger.info('remote file %s not found' % f) + + # return value + outfiles = list_to_filename(outfiles) + + return outfiles + def _list_outputs(self): try: paramiko @@ -2439,32 +2501,10 @@ def _list_outputs(self): isdefined(self.inputs.field_template) and \ key in self.inputs.field_template: template = self.inputs.field_template[key] + if not args: - client = self._get_ssh_client() - sftp = client.open_sftp() - sftp.chdir(self.inputs.base_directory) - filelist = sftp.listdir() - if self.inputs.template_expression == 'fnmatch': - filelist = fnmatch.filter(filelist, template) - elif self.inputs.template_expression == 'regexp': - regexp = re.compile(template) - filelist = list(filter(regexp.match, filelist)) - else: - raise ValueError('template_expression value invalid') - if len(filelist) == 0: - msg = 'Output key: %s Template: %s returned no files' % ( - key, template) - if self.inputs.raise_on_empty: - raise IOError(msg) - else: - warn(msg) - else: - if self.inputs.sort_filelist: - filelist = human_order_sorted(filelist) - outputs[key] = list_to_filename(filelist) - if self.inputs.download_files: - for f in filelist: - sftp.get(f, f) + outputs[key] = self._get_files_over_ssh(template) + for argnum, arglist in enumerate(args): maxlen = 1 for arg in arglist: @@ -2498,44 +2538,18 @@ def _list_outputs(self): e.message + ": Template %s failed to convert with args %s" % (template, str(tuple(argtuple)))) - client = self._get_ssh_client() - sftp = client.open_sftp() - sftp.chdir(self.inputs.base_directory) - filledtemplate_dir = os.path.dirname(filledtemplate) - filledtemplate_base = os.path.basename(filledtemplate) - filelist = sftp.listdir(filledtemplate_dir) - if self.inputs.template_expression == 'fnmatch': - outfiles = fnmatch.filter(filelist, - filledtemplate_base) - elif self.inputs.template_expression == 'regexp': - regexp = re.compile(filledtemplate_base) - outfiles = list(filter(regexp.match, filelist)) - else: - raise ValueError('template_expression value invalid') - if len(outfiles) == 0: - msg = 'Output key: %s Template: %s returned no files' % ( - key, filledtemplate) - if self.inputs.raise_on_empty: - raise IOError(msg) - else: - warn(msg) - outputs[key].append(None) - else: - if self.inputs.sort_filelist: - outfiles = human_order_sorted(outfiles) - outputs[key].append(list_to_filename(outfiles)) - if self.inputs.download_files: - for f in outfiles: - try: - sftp.get( - os.path.join(filledtemplate_dir, f), f) - except IOError: - iflogger.info('remote file %s not found', - f) + + outputs[key].append(self._get_files_over_ssh(filledtemplate)) + + # disclude where there was any invalid matches if any([val is None for val in outputs[key]]): outputs[key] = [] + + # no outputs is None, not empty list if len(outputs[key]) == 0: outputs[key] = None + + # one output is the item, not a list elif len(outputs[key]) == 1: outputs[key] = outputs[key][0] diff --git a/nipype/interfaces/tests/test_io.py b/nipype/interfaces/tests/test_io.py index 6aafb5b6b9..1eea2b31f3 100644 --- a/nipype/interfaces/tests/test_io.py +++ b/nipype/interfaces/tests/test_io.py @@ -5,6 +5,7 @@ from builtins import str, zip, range, open from future import standard_library import os +import copy import simplejson import glob import shutil @@ -37,6 +38,32 @@ except ImportError: noboto3 = True +# Check for paramiko +try: + import paramiko + no_paramiko = False + + # Check for localhost SSH Server + # FIXME: Tests requiring this are never run on CI + try: + proxy = None + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect('127.0.0.1', username=os.getenv('USER'), sock=proxy, + timeout=10) + + no_local_ssh = False + + except (paramiko.SSHException, + paramiko.ssh_exception.NoValidConnectionsError, + OSError): + no_local_ssh = True + +except ImportError: + no_paramiko = True + no_local_ssh = True + # Check for fakes3 standard_library.install_aliases() from subprocess import check_call, CalledProcessError @@ -316,7 +343,7 @@ def test_datasink_to_s3(dummy_input, tmpdir): aws_access_key_id='mykey', aws_secret_access_key='mysecret', service_name='s3', - endpoint_url='http://localhost:4567', + endpoint_url='http://127.0.0.1:4567', use_ssl=False) resource.meta.client.meta.events.unregister('before-sign.s3', fix_s3_host) @@ -611,3 +638,52 @@ def test_bids_infields_outfields(tmpdir): bg = nio.BIDSDataGrabber() for outfield in ['anat', 'func']: assert outfield in bg._outputs().traits() + + +@pytest.mark.skipif(no_paramiko, reason="paramiko library is not available") +@pytest.mark.skipif(no_local_ssh, reason="SSH Server is not running") +def test_SSHDataGrabber(tmpdir): + """Test SSHDataGrabber by connecting to localhost and collecting some data. + """ + old_cwd = tmpdir.chdir() + + source_dir = tmpdir.mkdir('source') + source_hdr = source_dir.join('somedata.hdr') + source_dat = source_dir.join('somedata.img') + source_hdr.ensure() # create + source_dat.ensure() # create + + # ssh client that connects to localhost, current user, regardless of + # ~/.ssh/config + def _mock_get_ssh_client(self): + proxy = None + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect('127.0.0.1', username=os.getenv('USER'), sock=proxy, + timeout=10) + return client + MockSSHDataGrabber = copy.copy(nio.SSHDataGrabber) + MockSSHDataGrabber._get_ssh_client = _mock_get_ssh_client + + # grabber to get files from source_dir matching test.hdr + ssh_grabber = MockSSHDataGrabber(infields=['test'], + outfields=['test_file']) + ssh_grabber.inputs.base_directory = str(source_dir) + ssh_grabber.inputs.hostname = '127.0.0.1' + ssh_grabber.inputs.field_template = dict(test_file='%s.hdr') + ssh_grabber.inputs.template = '' + ssh_grabber.inputs.template_args = dict(test_file=[['test']]) + ssh_grabber.inputs.test = 'somedata' + ssh_grabber.inputs.sort_filelist = True + + runtime = ssh_grabber.run() + + # did we successfully get the header? + assert runtime.outputs.test_file == str(tmpdir.join(source_hdr.basename)) + # did we successfully get the data? + assert (tmpdir.join(source_hdr.basename) # header file + .new(ext='.img') # data file + .check(file=True, exists=True)) # exists? + + old_cwd.chdir()