39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169 | class MountainSortAnalyzer:
@staticmethod
def sort_recording(
recording: "si.BaseRecording", plot_probe=False, multiprocess_mode: Literal["dask", "serial"] = "serial"
) -> tuple[list["si.BaseSorting"], list["si.BaseRecording"]]:
"""Sort a recording using MountainSort.
Args:
recording (si.BaseRecording): The recording to sort.
plot_probe (bool, optional): Whether to plot the probe. Defaults to False.
multiprocess_mode (Literal["dask", "serial"], optional): Whether to use dask or serial for multiprocessing. Defaults to "serial".
Returns:
list[si.SortingAnalyzer]: A list of independent sorting analyzers, one for each channel.
"""
if not MOUNTAINSORT_AVAILABLE:
raise ImportError(
"MountainSort5 is not available. Spike sorting functionality requires mountainsort5. "
"Install it with: pip install mountainsort5"
)
logging.debug(f"Sorting recording info: {recording}")
logging.debug(f"Sorting recording channel names: {recording.get_channel_ids()}")
if si is None or spre is None:
raise ImportError("spikeinterface is required for sorting")
rec = recording.clone()
probe = MountainSortAnalyzer._get_dummy_probe(rec)
rec = rec.set_probe(probe)
if plot_probe:
_, ax2 = plt.subplots(1, 1)
pi_plotting.plot_probe(probe, ax=ax2, with_device_index=True, with_contact_id=True)
plt.show()
# Get recordings for sorting and waveforms
sort_rec = MountainSortAnalyzer._get_recording_for_sorting(rec)
wave_rec = MountainSortAnalyzer._get_recording_for_waveforms(rec)
# Split recording into separate channels
sort_recs = MountainSortAnalyzer._split_recording(sort_rec)
wave_recs = MountainSortAnalyzer._split_recording(wave_rec)
# Run sorting
match multiprocess_mode:
case "dask":
if dask is None:
raise ImportError("dask is required for multiprocess_mode='dask'")
cached_recs = [dask.delayed(MountainSortAnalyzer._cache_recording)(sort_rec) for sort_rec in sort_recs]
sortings = [dask.delayed(MountainSortAnalyzer._run_sorting)(cached_rec) for cached_rec in cached_recs]
case "serial":
cached_recs = [MountainSortAnalyzer._cache_recording(sort_rec) for sort_rec in sort_recs]
sortings = [MountainSortAnalyzer._run_sorting(cached_rec) for cached_rec in cached_recs]
return sortings, wave_recs
@staticmethod
def _get_dummy_probe(recording: si.BaseRecording) -> pi.Probe:
linprobe = pi.generate_linear_probe(recording.get_num_channels(), ypitch=40)
linprobe.set_device_channel_indices(np.arange(recording.get_num_channels()))
linprobe.set_contact_ids(recording.get_channel_ids())
return linprobe
@staticmethod
def _get_recording_for_sorting(recording: si.BaseRecording) -> si.BaseRecording:
return MountainSortAnalyzer._apply_preprocessing(recording, constants.SORTING_PARAMS)
@staticmethod
def _get_recording_for_waveforms(recording: si.BaseRecording) -> si.BaseRecording:
return MountainSortAnalyzer._apply_preprocessing(recording, constants.WAVEFORM_PARAMS)
@staticmethod
def _apply_preprocessing(recording: si.BaseRecording, params: dict) -> si.BaseRecording:
rec = recording.clone()
if params["notch_freq"]:
rec = spre.notch_filter(rec, freq=params["notch_freq"], q=100)
if params["common_ref"]:
rec = spre.common_reference(rec)
if params["scale"]:
rec = spre.scale(rec, gain=params["scale"]) # Scaling for whitening to work properly
if params["whiten"]:
rec = spre.whiten(rec)
if params["freq_min"] and not params["freq_max"]:
rec = spre.highpass_filter(rec, freq_min=params["freq_min"], ftype="bessel")
elif params["freq_min"] and params["freq_max"]:
rec = spre.bandpass_filter(rec, freq_min=params["freq_min"], freq_max=params["freq_max"], ftype="bessel")
elif not params["freq_min"] and params["freq_max"]:
rec = spre.bandpass_filter(
rec, freq_min=0.1, freq_max=params["freq_max"], ftype="bessel"
) # Spike Interface doesn't have a lowpass filter
return rec
@staticmethod
def _split_recording(recording: si.BaseRecording) -> list[si.BaseRecording]:
rec_preps = []
for channel_id in recording.get_channel_ids():
rec_preps.append(recording.clone().select_channels([channel_id]))
return rec_preps
@staticmethod
def _cache_recording(recording: si.BaseRecording) -> si.BaseRecording:
temp_dir = get_temp_directory() / os.urandom(24).hex()
# dask.distributed.print(f"Caching recording to {temp_dir}")
os.makedirs(temp_dir)
cached_rec = create_cached_recording(recording.clone(), folder=temp_dir, chunk_duration="60s")
cached_rec = spre.astype(cached_rec, dtype=constants.GLOBAL_DTYPE)
return cached_rec
@staticmethod
def _run_sorting(recording: si.BaseRecording) -> si.BaseSorting:
# Confusingly, the snippet_T1 and snippet_T2 parameters in MS are in samples, not seconds
snippet_T1 = constants.SCHEME2_SORTING_PARAMS["snippet_T1"]
snippet_T2 = constants.SCHEME2_SORTING_PARAMS["snippet_T2"]
snippet_T1_samples = round(recording.get_sampling_frequency() * snippet_T1)
snippet_T2_samples = round(recording.get_sampling_frequency() * snippet_T2)
sort_params = Scheme2SortingParameters(
phase1_detect_channel_radius=constants.SCHEME2_SORTING_PARAMS["phase1_detect_channel_radius"],
detect_channel_radius=constants.SCHEME2_SORTING_PARAMS["detect_channel_radius"],
snippet_T1=snippet_T1_samples,
snippet_T2=snippet_T2_samples,
)
# dask.distributed.print(f"recording.dtype: {recording.dtype}")
with _HiddenPrints():
sorting = sorting_scheme2(recording=recording, sorting_parameters=sort_params)
return sorting
|