From 8a7c145a8a1abd4c8d63c55eabb32910205053ab Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 12 Jun 2024 13:30:06 +0100 Subject: [PATCH 1/2] Add peaks_on_probe widget and tests. --- src/spikeinterface/widgets/peaks_on_probe.py | 218 +++++++++++++ .../widgets/tests/test_peaks_on_probe.py | 304 ++++++++++++++++++ src/spikeinterface/widgets/widget_list.py | 3 + 3 files changed, 525 insertions(+) create mode 100644 src/spikeinterface/widgets/peaks_on_probe.py create mode 100644 src/spikeinterface/widgets/tests/test_peaks_on_probe.py diff --git a/src/spikeinterface/widgets/peaks_on_probe.py b/src/spikeinterface/widgets/peaks_on_probe.py new file mode 100644 index 0000000000..0d23b6c67e --- /dev/null +++ b/src/spikeinterface/widgets/peaks_on_probe.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import numpy as np + + +from .base import BaseWidget, to_attr + + +class PeaksOnProbeWidget(BaseWidget): + """ + Generate a plot of spike peaks showing their location on a plot + of the probe. Color scaling represents spike amplitude. + + The generated plot overlays the estimated position of a spike peak + (as a single point for each peak) onto a plot of the probe. The + dimensions of the plot are x axis: probe width, y axis: probe depth. + + Plots of different sets of peaks can be created on subplots, by + passing a list of peaks and corresponding peak locations. + + Parameters + ---------- + recording : Recording + A SpikeInterface recording object. + peaks : np.array | list[np.ndarray] + SpikeInterface 'peaks' array created with `detect_peaks()`, + an array of length num_peaks with entries: + (sample_index, channel_index, amplitude, segment_index) + To plot different sets of peaks in subplots, pass a list of peaks, each + with a corresponding entry in a list passed to `peak_locations`. + peak_locations : np.array | list[np.ndarray] + A SpikeInterface 'peak_locations' array created with `localize_peaks()`. + an array of length num_peaks with entries: (x, y) + To plot multiple peaks in subplots, pass a list of `peak_locations` + here with each entry having a corresponding `peaks`. + segment_index : None | int, default: None + If set, only peaks from this recording segment will be used. + time_range : None | Tuple, default: None + The time period over which to include peaks. If `None`, peaks + across the entire recording will be shown. + ylim : None | Tuple, default: None + The y-axis limits (i.e. the probe depth). If `None`, the entire + probe will be displayed. + decimate : int, default: 5 + For performance reasons, every nth peak is shown on the plot, + where n is set by decimate. To plot all peaks, set `decimate=1`. + """ + + def __init__( + self, + recording, + peaks, + peak_locations, + segment_index=None, + time_range=None, + ylim=None, + decimate=5, + backend=None, + **backend_kwargs, + ): + data_plot = dict( + recording=recording, + peaks=peaks, + peak_locations=peak_locations, + segment_index=segment_index, + time_range=time_range, + ylim=ylim, + decimate=decimate, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from spikeinterface.widgets import plot_probe_map + + dp = to_attr(data_plot) + + peaks, peak_locations = self._check_and_format_inputs( + dp.peaks, + dp.peak_locations, + ) + fs = dp.recording.get_sampling_frequency() + num_plots = len(peaks) + + # Set the maximum time to the end time of the longest segment + if dp.time_range is None: + + time_range = self._get_min_and_max_times_in_recording(dp.recording) + else: + time_range = dp.time_range + + ## Create the figure and axes + if backend_kwargs["figsize"] is None: + backend_kwargs.update(dict(figsize=(12, 8))) + + self.figure, self.axes, self.ax = make_mpl_figure(num_axes=num_plots, **backend_kwargs) + self.axes = self.axes[0] + + # Plot each passed peaks / peak_locations over the probe on a separate subplot + for ax_idx, (peaks_to_plot, peak_locs_to_plot) in enumerate(zip(peaks, peak_locations)): + + ax = self.axes[ax_idx] + plot_probe_map(dp.recording, ax=ax) + + time_mask = self._get_peaks_time_mask(dp.recording, time_range, peaks_to_plot) + + if dp.segment_index is not None: + segment_mask = peaks_to_plot["segment_index"] == dp.segment_index + mask = time_mask & segment_mask + else: + mask = time_mask + + if not any(mask): + raise ValueError( + "No peaks within the time and segment mask found. Change `time_range` or `segment_index`" + ) + + # only plot every nth peak + peak_slice = slice(None, None, dp.decimate) + + # Find the amplitudes for the colormap scaling + # (intensity represents amplitude) + amps = np.abs(peaks_to_plot["amplitude"][mask][peak_slice]) + amps /= np.quantile(amps, 0.95) + cmap = plt.get_cmap("inferno")(amps) + color_kwargs = dict(alpha=0.2, s=2, c=cmap) + + # Plot the peaks over the plot, and set the y-axis limits. + ax.scatter( + peak_locs_to_plot["x"][mask][peak_slice], peak_locs_to_plot["y"][mask][peak_slice], **color_kwargs + ) + + if dp.ylim is None: + padding = 25 # arbitary padding just to give some space around highests and lowest peaks on the plot + ylim = (np.min(peak_locs_to_plot["y"]) - padding, np.max(peak_locs_to_plot["y"]) + padding) + else: + ylim = dp.ylim + + ax.set_ylim(ylim[0], ylim[1]) + + self.figure.suptitle(f"Peaks on Probe Plot") + + def _get_peaks_time_mask(self, recording, time_range, peaks_to_plot): + """ + Return a mask of `True` where the peak is within the given time range + and `False` otherwise. + + This is a little complex, as each segment can have different start / + end times. For each segment, find the time bounds relative to that + segment time and fill the `time_mask` one segment at a time. + """ + time_mask = np.zeros(peaks_to_plot.size, dtype=bool) + + for seg_idx in range(recording.get_num_segments()): + + segment = recording.select_segments(seg_idx) + + t_start_sample = segment.time_to_sample_index(time_range[0]) + t_stop_sample = segment.time_to_sample_index(time_range[1]) + + seg_mask = peaks_to_plot["segment_index"] == seg_idx + + time_mask[seg_mask] = (t_start_sample < peaks_to_plot[seg_mask]["sample_index"]) & ( + peaks_to_plot[seg_mask]["sample_index"] < t_stop_sample + ) + + return time_mask + + def _get_min_and_max_times_in_recording(self, recording): + """ + Find the maximum and minimum time across all segments in the recording. + For example if the segment times are (10-100 s, 0 - 50s) the + min and max times are (0, 100) + """ + t_starts = [] + t_stops = [] + for seg_idx in range(recording.get_num_segments()): + + segment = recording.select_segments(seg_idx) + + t_starts.append(segment.sample_index_to_time(0)) + + t_stops.append(segment.sample_index_to_time(segment.get_num_samples() - 1)) + + time_range = (np.min(t_starts), np.max(t_stops)) + + return time_range + + def _check_and_format_inputs(self, peaks, peak_locations): + """ + Check that the inpust are in expected form. Corresponding peaks + and peak_locations of same size and format must be provided. + """ + types_are_list = [isinstance(peaks, list), isinstance(peak_locations, list)] + + if not all(types_are_list): + if any(types_are_list): + raise ValueError("`peaks` and `peak_locations` must either be both lists or both not lists.") + peaks = [peaks] + peak_locations = [peak_locations] + + if len(peaks) != len(peak_locations): + raise ValueError( + "If `peaks` and `peak_locations` are lists, they must contain " + "the same number of (corresponding) peaks and peak locations." + ) + + for idx, (peak, peak_loc) in enumerate(zip(peaks, peak_locations)): + if peak.size != peak_loc.size: + raise ValueError( + f"The number of peaks and peak_locations do not " + f"match for the {idx} input. For each spike peak, there " + f"must be a corresponding peak location" + ) + + return peaks, peak_locations diff --git a/src/spikeinterface/widgets/tests/test_peaks_on_probe.py b/src/spikeinterface/widgets/tests/test_peaks_on_probe.py new file mode 100644 index 0000000000..9820ee5e72 --- /dev/null +++ b/src/spikeinterface/widgets/tests/test_peaks_on_probe.py @@ -0,0 +1,304 @@ +import pytest +from spikeinterface.sortingcomponents.peak_localization import localize_peaks +from spikeinterface.sortingcomponents.peak_detection import detect_peaks +from spikeinterface.widgets import plot_peaks_on_probe +from spikeinterface import generate_ground_truth_recording # TODO: think about imports +import numpy as np + + +class TestPeaksOnProbe: + + @pytest.fixture(scope="session") + def peak_info(self): + """ + Fixture (created only once per test run) of a small + ground truth recording with peaks and peak locations calculated. + """ + recording, _ = generate_ground_truth_recording(num_units=5, num_channels=16, durations=[20, 9], seed=0) + peaks = detect_peaks(recording) + + peak_locations = localize_peaks( + recording, + peaks, + ms_before=0.3, + ms_after=0.6, + method="center_of_mass", + ) + + return (recording, peaks, peak_locations) + + def data_from_widget(self, widget, axes_idx): + """ + Convenience function to get the data of the peaks + that are on the plot (not sure why they are in the + second 'collections'). + """ + return widget.axes[axes_idx].collections[2].get_offsets().data + + def test_peaks_on_probe_main(self, peak_info): + """ + Plot all peaks, and check every peak is plot. + Check the labels are corect. + """ + recording, peaks, peak_locations = peak_info + + widget = plot_peaks_on_probe(recording, peaks, peak_locations, decimate=1) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + ax_y_pos = peak_locations["y"] + + assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) + assert widget.axes[0].get_ylabel() == "y ($\\mu m$)" + assert widget.axes[0].get_xlabel() == "x ($\\mu m$)" + + @pytest.mark.parametrize("segment_index", [0, 1]) + def test_segment_selection(self, peak_info, segment_index): + """ + Check that that when specifying only to plot peaks + from a sepecific segment, that only peaks + from that segment are plot. + """ + recording, peaks, peak_locations = peak_info + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + segment_index=segment_index, + ) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + ax_y_pos = peak_locations["y"][peaks["segment_index"] == segment_index] + + assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) + + def test_multiple_inputs(self, peak_info): + """ + Check that multiple inputs are correctly plot + on separate axes. Do this my creating a copy + of the peaks / peak locations with less peaks + and different locations, for good measure. + Check that these separate peaks / peak locations + are plot on different axes. + """ + recording, peaks, peak_locations = peak_info + + half_num_peaks = int(peaks.shape[0] / 2) + + peaks_change = peaks.copy()[:half_num_peaks] + locs_change = peak_locations.copy()[:half_num_peaks] + locs_change["y"] += 1 + + widget = plot_peaks_on_probe( + recording, + [peaks, peaks_change], + [peak_locations, locs_change], + decimate=1, + ) + + # Test the first entry, axis 0 + ax_0_y_data = self.data_from_widget(widget, 0)[:, 1] + + assert np.array_equal(np.sort(peak_locations["y"]), np.sort(ax_0_y_data)) + + # Test the second entry, axis 1. + ax_1_y_data = self.data_from_widget(widget, 1)[:, 1] + + assert np.array_equal(np.sort(locs_change["y"]), np.sort(ax_1_y_data)) + + def test_times_all(self, peak_info): + """ + Check that when the times of peaks to plot is restricted, + only peaks within the given time range are plot. Set the + limits just before and after the second peak, and check only + that peak is plot. + """ + recording, peaks, peak_locations = peak_info + + peak_idx = 1 + peak_cutoff_low = peaks["sample_index"][peak_idx] - 1 + peak_cutoff_high = peaks["sample_index"][peak_idx] + 1 + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + time_range=( + peak_cutoff_low / recording.get_sampling_frequency(), + peak_cutoff_high / recording.get_sampling_frequency(), + ), + ) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + + assert np.array_equal([peak_locations[peak_idx]["y"]], ax_y_data) + + def test_times_per_segment(self, peak_info): + """ + Test that the time bounds for multi-segment recordings + with different times are handled properly. The time bounds + given must respect the times for each segment. Here, we build + two segments with times 0-100s and 100-200s. We set the + time limits for peaks to plot as 50-150 i.e. all peaks + from the second half of the first segment, and the first half + of the second segment, should be plotted. + + Recompute peaks here for completeness even though this does + duplicate the fixture. + """ + recording, _, _ = peak_info + + first_seg_times = np.linspace(0, 100, recording.get_num_samples(0)) + second_seg_times = np.linspace(100, 200, recording.get_num_samples(1)) + + recording.set_times(first_seg_times, segment_index=0) + recording.set_times(second_seg_times, segment_index=1) + + # After setting the peak times above, re-detect peaks and plot + # with a time range 50-150 s + peaks = detect_peaks(recording) + + peak_locations = localize_peaks( + recording, + peaks, + ms_before=0.3, + ms_after=0.6, + method="center_of_mass", + ) + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + time_range=( + 50, + 150, + ), + ) + + # Find the peaks that are expected to be plot given the time + # restriction (second half of first segment, first half of + # second segment) and check that indeed the expected locations + # are displayed. + seg_one_num_samples = recording.get_num_samples(0) + seg_two_num_samples = recording.get_num_samples(1) + + okay_peaks_one = np.logical_and( + peaks["segment_index"] == 0, peaks["sample_index"] > int(seg_one_num_samples / 2) + ) + okay_peaks_two = np.logical_and( + peaks["segment_index"] == 1, peaks["sample_index"] < int(seg_two_num_samples / 2) + ) + okay_peaks = np.logical_or(okay_peaks_one, okay_peaks_two) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + + assert any(okay_peaks), "someting went wrong in test generation, no peaks within the set time bounds detected" + + assert np.array_equal(np.sort(ax_y_data), np.sort(peak_locations[okay_peaks]["y"])) + + def test_get_min_and_max_times_in_recording(self, peak_info): + """ + Check that the function which finds the minimum and maximum times + across all segments in the recording returns correctly. First + set times of the segments such that the earliest time is 50s and + latest 200s. Check the function returns (50, 200). + """ + recording, peaks, peak_locations = peak_info + + first_seg_times = np.linspace(50, 100, recording.get_num_samples(0)) + second_seg_times = np.linspace(100, 200, recording.get_num_samples(1)) + + recording.set_times(first_seg_times, segment_index=0) + recording.set_times(second_seg_times, segment_index=1) + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + ) + + min_max_times = widget._get_min_and_max_times_in_recording(recording) + + assert min_max_times == (50, 200) + + def test_ylim(self, peak_info): + """ + Specify some y-axis limits (which is the probe height + to show) and check that the plot is restricted to + these limits. + """ + recording, peaks, peak_locations = peak_info + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=1, + ylim=(300, 600), + ) + + assert widget.axes[0].get_ylim() == (300, 600) + + def test_decimate(self, peak_info): + """ + By default, only a subset of peaks are shown for + performance reasons. In tests, decimate is set to 1 + to ensure all peaks are plot. This tests now + checks the decimate argument, to ensure peaks that are + plot are correctly decimated. + """ + recording, peaks, peak_locations = peak_info + + decimate = 5 + + widget = plot_peaks_on_probe( + recording, + peaks, + peak_locations, + decimate=decimate, + ) + + ax_y_data = self.data_from_widget(widget, 0)[:, 1] + ax_y_pos = peak_locations["y"][::decimate] + + assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) + + def test_errors(self, peak_info): + """ + Test all validation errors are raised when data in + incorrect form is passed to the plotting function. + """ + recording, peaks, peak_locations = peak_info + + # All lists must be same length + with pytest.raises(ValueError) as e: + plot_peaks_on_probe( + recording, + [peaks, peaks], + [peak_locations], + ) + + # peaks and corresponding peak locations must be same size + with pytest.raises(ValueError) as e: + plot_peaks_on_probe( + recording, + [peaks[:-1]], + [peak_locations], + ) + + # if one is list, both must be lists + with pytest.raises(ValueError) as e: + plot_peaks_on_probe( + recording, + peaks, + [peak_locations], + ) + + # must have some peaks within the given time / segment + with pytest.raises(ValueError) as e: + plot_peaks_on_probe(recording, [peaks[:-1]], [peak_locations], time_range=(0, 0.001)) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index d6df59b0f3..6367e098ea 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -13,6 +13,7 @@ from .motion import MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget +from .peaks_on_probe import PeaksOnProbeWidget from .potential_merges import PotentialMergesWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget @@ -50,6 +51,7 @@ MultiCompAgreementBySorterWidget, MultiCompGraphWidget, PeakActivityMapWidget, + PeaksOnProbeWidget, PotentialMergesWidget, ProbeMapWidget, QualityMetricsWidget, @@ -123,6 +125,7 @@ plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget plot_multicomparison_graph = MultiCompGraphWidget plot_peak_activity = PeakActivityMapWidget +plot_peaks_on_probe = PeaksOnProbeWidget plot_potential_merges = PotentialMergesWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget From 7ab068b36bcc55d1efd6966051f54b152c4321e2 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 19 Jun 2024 20:00:17 +0100 Subject: [PATCH 2/2] Remove tests. --- .../widgets/tests/test_peaks_on_probe.py | 304 ------------------ 1 file changed, 304 deletions(-) delete mode 100644 src/spikeinterface/widgets/tests/test_peaks_on_probe.py diff --git a/src/spikeinterface/widgets/tests/test_peaks_on_probe.py b/src/spikeinterface/widgets/tests/test_peaks_on_probe.py deleted file mode 100644 index 9820ee5e72..0000000000 --- a/src/spikeinterface/widgets/tests/test_peaks_on_probe.py +++ /dev/null @@ -1,304 +0,0 @@ -import pytest -from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.widgets import plot_peaks_on_probe -from spikeinterface import generate_ground_truth_recording # TODO: think about imports -import numpy as np - - -class TestPeaksOnProbe: - - @pytest.fixture(scope="session") - def peak_info(self): - """ - Fixture (created only once per test run) of a small - ground truth recording with peaks and peak locations calculated. - """ - recording, _ = generate_ground_truth_recording(num_units=5, num_channels=16, durations=[20, 9], seed=0) - peaks = detect_peaks(recording) - - peak_locations = localize_peaks( - recording, - peaks, - ms_before=0.3, - ms_after=0.6, - method="center_of_mass", - ) - - return (recording, peaks, peak_locations) - - def data_from_widget(self, widget, axes_idx): - """ - Convenience function to get the data of the peaks - that are on the plot (not sure why they are in the - second 'collections'). - """ - return widget.axes[axes_idx].collections[2].get_offsets().data - - def test_peaks_on_probe_main(self, peak_info): - """ - Plot all peaks, and check every peak is plot. - Check the labels are corect. - """ - recording, peaks, peak_locations = peak_info - - widget = plot_peaks_on_probe(recording, peaks, peak_locations, decimate=1) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - ax_y_pos = peak_locations["y"] - - assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) - assert widget.axes[0].get_ylabel() == "y ($\\mu m$)" - assert widget.axes[0].get_xlabel() == "x ($\\mu m$)" - - @pytest.mark.parametrize("segment_index", [0, 1]) - def test_segment_selection(self, peak_info, segment_index): - """ - Check that that when specifying only to plot peaks - from a sepecific segment, that only peaks - from that segment are plot. - """ - recording, peaks, peak_locations = peak_info - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - segment_index=segment_index, - ) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - ax_y_pos = peak_locations["y"][peaks["segment_index"] == segment_index] - - assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) - - def test_multiple_inputs(self, peak_info): - """ - Check that multiple inputs are correctly plot - on separate axes. Do this my creating a copy - of the peaks / peak locations with less peaks - and different locations, for good measure. - Check that these separate peaks / peak locations - are plot on different axes. - """ - recording, peaks, peak_locations = peak_info - - half_num_peaks = int(peaks.shape[0] / 2) - - peaks_change = peaks.copy()[:half_num_peaks] - locs_change = peak_locations.copy()[:half_num_peaks] - locs_change["y"] += 1 - - widget = plot_peaks_on_probe( - recording, - [peaks, peaks_change], - [peak_locations, locs_change], - decimate=1, - ) - - # Test the first entry, axis 0 - ax_0_y_data = self.data_from_widget(widget, 0)[:, 1] - - assert np.array_equal(np.sort(peak_locations["y"]), np.sort(ax_0_y_data)) - - # Test the second entry, axis 1. - ax_1_y_data = self.data_from_widget(widget, 1)[:, 1] - - assert np.array_equal(np.sort(locs_change["y"]), np.sort(ax_1_y_data)) - - def test_times_all(self, peak_info): - """ - Check that when the times of peaks to plot is restricted, - only peaks within the given time range are plot. Set the - limits just before and after the second peak, and check only - that peak is plot. - """ - recording, peaks, peak_locations = peak_info - - peak_idx = 1 - peak_cutoff_low = peaks["sample_index"][peak_idx] - 1 - peak_cutoff_high = peaks["sample_index"][peak_idx] + 1 - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - time_range=( - peak_cutoff_low / recording.get_sampling_frequency(), - peak_cutoff_high / recording.get_sampling_frequency(), - ), - ) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - - assert np.array_equal([peak_locations[peak_idx]["y"]], ax_y_data) - - def test_times_per_segment(self, peak_info): - """ - Test that the time bounds for multi-segment recordings - with different times are handled properly. The time bounds - given must respect the times for each segment. Here, we build - two segments with times 0-100s and 100-200s. We set the - time limits for peaks to plot as 50-150 i.e. all peaks - from the second half of the first segment, and the first half - of the second segment, should be plotted. - - Recompute peaks here for completeness even though this does - duplicate the fixture. - """ - recording, _, _ = peak_info - - first_seg_times = np.linspace(0, 100, recording.get_num_samples(0)) - second_seg_times = np.linspace(100, 200, recording.get_num_samples(1)) - - recording.set_times(first_seg_times, segment_index=0) - recording.set_times(second_seg_times, segment_index=1) - - # After setting the peak times above, re-detect peaks and plot - # with a time range 50-150 s - peaks = detect_peaks(recording) - - peak_locations = localize_peaks( - recording, - peaks, - ms_before=0.3, - ms_after=0.6, - method="center_of_mass", - ) - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - time_range=( - 50, - 150, - ), - ) - - # Find the peaks that are expected to be plot given the time - # restriction (second half of first segment, first half of - # second segment) and check that indeed the expected locations - # are displayed. - seg_one_num_samples = recording.get_num_samples(0) - seg_two_num_samples = recording.get_num_samples(1) - - okay_peaks_one = np.logical_and( - peaks["segment_index"] == 0, peaks["sample_index"] > int(seg_one_num_samples / 2) - ) - okay_peaks_two = np.logical_and( - peaks["segment_index"] == 1, peaks["sample_index"] < int(seg_two_num_samples / 2) - ) - okay_peaks = np.logical_or(okay_peaks_one, okay_peaks_two) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - - assert any(okay_peaks), "someting went wrong in test generation, no peaks within the set time bounds detected" - - assert np.array_equal(np.sort(ax_y_data), np.sort(peak_locations[okay_peaks]["y"])) - - def test_get_min_and_max_times_in_recording(self, peak_info): - """ - Check that the function which finds the minimum and maximum times - across all segments in the recording returns correctly. First - set times of the segments such that the earliest time is 50s and - latest 200s. Check the function returns (50, 200). - """ - recording, peaks, peak_locations = peak_info - - first_seg_times = np.linspace(50, 100, recording.get_num_samples(0)) - second_seg_times = np.linspace(100, 200, recording.get_num_samples(1)) - - recording.set_times(first_seg_times, segment_index=0) - recording.set_times(second_seg_times, segment_index=1) - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - ) - - min_max_times = widget._get_min_and_max_times_in_recording(recording) - - assert min_max_times == (50, 200) - - def test_ylim(self, peak_info): - """ - Specify some y-axis limits (which is the probe height - to show) and check that the plot is restricted to - these limits. - """ - recording, peaks, peak_locations = peak_info - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=1, - ylim=(300, 600), - ) - - assert widget.axes[0].get_ylim() == (300, 600) - - def test_decimate(self, peak_info): - """ - By default, only a subset of peaks are shown for - performance reasons. In tests, decimate is set to 1 - to ensure all peaks are plot. This tests now - checks the decimate argument, to ensure peaks that are - plot are correctly decimated. - """ - recording, peaks, peak_locations = peak_info - - decimate = 5 - - widget = plot_peaks_on_probe( - recording, - peaks, - peak_locations, - decimate=decimate, - ) - - ax_y_data = self.data_from_widget(widget, 0)[:, 1] - ax_y_pos = peak_locations["y"][::decimate] - - assert np.array_equal(np.sort(ax_y_data), np.sort(ax_y_pos)) - - def test_errors(self, peak_info): - """ - Test all validation errors are raised when data in - incorrect form is passed to the plotting function. - """ - recording, peaks, peak_locations = peak_info - - # All lists must be same length - with pytest.raises(ValueError) as e: - plot_peaks_on_probe( - recording, - [peaks, peaks], - [peak_locations], - ) - - # peaks and corresponding peak locations must be same size - with pytest.raises(ValueError) as e: - plot_peaks_on_probe( - recording, - [peaks[:-1]], - [peak_locations], - ) - - # if one is list, both must be lists - with pytest.raises(ValueError) as e: - plot_peaks_on_probe( - recording, - peaks, - [peak_locations], - ) - - # must have some peaks within the given time / segment - with pytest.raises(ValueError) as e: - plot_peaks_on_probe(recording, [peaks[:-1]], [peak_locations], time_range=(0, 0.001))