Skip to content

Spike Sorting

MountainSortAnalyzer

Source code in pythoneeg/core/analyze_sort.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class MountainSortAnalyzer:
    @staticmethod
    def sort_recording(
        recording: "si.BaseRecording", plot_probe=False, multiprocess_mode: Literal["dask", "serial"] = "serial"
    ) -> tuple[list["si.BaseSorting"], list["si.BaseRecording"]]:
        """Sort a recording using MountainSort.

        Args:
            recording (si.BaseRecording): The recording to sort.
            plot_probe (bool, optional): Whether to plot the probe. Defaults to False.
            multiprocess_mode (Literal["dask", "serial"], optional): Whether to use dask or serial for multiprocessing. Defaults to "serial".

        Returns:
            list[si.SortingAnalyzer]: A list of independent sorting analyzers, one for each channel.
        """
        if not MOUNTAINSORT_AVAILABLE:
            raise ImportError(
                "MountainSort5 is not available. Spike sorting functionality requires mountainsort5. "
                "Install it with: pip install mountainsort5"
            )

        logging.debug(f"Sorting recording info: {recording}")
        logging.debug(f"Sorting recording channel names: {recording.get_channel_ids()}")

        if si is None or spre is None:
            raise ImportError("spikeinterface is required for sorting")
        rec = recording.clone()
        probe = MountainSortAnalyzer._get_dummy_probe(rec)
        rec = rec.set_probe(probe)

        if plot_probe:
            _, ax2 = plt.subplots(1, 1)
            pi_plotting.plot_probe(probe, ax=ax2, with_device_index=True, with_contact_id=True)
            plt.show()

        # Get recordings for sorting and waveforms
        sort_rec = MountainSortAnalyzer._get_recording_for_sorting(rec)
        wave_rec = MountainSortAnalyzer._get_recording_for_waveforms(rec)

        # Split recording into separate channels
        sort_recs = MountainSortAnalyzer._split_recording(sort_rec)
        wave_recs = MountainSortAnalyzer._split_recording(wave_rec)

        # Run sorting
        match multiprocess_mode:
            case "dask":
                if dask is None:
                    raise ImportError("dask is required for multiprocess_mode='dask'")
                cached_recs = [dask.delayed(MountainSortAnalyzer._cache_recording)(sort_rec) for sort_rec in sort_recs]
                sortings = [dask.delayed(MountainSortAnalyzer._run_sorting)(cached_rec) for cached_rec in cached_recs]
            case "serial":
                cached_recs = [MountainSortAnalyzer._cache_recording(sort_rec) for sort_rec in sort_recs]
                sortings = [MountainSortAnalyzer._run_sorting(cached_rec) for cached_rec in cached_recs]

        return sortings, wave_recs

    @staticmethod
    def _get_dummy_probe(recording: si.BaseRecording) -> pi.Probe:
        linprobe = pi.generate_linear_probe(recording.get_num_channels(), ypitch=40)
        linprobe.set_device_channel_indices(np.arange(recording.get_num_channels()))
        linprobe.set_contact_ids(recording.get_channel_ids())
        return linprobe

    @staticmethod
    def _get_recording_for_sorting(recording: si.BaseRecording) -> si.BaseRecording:
        return MountainSortAnalyzer._apply_preprocessing(recording, constants.SORTING_PARAMS)

    @staticmethod
    def _get_recording_for_waveforms(recording: si.BaseRecording) -> si.BaseRecording:
        return MountainSortAnalyzer._apply_preprocessing(recording, constants.WAVEFORM_PARAMS)

    @staticmethod
    def _apply_preprocessing(recording: si.BaseRecording, params: dict) -> si.BaseRecording:
        rec = recording.clone()

        if params["notch_freq"]:
            rec = spre.notch_filter(rec, freq=params["notch_freq"], q=100)
        if params["common_ref"]:
            rec = spre.common_reference(rec)
        if params["scale"]:
            rec = spre.scale(rec, gain=params["scale"])  # Scaling for whitening to work properly
        if params["whiten"]:
            rec = spre.whiten(rec)

        if params["freq_min"] and not params["freq_max"]:
            rec = spre.highpass_filter(rec, freq_min=params["freq_min"], ftype="bessel")
        elif params["freq_min"] and params["freq_max"]:
            rec = spre.bandpass_filter(rec, freq_min=params["freq_min"], freq_max=params["freq_max"], ftype="bessel")
        elif not params["freq_min"] and params["freq_max"]:
            rec = spre.bandpass_filter(
                rec, freq_min=0.1, freq_max=params["freq_max"], ftype="bessel"
            )  # Spike Interface doesn't have a lowpass filter

        return rec

    @staticmethod
    def _split_recording(recording: si.BaseRecording) -> list[si.BaseRecording]:
        rec_preps = []
        for channel_id in recording.get_channel_ids():
            rec_preps.append(recording.clone().select_channels([channel_id]))
        return rec_preps

    @staticmethod
    def _cache_recording(recording: si.BaseRecording) -> si.BaseRecording:
        temp_dir = get_temp_directory() / os.urandom(24).hex()
        # dask.distributed.print(f"Caching recording to {temp_dir}")
        os.makedirs(temp_dir)
        cached_rec = create_cached_recording(recording.clone(), folder=temp_dir, chunk_duration="60s")
        cached_rec = spre.astype(cached_rec, dtype=constants.GLOBAL_DTYPE)
        return cached_rec

    @staticmethod
    def _run_sorting(recording: si.BaseRecording) -> si.BaseSorting:
        # Confusingly, the snippet_T1 and snippet_T2 parameters in MS are in samples, not seconds
        snippet_T1 = constants.SCHEME2_SORTING_PARAMS["snippet_T1"]
        snippet_T2 = constants.SCHEME2_SORTING_PARAMS["snippet_T2"]
        snippet_T1_samples = round(recording.get_sampling_frequency() * snippet_T1)
        snippet_T2_samples = round(recording.get_sampling_frequency() * snippet_T2)

        sort_params = Scheme2SortingParameters(
            phase1_detect_channel_radius=constants.SCHEME2_SORTING_PARAMS["phase1_detect_channel_radius"],
            detect_channel_radius=constants.SCHEME2_SORTING_PARAMS["detect_channel_radius"],
            snippet_T1=snippet_T1_samples,
            snippet_T2=snippet_T2_samples,
        )

        # dask.distributed.print(f"recording.dtype: {recording.dtype}")
        with _HiddenPrints():
            sorting = sorting_scheme2(recording=recording, sorting_parameters=sort_params)

        return sorting

sort_recording(recording, plot_probe=False, multiprocess_mode='serial') staticmethod

Sort a recording using MountainSort.

Parameters:

Name Type Description Default
recording BaseRecording

The recording to sort.

required
plot_probe bool

Whether to plot the probe. Defaults to False.

False
multiprocess_mode Literal['dask', 'serial']

Whether to use dask or serial for multiprocessing. Defaults to "serial".

'serial'

Returns:

Type Description
tuple[list[BaseSorting], list[BaseRecording]]

list[si.SortingAnalyzer]: A list of independent sorting analyzers, one for each channel.

Source code in pythoneeg/core/analyze_sort.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@staticmethod
def sort_recording(
    recording: "si.BaseRecording", plot_probe=False, multiprocess_mode: Literal["dask", "serial"] = "serial"
) -> tuple[list["si.BaseSorting"], list["si.BaseRecording"]]:
    """Sort a recording using MountainSort.

    Args:
        recording (si.BaseRecording): The recording to sort.
        plot_probe (bool, optional): Whether to plot the probe. Defaults to False.
        multiprocess_mode (Literal["dask", "serial"], optional): Whether to use dask or serial for multiprocessing. Defaults to "serial".

    Returns:
        list[si.SortingAnalyzer]: A list of independent sorting analyzers, one for each channel.
    """
    if not MOUNTAINSORT_AVAILABLE:
        raise ImportError(
            "MountainSort5 is not available. Spike sorting functionality requires mountainsort5. "
            "Install it with: pip install mountainsort5"
        )

    logging.debug(f"Sorting recording info: {recording}")
    logging.debug(f"Sorting recording channel names: {recording.get_channel_ids()}")

    if si is None or spre is None:
        raise ImportError("spikeinterface is required for sorting")
    rec = recording.clone()
    probe = MountainSortAnalyzer._get_dummy_probe(rec)
    rec = rec.set_probe(probe)

    if plot_probe:
        _, ax2 = plt.subplots(1, 1)
        pi_plotting.plot_probe(probe, ax=ax2, with_device_index=True, with_contact_id=True)
        plt.show()

    # Get recordings for sorting and waveforms
    sort_rec = MountainSortAnalyzer._get_recording_for_sorting(rec)
    wave_rec = MountainSortAnalyzer._get_recording_for_waveforms(rec)

    # Split recording into separate channels
    sort_recs = MountainSortAnalyzer._split_recording(sort_rec)
    wave_recs = MountainSortAnalyzer._split_recording(wave_rec)

    # Run sorting
    match multiprocess_mode:
        case "dask":
            if dask is None:
                raise ImportError("dask is required for multiprocess_mode='dask'")
            cached_recs = [dask.delayed(MountainSortAnalyzer._cache_recording)(sort_rec) for sort_rec in sort_recs]
            sortings = [dask.delayed(MountainSortAnalyzer._run_sorting)(cached_rec) for cached_rec in cached_recs]
        case "serial":
            cached_recs = [MountainSortAnalyzer._cache_recording(sort_rec) for sort_rec in sort_recs]
            sortings = [MountainSortAnalyzer._run_sorting(cached_rec) for cached_rec in cached_recs]

    return sortings, wave_recs