Skip to content

ENH: REF: SSHDataGrabber should grab related file #2104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ python:
env:
- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler" CI_SKIP_TEST=1
- INSTALL_DEB_DEPENDECIES=false NIPYPE_EXTRAS="doc,tests,fmri,profiler" CI_SKIP_TEST=1
- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler,duecredit" CI_SKIP_TEST=1
- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler,duecredit,ssh" CI_SKIP_TEST=1
- INSTALL_DEB_DEPENDECIES=true NIPYPE_EXTRAS="doc,tests,fmri,profiler" PIP_FLAGS="--pre" CI_SKIP_TEST=1

addons:
Expand Down
2 changes: 1 addition & 1 deletion docker/generate_dockerfiles.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function generate_main_dockerfile() {
--arg PYTHON_VERSION_MAJOR=3 PYTHON_VERSION_MINOR=6 BUILD_DATE VCS_REF VERSION \
--miniconda env_name=neuro \
conda_install='python=${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}
icu=58.1 libxml2 libxslt matplotlib mkl numpy
icu=58.1 libxml2 libxslt matplotlib mkl numpy paramiko
pandas psutil scikit-learn scipy traits=4.6.0' \
pip_opts="-e" \
pip_install="/src/nipype[all]" \
Expand Down
3 changes: 2 additions & 1 deletion nipype/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def get_nipype_gitversion():
'profiler': ['psutil>=5.0'],
'duecredit': ['duecredit'],
'xvfbwrapper': ['xvfbwrapper'],
'pybids': ['pybids']
'pybids': ['pybids'],
'ssh': ['paramiko'],
# 'mesh': ['mayavi'] # Enable when it works
}

Expand Down
134 changes: 74 additions & 60 deletions nipype/interfaces/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@
import shutil
import subprocess
import re
import copy
import tempfile
from os.path import join, dirname
from warnings import warn

import sqlite3

from .. import config, logging
from ..utils.filemanip import copyfile, list_to_filename, filename_to_list
from ..utils.filemanip import (
copyfile, list_to_filename, filename_to_list,
get_related_files, related_filetype_sets)
from ..utils.misc import human_order_sorted, str2bool
from .base import (
TraitedSpec, traits, Str, File, Directory, BaseInterface, InputMultiPath,
Expand Down Expand Up @@ -2412,6 +2415,65 @@ def __init__(self, infields=None, outfields=None, **kwargs):
and self.inputs.template[-1] != '$'):
self.inputs.template += '$'

def _get_files_over_ssh(self, template):
"""Get the files matching template over an SSH connection."""
# Connect over SSH
client = self._get_ssh_client()
sftp = client.open_sftp()
sftp.chdir(self.inputs.base_directory)

# Get all files in the dir, and filter for desired files
template_dir = os.path.dirname(template)
template_base = os.path.basename(template)
every_file_in_dir = sftp.listdir(template_dir)
if self.inputs.template_expression == 'fnmatch':
outfiles = fnmatch.filter(every_file_in_dir, template_base)
elif self.inputs.template_expression == 'regexp':
regexp = re.compile(template_base)
outfiles = list(filter(regexp.match, every_file_in_dir))
else:
raise ValueError('template_expression value invalid')

if len(outfiles) == 0:
# no files
msg = 'Output template: %s returned no files' % template
if self.inputs.raise_on_empty:
raise IOError(msg)
else:
warn(msg)

# return value
outfiles = None

else:
# found files, sort and save to outputs
if self.inputs.sort_filelist:
outfiles = human_order_sorted(outfiles)

# actually download the files, if desired
if self.inputs.download_files:
files_to_download = copy.copy(outfiles) # make sure new list!

# check to see if there are any related files to download
for file_to_download in files_to_download:
related_to_current = get_related_files(
file_to_download, include_this_file=False)
existing_related_not_downloading = [
f for f in related_to_current
if f in every_file_in_dir and f not in files_to_download]
files_to_download.extend(existing_related_not_downloading)

for f in files_to_download:
try:
sftp.get(os.path.join(template_dir, f), f)
except IOError:
iflogger.info('remote file %s not found' % f)

# return value
outfiles = list_to_filename(outfiles)

return outfiles

def _list_outputs(self):
try:
paramiko
Expand Down Expand Up @@ -2439,32 +2501,10 @@ def _list_outputs(self):
isdefined(self.inputs.field_template) and \
key in self.inputs.field_template:
template = self.inputs.field_template[key]

if not args:
client = self._get_ssh_client()
sftp = client.open_sftp()
sftp.chdir(self.inputs.base_directory)
filelist = sftp.listdir()
if self.inputs.template_expression == 'fnmatch':
filelist = fnmatch.filter(filelist, template)
elif self.inputs.template_expression == 'regexp':
regexp = re.compile(template)
filelist = list(filter(regexp.match, filelist))
else:
raise ValueError('template_expression value invalid')
if len(filelist) == 0:
msg = 'Output key: %s Template: %s returned no files' % (
key, template)
if self.inputs.raise_on_empty:
raise IOError(msg)
else:
warn(msg)
else:
if self.inputs.sort_filelist:
filelist = human_order_sorted(filelist)
outputs[key] = list_to_filename(filelist)
if self.inputs.download_files:
for f in filelist:
sftp.get(f, f)
outputs[key] = self._get_files_over_ssh(template)

for argnum, arglist in enumerate(args):
maxlen = 1
for arg in arglist:
Expand Down Expand Up @@ -2498,44 +2538,18 @@ def _list_outputs(self):
e.message +
": Template %s failed to convert with args %s"
% (template, str(tuple(argtuple))))
client = self._get_ssh_client()
sftp = client.open_sftp()
sftp.chdir(self.inputs.base_directory)
filledtemplate_dir = os.path.dirname(filledtemplate)
filledtemplate_base = os.path.basename(filledtemplate)
filelist = sftp.listdir(filledtemplate_dir)
if self.inputs.template_expression == 'fnmatch':
outfiles = fnmatch.filter(filelist,
filledtemplate_base)
elif self.inputs.template_expression == 'regexp':
regexp = re.compile(filledtemplate_base)
outfiles = list(filter(regexp.match, filelist))
else:
raise ValueError('template_expression value invalid')
if len(outfiles) == 0:
msg = 'Output key: %s Template: %s returned no files' % (
key, filledtemplate)
if self.inputs.raise_on_empty:
raise IOError(msg)
else:
warn(msg)
outputs[key].append(None)
else:
if self.inputs.sort_filelist:
outfiles = human_order_sorted(outfiles)
outputs[key].append(list_to_filename(outfiles))
if self.inputs.download_files:
for f in outfiles:
try:
sftp.get(
os.path.join(filledtemplate_dir, f), f)
except IOError:
iflogger.info('remote file %s not found',
f)

outputs[key].append(self._get_files_over_ssh(filledtemplate))

# disclude where there was any invalid matches
if any([val is None for val in outputs[key]]):
outputs[key] = []

# no outputs is None, not empty list
if len(outputs[key]) == 0:
outputs[key] = None

# one output is the item, not a list
elif len(outputs[key]) == 1:
outputs[key] = outputs[key][0]

Expand Down
78 changes: 77 additions & 1 deletion nipype/interfaces/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from builtins import str, zip, range, open
from future import standard_library
import os
import copy
import simplejson
import glob
import shutil
Expand Down Expand Up @@ -37,6 +38,32 @@
except ImportError:
noboto3 = True

# Check for paramiko
try:
import paramiko
no_paramiko = False

# Check for localhost SSH Server
# FIXME: Tests requiring this are never run on CI
try:
proxy = None
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect('127.0.0.1', username=os.getenv('USER'), sock=proxy,
timeout=10)

no_local_ssh = False

except (paramiko.SSHException,
paramiko.ssh_exception.NoValidConnectionsError,
OSError):
no_local_ssh = True

except ImportError:
no_paramiko = True
no_local_ssh = True

# Check for fakes3
standard_library.install_aliases()
from subprocess import check_call, CalledProcessError
Expand Down Expand Up @@ -316,7 +343,7 @@ def test_datasink_to_s3(dummy_input, tmpdir):
aws_access_key_id='mykey',
aws_secret_access_key='mysecret',
service_name='s3',
endpoint_url='http://localhost:4567',
endpoint_url='http://127.0.0.1:4567',
use_ssl=False)
resource.meta.client.meta.events.unregister('before-sign.s3', fix_s3_host)

Expand Down Expand Up @@ -611,3 +638,52 @@ def test_bids_infields_outfields(tmpdir):
bg = nio.BIDSDataGrabber()
for outfield in ['anat', 'func']:
assert outfield in bg._outputs().traits()


@pytest.mark.skipif(no_paramiko, reason="paramiko library is not available")
@pytest.mark.skipif(no_local_ssh, reason="SSH Server is not running")
def test_SSHDataGrabber(tmpdir):
"""Test SSHDataGrabber by connecting to localhost and collecting some data.
"""
old_cwd = tmpdir.chdir()

source_dir = tmpdir.mkdir('source')
source_hdr = source_dir.join('somedata.hdr')
source_dat = source_dir.join('somedata.img')
source_hdr.ensure() # create
source_dat.ensure() # create

# ssh client that connects to localhost, current user, regardless of
# ~/.ssh/config
def _mock_get_ssh_client(self):
proxy = None
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect('127.0.0.1', username=os.getenv('USER'), sock=proxy,
timeout=10)
return client
MockSSHDataGrabber = copy.copy(nio.SSHDataGrabber)
MockSSHDataGrabber._get_ssh_client = _mock_get_ssh_client

# grabber to get files from source_dir matching test.hdr
ssh_grabber = MockSSHDataGrabber(infields=['test'],
outfields=['test_file'])
ssh_grabber.inputs.base_directory = str(source_dir)
ssh_grabber.inputs.hostname = '127.0.0.1'
ssh_grabber.inputs.field_template = dict(test_file='%s.hdr')
ssh_grabber.inputs.template = ''
ssh_grabber.inputs.template_args = dict(test_file=[['test']])
ssh_grabber.inputs.test = 'somedata'
ssh_grabber.inputs.sort_filelist = True

runtime = ssh_grabber.run()

# did we successfully get the header?
assert runtime.outputs.test_file == str(tmpdir.join(source_hdr.basename))
# did we successfully get the data?
assert (tmpdir.join(source_hdr.basename) # header file
.new(ext='.img') # data file
.check(file=True, exists=True)) # exists?

old_cwd.chdir()