From 60715187ebcf0c1665df0f55ca4ae23979cff1d5 Mon Sep 17 00:00:00 2001 From: FrancescoNegri Date: Thu, 4 Jul 2024 15:36:33 +0200 Subject: [PATCH 1/7] Add UnitSpatialDistributionsWidget class --- src/spikeinterface/widgets/unit_spatial.py | 135 +++++++++++++++++++++ src/spikeinterface/widgets/widget_list.py | 14 +-- 2 files changed, 140 insertions(+), 9 deletions(-) create mode 100644 src/spikeinterface/widgets/unit_spatial.py diff --git a/src/spikeinterface/widgets/unit_spatial.py b/src/spikeinterface/widgets/unit_spatial.py new file mode 100644 index 0000000000..936e77339b --- /dev/null +++ b/src/spikeinterface/widgets/unit_spatial.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import numpy as np +from probeinterface import Probe +from probeinterface.plotting import get_auto_lims +from seaborn import color_palette +from warnings import warn +from .base import BaseWidget, to_attr + +class UnitSpatialDistributionsWidget(BaseWidget): + """ + Placeholder documentation to be changed. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer object + depth_axis : int, default: 1 + The dimension of unit_locations that is depth + """ + + def __init__( + self, + sorting_analyzer, probe=None, + depth_axis=1, bins=None, + cmap="viridis", kde=False, + depth_hist=True, groups=None, + backend=None, **backend_kwargs + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + + self.check_extensions(sorting_analyzer, "unit_locations") + ulc = sorting_analyzer.get_extension("unit_locations") + unit_locations = ulc.get_data(outputs="numpy") + x, y = unit_locations[:, 0], unit_locations[:, 1] + + if type(probe) is Probe: + if sorting_analyzer.recording.has_probe(): + # TODO: throw warning saying that sorting_analyzer has a probe and it will be overwritten + pass + elif sorting_analyzer.recording.has_probe(): + probe = sorting_analyzer.get_probe() + else: + # TODO: throw error or warning, no probe available + pass + + xrange, yrange, _ = get_auto_lims(probe, margin=0) + if bins is None: + bins = ( + np.round(np.diff(xrange).squeeze() / 75).astype(int), + np.round(np.diff(yrange).squeeze() / 75).astype(int) + ) + + if type(cmap) is str: + cmap = color_palette(cmap, as_cmap=True) + + plot_data = dict( + probe=probe, + x=x, + y=y, + depth_axis=depth_axis, + xrange=xrange, + yrange=yrange, + bins=bins, + kde=kde, + cmap=cmap, + depth_hist=depth_hist, + groups=groups + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.patches as patches + import matplotlib.path as path + from seaborn import kdeplot, histplot + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + + custom_shape = path.Path(dp.probe.probe_planar_contour) + patch = patches.PathPatch(custom_shape, facecolor="none", edgecolor="none") + ax.add_patch(patch) + + if dp.kde is not True: + hist, xedges, yedges = np.histogram2d(dp.x, dp.y, bins=dp.bins, range=[dp.xrange, dp.yrange]) + pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap=dp.cmap) + else: + data = dict(x=dp.x, y=dp.y) + bg = ax.add_patch( + patches.Rectangle( + [dp.xrange[0], dp.yrange[0]], + np.diff(dp.xrange).squeeze(), + np.diff(dp.yrange).squeeze(), + facecolor=dp.cmap.colors[0], + fill=True + ) + ) + bg.set_clip_path(patch) + kdeplot( + data, x='x', y='y', + cmap=dp.cmap, levels=100, thresh=0, fill=True, + ax=ax, bw_adjust=0.1, clip=[dp.xrange, dp.yrange] + ) + pcm = ax.collections[0] + ax.set_xlabel(None) + ax.set_ylabel(None) + + pcm.set_clip_path(patch) + + xlim, ylim, _ = get_auto_lims(dp.probe, margin=10) + ax.set_xlim(*xlim) + ax.set_ylim(*ylim) + ax.spines['top'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.set_xticks([]) + ax.set_xlabel('') + ax.set_ylabel('Depth (um)') + + if dp.depth_hist is True: + bbox = ax.get_window_extent() + hist_height = 1.5 * bbox.width + + ax_hist = ax.inset_axes([1, 0, hist_height / bbox.width, 1]) + data = dict(y=dp.y) + data['group'] = np.ones(dp.y.size) if dp.groups is None else dp.groups + palette = color_palette('bright', n_colors=1 if dp.groups is None else np.unique(dp.groups).size) + histplot(data=data, y='y', hue='group', bins=dp.bins[1], binrange=dp.yrange, palette=palette, ax=ax_hist, legend=False) + ax_hist.axis('off') + ax_hist.set_ylim(*ylim) \ No newline at end of file diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 8163271ec4..8aebe31dd2 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,10 +10,9 @@ from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget -from .motion import DriftRasterMapWidget, MotionWidget, MotionInfoWidget +from .motion import MotionWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget -from .peaks_on_probe import PeaksOnProbeWidget from .potential_merges import PotentialMergesWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget @@ -28,6 +27,7 @@ from .unit_locations import UnitLocationsWidget from .unit_presence import UnitPresenceWidget from .unit_probe_map import UnitProbeMapWidget +from .unit_spatial import UnitSpatialDistributionsWidget from .unit_summary import UnitSummaryWidget from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget @@ -44,15 +44,12 @@ ConfusionMatrixWidget, ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, - DriftRasterMapWidget, ISIDistributionWidget, MotionWidget, - MotionInfoWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget, MultiCompGraphWidget, PeakActivityMapWidget, - PeaksOnProbeWidget, PotentialMergesWidget, ProbeMapWidget, QualityMetricsWidget, @@ -67,6 +64,7 @@ UnitLocationsWidget, UnitPresenceWidget, UnitProbeMapWidget, + UnitSpatialDistributionsWidget, UnitSummaryWidget, UnitTemplatesWidget, UnitWaveformDensityMapWidget, @@ -119,15 +117,12 @@ plot_confusion_matrix = ConfusionMatrixWidget plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget plot_crosscorrelograms = CrossCorrelogramsWidget -plot_drift_raster_map = DriftRasterMapWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget -plot_motion_info = MotionInfoWidget plot_multicomparison_agreement = MultiCompGlobalAgreementWidget plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget plot_multicomparison_graph = MultiCompGraphWidget plot_peak_activity = PeakActivityMapWidget -plot_peaks_on_probe = PeaksOnProbeWidget plot_potential_merges = PotentialMergesWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget @@ -142,6 +137,7 @@ plot_unit_locations = UnitLocationsWidget plot_unit_presence = UnitPresenceWidget plot_unit_probe_map = UnitProbeMapWidget +plot_unit_spatial_distribution = UnitSpatialDistributionsWidget plot_unit_summary = UnitSummaryWidget plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget @@ -156,4 +152,4 @@ def plot_timeseries(*args, **kwargs): warnings.warn("plot_timeseries() is now plot_traces()") - return plot_traces(*args, **kwargs) + return plot_traces(*args, **kwargs) \ No newline at end of file From cfda64986d62d8253ec0e551dbff3be78650415b Mon Sep 17 00:00:00 2001 From: FrancescoNegri Date: Thu, 4 Jul 2024 15:54:25 +0200 Subject: [PATCH 2/7] Add KDE keyword arguments --- src/spikeinterface/widgets/unit_spatial.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/widgets/unit_spatial.py b/src/spikeinterface/widgets/unit_spatial.py index 936e77339b..746aaf0a8e 100644 --- a/src/spikeinterface/widgets/unit_spatial.py +++ b/src/spikeinterface/widgets/unit_spatial.py @@ -25,6 +25,7 @@ def __init__( depth_axis=1, bins=None, cmap="viridis", kde=False, depth_hist=True, groups=None, + kde_kws=None, backend=None, **backend_kwargs ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) @@ -50,6 +51,7 @@ def __init__( np.round(np.diff(xrange).squeeze() / 75).astype(int), np.round(np.diff(yrange).squeeze() / 75).astype(int) ) + # TODO: change behaviour, if bins is not defined, bin only along the depth axis if type(cmap) is str: cmap = color_palette(cmap, as_cmap=True) @@ -65,7 +67,8 @@ def __init__( kde=kde, cmap=cmap, depth_hist=depth_hist, - groups=groups + groups=groups, + kde_kws=kde_kws ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -90,6 +93,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): hist, xedges, yedges = np.histogram2d(dp.x, dp.y, bins=dp.bins, range=[dp.xrange, dp.yrange]) pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap=dp.cmap) else: + kde_kws = dict(levels=100, thresh=0, fill=True, bw_adjust=0.1) + if dp.kde_kws is not None: + kde_kws.update(dp.kde_kws) data = dict(x=dp.x, y=dp.y) bg = ax.add_patch( patches.Rectangle( @@ -103,8 +109,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): bg.set_clip_path(patch) kdeplot( data, x='x', y='y', - cmap=dp.cmap, levels=100, thresh=0, fill=True, - ax=ax, bw_adjust=0.1, clip=[dp.xrange, dp.yrange] + clip=[dp.xrange, dp.yrange], + cmap=dp.cmap, ax=ax, + **kde_kws ) pcm = ax.collections[0] ax.set_xlabel(None) From 64f978812c01af3dcdf6c414f159660f1742d1a0 Mon Sep 17 00:00:00 2001 From: FrancescoNegri Date: Tue, 9 Jul 2024 16:43:29 +0200 Subject: [PATCH 3/7] Revert wrong changes in the widgets list --- src/spikeinterface/widgets/widget_list.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 8aebe31dd2..abd9b700e5 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,9 +10,10 @@ from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget -from .motion import MotionWidget +from .motion import DriftRasterMapWidget, MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget +from .peaks_on_probe import PeaksOnProbeWidget from .potential_merges import PotentialMergesWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget @@ -44,12 +45,15 @@ ConfusionMatrixWidget, ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, + DriftRasterMapWidget, ISIDistributionWidget, MotionWidget, + MotionInfoWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget, MultiCompGraphWidget, PeakActivityMapWidget, + PeaksOnProbeWidget, PotentialMergesWidget, ProbeMapWidget, QualityMetricsWidget, @@ -117,12 +121,15 @@ plot_confusion_matrix = ConfusionMatrixWidget plot_comparison_collision_by_similarity = ComparisonCollisionBySimilarityWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_drift_raster_map = DriftRasterMapWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget +plot_motion_info = MotionInfoWidget plot_multicomparison_agreement = MultiCompGlobalAgreementWidget plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget plot_multicomparison_graph = MultiCompGraphWidget plot_peak_activity = PeakActivityMapWidget +plot_peaks_on_probe = PeaksOnProbeWidget plot_potential_merges = PotentialMergesWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget From 574955adaf638f89a36f254cd526ecae2253f7be Mon Sep 17 00:00:00 2001 From: FrancescoNegri Date: Tue, 9 Jul 2024 16:53:23 +0200 Subject: [PATCH 4/7] Run pre-commit locally --- src/spikeinterface/widgets/unit_spatial.py | 66 +++++++++++++--------- src/spikeinterface/widgets/widget_list.py | 2 +- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/widgets/unit_spatial.py b/src/spikeinterface/widgets/unit_spatial.py index 746aaf0a8e..06f253f07f 100644 --- a/src/spikeinterface/widgets/unit_spatial.py +++ b/src/spikeinterface/widgets/unit_spatial.py @@ -7,6 +7,7 @@ from warnings import warn from .base import BaseWidget, to_attr + class UnitSpatialDistributionsWidget(BaseWidget): """ Placeholder documentation to be changed. @@ -18,15 +19,20 @@ class UnitSpatialDistributionsWidget(BaseWidget): depth_axis : int, default: 1 The dimension of unit_locations that is depth """ - + def __init__( - self, - sorting_analyzer, probe=None, - depth_axis=1, bins=None, - cmap="viridis", kde=False, - depth_hist=True, groups=None, - kde_kws=None, - backend=None, **backend_kwargs + self, + sorting_analyzer, + probe=None, + depth_axis=1, + bins=None, + cmap="viridis", + kde=False, + depth_hist=True, + groups=None, + kde_kws=None, + backend=None, + **backend_kwargs, ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) @@ -44,12 +50,12 @@ def __init__( else: # TODO: throw error or warning, no probe available pass - + xrange, yrange, _ = get_auto_lims(probe, margin=0) if bins is None: bins = ( np.round(np.diff(xrange).squeeze() / 75).astype(int), - np.round(np.diff(yrange).squeeze() / 75).astype(int) + np.round(np.diff(yrange).squeeze() / 75).astype(int), ) # TODO: change behaviour, if bins is not defined, bin only along the depth axis @@ -68,7 +74,7 @@ def __init__( cmap=cmap, depth_hist=depth_hist, groups=groups, - kde_kws=kde_kws + kde_kws=kde_kws, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -103,16 +109,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): np.diff(dp.xrange).squeeze(), np.diff(dp.yrange).squeeze(), facecolor=dp.cmap.colors[0], - fill=True + fill=True, ) ) bg.set_clip_path(patch) - kdeplot( - data, x='x', y='y', - clip=[dp.xrange, dp.yrange], - cmap=dp.cmap, ax=ax, - **kde_kws - ) + kdeplot(data, x="x", y="y", clip=[dp.xrange, dp.yrange], cmap=dp.cmap, ax=ax, **kde_kws) pcm = ax.collections[0] ax.set_xlabel(None) ax.set_ylabel(None) @@ -122,12 +123,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): xlim, ylim, _ = get_auto_lims(dp.probe, margin=10) ax.set_xlim(*xlim) ax.set_ylim(*ylim) - ax.spines['top'].set_visible(False) - ax.spines['bottom'].set_visible(False) - ax.spines['right'].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["right"].set_visible(False) ax.set_xticks([]) - ax.set_xlabel('') - ax.set_ylabel('Depth (um)') + ax.set_xlabel("") + ax.set_ylabel("Depth (um)") if dp.depth_hist is True: bbox = ax.get_window_extent() @@ -135,8 +136,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax_hist = ax.inset_axes([1, 0, hist_height / bbox.width, 1]) data = dict(y=dp.y) - data['group'] = np.ones(dp.y.size) if dp.groups is None else dp.groups - palette = color_palette('bright', n_colors=1 if dp.groups is None else np.unique(dp.groups).size) - histplot(data=data, y='y', hue='group', bins=dp.bins[1], binrange=dp.yrange, palette=palette, ax=ax_hist, legend=False) - ax_hist.axis('off') - ax_hist.set_ylim(*ylim) \ No newline at end of file + data["group"] = np.ones(dp.y.size) if dp.groups is None else dp.groups + palette = color_palette("bright", n_colors=1 if dp.groups is None else np.unique(dp.groups).size) + histplot( + data=data, + y="y", + hue="group", + bins=dp.bins[1], + binrange=dp.yrange, + palette=palette, + ax=ax_hist, + legend=False, + ) + ax_hist.axis("off") + ax_hist.set_ylim(*ylim) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index abd9b700e5..ca4159cabb 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -159,4 +159,4 @@ def plot_timeseries(*args, **kwargs): warnings.warn("plot_timeseries() is now plot_traces()") - return plot_traces(*args, **kwargs) \ No newline at end of file + return plot_traces(*args, **kwargs) From 9e5330a42fd95b93ab799227e576328d6655bceb Mon Sep 17 00:00:00 2001 From: FrancescoNegri Date: Wed, 10 Jul 2024 10:23:37 +0200 Subject: [PATCH 5/7] Add seaborn to widgets dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 2ba53328e7..4b2fe2232c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,7 @@ widgets = [ "matplotlib", "ipympl", "ipywidgets", + "seaborn>=0.13.0", "sortingview>=0.12.0", ] From ee76d37450f9fef902d469cbb65cc3de749b561f Mon Sep 17 00:00:00 2001 From: FrancescoNegri Date: Wed, 10 Jul 2024 10:25:51 +0200 Subject: [PATCH 6/7] Move plotting dependencies inside plot_matplotlib function --- src/spikeinterface/widgets/unit_spatial.py | 42 ++++++++++------------ 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/widgets/unit_spatial.py b/src/spikeinterface/widgets/unit_spatial.py index 06f253f07f..09f4f1afba 100644 --- a/src/spikeinterface/widgets/unit_spatial.py +++ b/src/spikeinterface/widgets/unit_spatial.py @@ -2,8 +2,6 @@ import numpy as np from probeinterface import Probe -from probeinterface.plotting import get_auto_lims -from seaborn import color_palette from warnings import warn from .base import BaseWidget, to_attr @@ -51,24 +49,19 @@ def __init__( # TODO: throw error or warning, no probe available pass - xrange, yrange, _ = get_auto_lims(probe, margin=0) - if bins is None: - bins = ( - np.round(np.diff(xrange).squeeze() / 75).astype(int), - np.round(np.diff(yrange).squeeze() / 75).astype(int), - ) - # TODO: change behaviour, if bins is not defined, bin only along the depth axis - - if type(cmap) is str: - cmap = color_palette(cmap, as_cmap=True) + # xrange, yrange, _ = get_auto_lims(probe, margin=0) + # if bins is None: + # bins = ( + # np.round(np.diff(xrange).squeeze() / 75).astype(int), + # np.round(np.diff(yrange).squeeze() / 75).astype(int), + # ) + # # TODO: change behaviour, if bins is not defined, bin only along the depth axis plot_data = dict( probe=probe, x=x, y=y, depth_axis=depth_axis, - xrange=xrange, - yrange=yrange, bins=bins, kde=kde, cmap=cmap, @@ -82,10 +75,13 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.patches as patches import matplotlib.path as path - from seaborn import kdeplot, histplot + from probeinterface.plotting import get_auto_lims + from seaborn import color_palette, kdeplot, histplot from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) + xrange, yrange, _ = get_auto_lims(dp.probe, margin=0) + cmap = color_palette(dp.cmap, as_cmap=True) if type(dp.cmap) is str else dp.cmap self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) @@ -96,8 +92,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.add_patch(patch) if dp.kde is not True: - hist, xedges, yedges = np.histogram2d(dp.x, dp.y, bins=dp.bins, range=[dp.xrange, dp.yrange]) - pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap=dp.cmap) + hist, xedges, yedges = np.histogram2d(dp.x, dp.y, bins=dp.bins, range=[xrange, yrange]) + pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap=cmap) else: kde_kws = dict(levels=100, thresh=0, fill=True, bw_adjust=0.1) if dp.kde_kws is not None: @@ -105,15 +101,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): data = dict(x=dp.x, y=dp.y) bg = ax.add_patch( patches.Rectangle( - [dp.xrange[0], dp.yrange[0]], - np.diff(dp.xrange).squeeze(), - np.diff(dp.yrange).squeeze(), - facecolor=dp.cmap.colors[0], + [xrange[0], yrange[0]], + np.diff(xrange).squeeze(), + np.diff(yrange).squeeze(), + facecolor=cmap.colors[0], fill=True, ) ) bg.set_clip_path(patch) - kdeplot(data, x="x", y="y", clip=[dp.xrange, dp.yrange], cmap=dp.cmap, ax=ax, **kde_kws) + kdeplot(data, x="x", y="y", clip=[xrange, yrange], cmap=cmap, ax=ax, **kde_kws) pcm = ax.collections[0] ax.set_xlabel(None) ax.set_ylabel(None) @@ -143,7 +139,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): y="y", hue="group", bins=dp.bins[1], - binrange=dp.yrange, + binrange=yrange, palette=palette, ax=ax_hist, legend=False, From e679568e68436ac9c6eddb08b0273c71bee83b02 Mon Sep 17 00:00:00 2001 From: FrancescoNegri Date: Wed, 10 Jul 2024 10:26:54 +0200 Subject: [PATCH 7/7] Add warning and error messages on Probe object --- src/spikeinterface/widgets/unit_spatial.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/unit_spatial.py b/src/spikeinterface/widgets/unit_spatial.py index 09f4f1afba..a7c7e4f1c1 100644 --- a/src/spikeinterface/widgets/unit_spatial.py +++ b/src/spikeinterface/widgets/unit_spatial.py @@ -41,13 +41,15 @@ def __init__( if type(probe) is Probe: if sorting_analyzer.recording.has_probe(): - # TODO: throw warning saying that sorting_analyzer has a probe and it will be overwritten - pass + warn( + "There is a Probe attached to this recording, but the probe argument is not None: the attached Probe will be ignored." + ) elif sorting_analyzer.recording.has_probe(): probe = sorting_analyzer.get_probe() else: - # TODO: throw error or warning, no probe available - pass + raise ValueError( + "There is no Probe attached to this recording. Use set_probe(...) to attach one or pass it to the function via the probe argument." + ) # xrange, yrange, _ = get_auto_lims(probe, margin=0) # if bins is None: