Skip to content

Add DetectAndRemoveBadChannelsRecording and DetectAndInterpolateBadChannelsRecording classes #3685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 138 additions & 63 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,142 @@
import numpy as np
from typing import Literal

from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from .filter import highpass_filter
from spikeinterface.core import get_random_data_chunks, order_channels_by_depth, BaseRecording
from spikeinterface.core.channelslice import ChannelSliceRecording

from inspect import signature

_bad_channel_detection_kwargs_doc = """Different methods are implemented:

* std : threhshold on channel standard deviations
If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all
channels standard deviations, the channel is flagged as noisy
* mad : same as std, but using median absolute deviations instead
* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types:
* Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median)
* Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz)
and a high coherence with "far away" channels"
* Out of brain channels are contigious regions of channels dissimilar to the median of all channels
at the top end of the probe (i.e. large channel number)
* neighborhood_r2
A method tuned for LFP use-cases, where channels should be highly correlated with their spatial
neighbors. This method estimates the correlation of each channel with the median of its spatial
neighbors, and considers channels bad when this correlation is too small.

Parameters
----------
recording : BaseRecording
The recording for which bad channels are detected
method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd"
The method to be used for bad channel detection
std_mad_threshold : float, default: 5
The standard deviation/mad multiplier threshold
psd_hf_threshold : float, default: 0.02
For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels.
Channels with average power at >80% Nyquist larger than this threshold
will be labeled as noise
dead_channel_threshold : float, default: -0.5
For coherence+psd - threshold for channel coherence below which channels are labeled as dead
noisy_channel_threshold : float, default: 1
Threshold for channel coherence above which channels are labeled as noisy (together with psd condition)
outside_channel_threshold : float, default: -0.75
For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside
of the brain
outside_channels_location : "top" | "bottom" | "both", default: "top"
For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
n_neighbors : int, default: 11
For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd)
nyquist_threshold : float, default: 0.8
For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared
with psd_hf_threshold
direction : "x" | "y" | "z", default: "y"
For coherence+psd - the depth dimension
highpass_filter_cutoff : float, default: 300
If the recording is not filtered, the cutoff frequency of the highpass filter
chunk_duration_s : float, default: 0.5
Duration of each chunk
num_random_chunks : int, default: 100
Number of random chunks
Having many chunks is important for reproducibility.
welch_window_ms : float, default: 10
Window size for the scipy.signal.welch that will be converted to nperseg
neighborhood_r2_threshold : float, default: 0.95
R^2 threshold for the neighborhood_r2 method.
neighborhood_r2_radius_um : float, default: 30
Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method.
seed : int or None, default: None
The random seed to extract chunks
"""


class DetectAndRemoveBadChannelsRecording(ChannelSliceRecording):
"""
Detects and removes bad channels. If `bad_channel_ids` are given,
the detection is skipped and uses these instead.

{}
bad_channel_ids : np.array | list | None, default: None
If given, these are used rather than being dected.

Returns
-------
removed_bad_channels_recording : DetectAndRemoveBadChannelsRecording
The recording with bad channels removed
"""

_precomputable_kwarg_names = ["bad_channel_ids"]

def __init__(
self,
parent_recording: BaseRecording,
bad_channel_ids=None,
**detect_bad_channels_kwargs,
):

if bad_channel_ids is None:
bad_channel_ids, channel_labels = detect_bad_channels(
recording=parent_recording, **detect_bad_channels_kwargs
)
else:
channel_labels = None

self._main_ids = parent_recording.get_channel_ids()
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, bad_channel_ids)]

ChannelSliceRecording.__init__(
self,
parent_recording=parent_recording,
channel_ids=new_channel_ids,
)

self._kwargs.update({"bad_channel_ids": bad_channel_ids})
if channel_labels is not None:
self._kwargs.update({"channel_labels": channel_labels})

all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs)
self._kwargs.update(all_bad_channels_kwargs)


detect_and_remove_bad_channels = define_function_handling_dict_from_class(
source_class=DetectAndRemoveBadChannelsRecording, name="detect_and_remove_bad_channels"
)
DetectAndRemoveBadChannelsRecording.__doc__ = DetectAndRemoveBadChannelsRecording.__doc__.format(
_bad_channel_detection_kwargs_doc
)


def _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs):
"""Get the default parameters from `detect_bad_channels`, and update with any user-specified parameters."""

sig = signature(detect_bad_channels)
all_detect_bad_channels_kwargs = {k: v.default for k, v in sig.parameters.items() if k != "recording"}
all_detect_bad_channels_kwargs.update(detect_bad_channels_kwargs)
return all_detect_bad_channels_kwargs


def detect_bad_channels(
Expand All @@ -32,69 +166,7 @@ def detect_bad_channels(
Perform bad channel detection.
The recording is assumed to be filtered. If not, a highpass filter is applied on the fly.

Different methods are implemented:

* std : threhshold on channel standard deviations
If the standard deviation of a channel is greater than `std_mad_threshold` times the median of all
channels standard deviations, the channel is flagged as noisy
* mad : same as std, but using median absolute deviations instead
* coeherence+psd : method developed by the International Brain Laboratory that detects bad channels of three types:
* Dead channels are those with low similarity to the surrounding channels (n=`n_neighbors` median)
* Noise channels are those with power at >80% Nyquist above the psd_hf_threshold (default 0.02 uV^2 / Hz)
and a high coherence with "far away" channels"
* Out of brain channels are contigious regions of channels dissimilar to the median of all channels
at the top end of the probe (i.e. large channel number)
* neighborhood_r2
A method tuned for LFP use-cases, where channels should be highly correlated with their spatial
neighbors. This method estimates the correlation of each channel with the median of its spatial
neighbors, and considers channels bad when this correlation is too small.

Parameters
----------
recording : BaseRecording
The recording for which bad channels are detected
method : "coeherence+psd" | "std" | "mad" | "neighborhood_r2", default: "coeherence+psd"
The method to be used for bad channel detection
std_mad_threshold : float, default: 5
The standard deviation/mad multiplier threshold
psd_hf_threshold : float, default: 0.02
For coherence+psd - an absolute threshold (uV^2/Hz) used as a cutoff for noise channels.
Channels with average power at >80% Nyquist larger than this threshold
will be labeled as noise
dead_channel_threshold : float, default: -0.5
For coherence+psd - threshold for channel coherence below which channels are labeled as dead
noisy_channel_threshold : float, default: 1
Threshold for channel coherence above which channels are labeled as noisy (together with psd condition)
outside_channel_threshold : float, default: -0.75
For coherence+psd - threshold for channel coherence above which channels at the edge of the recording are marked as outside
of the brain
outside_channels_location : "top" | "bottom" | "both", default: "top"
For coherence+psd - location of the outside channels. If "top", only the channels at the top of the probe can be
marked as outside channels. If "bottom", only the channels at the bottom of the probe can be
marked as outside channels. If "both", both the channels at the top and bottom of the probe can be
marked as outside channels
n_neighbors : int, default: 11
For coeherence+psd - number of channel neighbors to compute median filter (needs to be odd)
nyquist_threshold : float, default: 0.8
For coherence+psd - frequency with respect to Nyquist (Fn=1) above which the mean of the PSD is calculated and compared
with psd_hf_threshold
direction : "x" | "y" | "z", default: "y"
For coherence+psd - the depth dimension
highpass_filter_cutoff : float, default: 300
If the recording is not filtered, the cutoff frequency of the highpass filter
chunk_duration_s : float, default: 0.5
Duration of each chunk
num_random_chunks : int, default: 100
Number of random chunks
Having many chunks is important for reproducibility.
welch_window_ms : float, default: 10
Window size for the scipy.signal.welch that will be converted to nperseg
neighborhood_r2_threshold : float, default: 0.95
R^2 threshold for the neighborhood_r2 method.
neighborhood_r2_radius_um : float, default: 30
Spatial radius below which two channels are considered neighbors in the neighborhood_r2 method.
seed : int or None, default: None
The random seed to extract chunks
{}

Returns
-------
Expand Down Expand Up @@ -269,6 +341,9 @@ def detect_bad_channels(
return bad_channel_ids, channel_labels


detect_bad_channels.__doc__ = detect_bad_channels.__doc__.format(_bad_channel_detection_kwargs_doc)


# ----------------------------------------------------------------------------------------------
# IBL Detect Bad Channels
# ----------------------------------------------------------------------------------------------
Expand Down
58 changes: 57 additions & 1 deletion src/spikeinterface/preprocessing/interpolate_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

import numpy as np

from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from spikeinterface.preprocessing import preprocessing_tools
from .detect_bad_channels import (
_bad_channel_detection_kwargs_doc,
detect_bad_channels,
_get_all_detect_bad_channel_kwargs,
)
from inspect import signature


class InterpolateBadChannelsRecording(BasePreprocessor):
Expand Down Expand Up @@ -82,6 +88,56 @@ def check_inputs(self, recording, bad_channel_ids):
raise NotImplementedError("Channel spacing units must be um")


class DetectAndInterpolateBadChannelsRecording(InterpolateBadChannelsRecording):
"""
Detects and interpolates bad channels. If `bad_channel_ids` are given,
the detection is skipped and uses these instead.

{}
bad_channel_ids : np.array | list | None, default: None
If given, these are used rather than being dected.

Returns
-------
interpolated_bad_channels_recording : DetectAndInterpolateBadChannelsRecording
The recording with bad channels removed
"""

_precomputable_kwarg_names = ["bad_channel_ids"]

def __init__(
self,
recording: BaseRecording,
bad_channel_ids=None,
**detect_bad_channels_kwargs,
):
if bad_channel_ids is None:
bad_channel_ids, channel_labels = detect_bad_channels(recording=recording, **detect_bad_channels_kwargs)
else:
channel_labels = None

InterpolateBadChannelsRecording.__init__(
self,
recording,
bad_channel_ids=bad_channel_ids,
)

self._kwargs.update({"bad_channel_ids": bad_channel_ids})
if channel_labels is not None:
self._kwargs.update({"channel_labels": channel_labels})

all_bad_channels_kwargs = _get_all_detect_bad_channel_kwargs(detect_bad_channels_kwargs)
self._kwargs.update(all_bad_channels_kwargs)


detect_and_interpolate_bad_channels = define_function_handling_dict_from_class(
source_class=DetectAndInterpolateBadChannelsRecording, name="detect_and_interpolate_bad_channels"
)
DetectAndInterpolateBadChannelsRecording.__doc__ = DetectAndInterpolateBadChannelsRecording.__doc__.format(
_bad_channel_detection_kwargs_doc
)


class InterpolateBadChannelsSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, good_channel_indices, bad_channel_indices, weights):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
Expand Down
10 changes: 9 additions & 1 deletion src/spikeinterface/preprocessing/preprocessinglist.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
from .zero_channel_pad import ZeroChannelPaddedRecording, zero_channel_pad
from .deepinterpolation import DeepInterpolatedRecording, deepinterpolate, train_deepinterpolation
from .highpass_spatial_filter import HighpassSpatialFilterRecording, highpass_spatial_filter
from .interpolate_bad_channels import InterpolateBadChannelsRecording, interpolate_bad_channels
from .interpolate_bad_channels import (
DetectAndInterpolateBadChannelsRecording,
detect_and_interpolate_bad_channels,
InterpolateBadChannelsRecording,
interpolate_bad_channels,
)
from .detect_bad_channels import DetectAndRemoveBadChannelsRecording, detect_and_remove_bad_channels
from .average_across_direction import AverageAcrossDirectionRecording, average_across_direction
from .directional_derivative import DirectionalDerivativeRecording, directional_derivative
from .depth_order import DepthOrderRecording, depth_order
Expand Down Expand Up @@ -80,6 +86,8 @@
DirectionalDerivativeRecording,
AstypeRecording,
UnsignedToSignedRecording,
DetectAndRemoveBadChannelsRecording,
DetectAndInterpolateBadChannelsRecording,
]

preprocesser_dict = {pp_class.name: pp_class for pp_class in preprocessers_full_list}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from spikeinterface import NumpyRecording, get_random_data_chunks
from probeinterface import generate_linear_probe

from spikeinterface.generation import generate_recording

from spikeinterface.core import generate_recording
from spikeinterface.preprocessing import detect_bad_channels, highpass_filter
from spikeinterface.preprocessing import detect_bad_channels, highpass_filter, detect_and_remove_bad_channels

try:
# WARNING : this is not this package https://pypi.org/project/neurodsp/
Expand All @@ -18,6 +20,45 @@
HAVE_NPIX = False


def test_remove_bad_channel():
"""
Generate a recording, then remove bad channels with a low noise threshold,
so that some units are removed. Then check that the new recording has none
of the bad channels still in it and that the one changed kwarg is successfully
propogated to the new recording.
"""

recording = generate_recording(durations=[5, 6], seed=1205, num_channels=8)
recording.set_channel_offsets(0)
recording.set_channel_gains(1)

# set noisy_channel_threshold so that we do detect some bad channels
new_rec = detect_and_remove_bad_channels(recording, noisy_channel_threshold=0, seed=1205)

# make sure they are removed
bad_channel_ids = new_rec._kwargs["bad_channel_ids"]
assert len(set(bad_channel_ids).intersection(new_rec.channel_ids)) == 0
# and the good ones are kept
good_channel_ids = recording.channel_ids[~np.isin(recording.channel_ids, bad_channel_ids)]
assert set(good_channel_ids) == set(new_rec.channel_ids)

# and that the kwarg is propogatged to the kwargs of new_rec.
assert set(new_rec._kwargs["channel_ids"]) == set(good_channel_ids)
assert new_rec._kwargs["noisy_channel_threshold"] == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything to worry about channel ordering here? I guess not (and this is a question for ChannelSliceRecording tests anyways. But as I have no understanding of how the ordering works I thought worth asking 😆

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ordering of the channel ids?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

erm like order of the channels on the recording itself (I'm not sure how exactly this is represented 😅 ). But like the default order when you do plot_traces without order_channel_by_depth


# now apply `detec_bad_channels` directly and see that the outputs matches
bad_channel_ids_from_function, channel_labels_from_function = detect_bad_channels(
recording, noisy_channel_threshold=0, seed=1205
)

assert np.all(new_rec._kwargs["bad_channel_ids"] == bad_channel_ids_from_function)
assert np.all(new_rec._kwargs["channel_labels"] == channel_labels_from_function)

new_rec_from_function = recording.remove_channels(remove_channel_ids=bad_channel_ids_from_function)

assert np.all(new_rec_from_function.channel_ids == new_rec.channel_ids)


def test_detect_bad_channels_std_mad():
num_channels = 4
sampling_frequency = 30000.0
Expand Down
Loading