from enum import Enum
[docs]
class FeatureType(Enum):
"""Enumeration of feature data shapes.
Each feature in the system has exactly one :class:`FeatureType` that describes
how its data is stored and how common operations (averaging, filtering,
channel remapping, plotting) should handle it.
Using :func:`classify_feature` with this enum eliminates the need for
scattered ``match`` / ``if`` blocks that hard-code feature name strings.
Canonical extracted shapes (after stacking W windows via ``extract_*``):
.. code-block:: text
Type Extracted shape Axis semantics
────────────── ──────────────── ──────────────────────────
LINEAR (W, C) windows, channels
LINEAR_2D (W, C, K) windows, channels, components
BAND (W, C, B) windows, channels, bands
SIMPLE_MATRIX (W, C, C) windows, ch_row, ch_col
BANDED_MATRIX (W, C, C, B) windows, ch_row, ch_col, bands
HIST (W, C, F) windows, channels, freq_bins
**Convention**: ``W`` is always axis 0. Channel axes (``C``) always come
next (axis 1, and axis 2 for matrix types). Feature-specific semantic
dimensions (bands ``B``, components ``K``, frequency bins ``F``) are last.
This means:
- **Channel collapse** is always ``axis=1`` (or ``axis=(1, 2)`` for matrices).
- ``vals[:, ch_idx]`` gives one channel's data for every feature type.
- Matches per-cell storage order: LINEAR stores ``(C,)``, matrices store
``(C, C)``, so stacking windows on axis 0 produces the canonical shape.
Example::
from neurodent.constants import FeatureType, classify_feature
ftype = classify_feature("rms")
assert ftype is FeatureType.LINEAR
assert ftype.extracted_shape == "W, C"
assert ftype.channel_axes == (1,)
assert not ftype.is_dict_stored
"""
LINEAR = "linear"
"""Scalar per channel. Stored as a list of floats (one per channel)."""
LINEAR_2D = "linear_2d"
"""Multi-component per channel. Stored as a 2-D array ``(n_channels, n_components)``."""
BAND = "band"
"""Dict keyed by frequency-band name, values are per-channel arrays."""
BANDED_MATRIX = "banded_matrix"
"""Dict keyed by frequency-band name, values are 2-D (channel × channel) matrices."""
SIMPLE_MATRIX = "simple_matrix"
"""2-D (channel × channel) matrix without a frequency-band dimension."""
HIST = "hist"
"""Histogram / spectral data stored as a ``(frequencies, values)`` tuple."""
# -- shape metadata properties -----------------------------------------------
@property
def extracted_shape(self) -> str:
"""Symbolic shape string after extraction (e.g. ``'W, C, B'``)."""
return FEATURE_SHAPES[self]["extracted_shape"]
@property
def channel_axes(self) -> tuple[int, ...]:
"""Axis indices corresponding to channel dimensions."""
return FEATURE_SHAPES[self]["channel_axes"]
@property
def semantic_axes(self) -> dict[str, int]:
"""Mapping of semantic dimension name to axis index."""
return FEATURE_SHAPES[self]["semantic_axes"]
# -- convenience predicates --------------------------------------------------
@property
def is_linear(self) -> bool:
"""``True`` for per-channel features (scalar or multi-component)."""
return self in (FeatureType.LINEAR, FeatureType.LINEAR_2D)
@property
def is_matrix(self) -> bool:
"""``True`` for features stored as channel × channel matrices."""
return self in (FeatureType.BANDED_MATRIX, FeatureType.SIMPLE_MATRIX)
@property
def is_dict_stored(self) -> bool:
"""``True`` for features stored as dicts keyed by frequency band."""
return self in (FeatureType.BAND, FeatureType.BANDED_MATRIX)
# ---------------------------------------------------------------------------
# Canonical shape registry (defined after FeatureType so members are available)
# ---------------------------------------------------------------------------
FEATURE_SHAPES: dict[FeatureType, dict] = {
FeatureType.LINEAR: {
"extracted_shape": "W, C",
"cell_shape": "C",
"channel_axes": (1,),
"semantic_axes": {},
"description": "Scalar per channel (e.g. rms, ampvar)",
},
FeatureType.LINEAR_2D: {
"extracted_shape": "W, C, K",
"cell_shape": "C, K",
"channel_axes": (1,),
"semantic_axes": {"components": 2},
"description": "K components per channel (e.g. psdslope)",
},
FeatureType.BAND: {
"extracted_shape": "W, C, B",
"cell_shape": "{band: (C,)}",
"channel_axes": (1,),
"semantic_axes": {"bands": 2},
"description": "One value per band per channel (e.g. psdband, logpsdband)",
},
FeatureType.SIMPLE_MATRIX: {
"extracted_shape": "W, C, C",
"cell_shape": "C, C",
"channel_axes": (1, 2),
"semantic_axes": {},
"description": "Channel x channel matrix (e.g. pcorr, zpcorr)",
},
FeatureType.BANDED_MATRIX: {
"extracted_shape": "W, C, C, B",
"cell_shape": "{band: (C, C)}",
"channel_axes": (1, 2),
"semantic_axes": {"bands": 3},
"description": "Channel x channel matrix per band (e.g. cohere, zcohere)",
},
FeatureType.HIST: {
"extracted_shape": "W, C, F",
"cell_shape": "(F,), (C, F)",
"channel_axes": (1,),
"semantic_axes": {"freq_bins": 2},
"description": "Spectral histogram per channel (e.g. psd)",
},
}
"""Canonical shape metadata for each :class:`FeatureType`.
Keys per entry:
- ``extracted_shape``: Symbolic shape after stacking W windows.
- ``cell_shape``: Shape of a single cell in the WAR DataFrame.
- ``channel_axes``: Tuple of axis indices that are channel dimensions.
- ``semantic_axes``: Dict mapping semantic dimension names to axis indices.
- ``description``: Human-readable description.
**Shape convention**: ``(W, C, [C₂], ...semantic)``.
Windows (W) is always axis 0. Channel axes (C) always follow immediately.
Semantic dimensions (bands, components, freq bins) are trailing.
"""
LINEAR_FEATURES = [
"rms",
"ampvar",
"psdtotal",
"nspike",
"logrms",
"logampvar",
"logpsdtotal",
"lognspike",
]
"""List of linear (scalar) feature names computed per channel."""
LINEAR_2D_FEATURES = ["psdslope"]
"""List of multi-component linear features (stored as 2-D arrays per channel)."""
BAND_FEATURES = ["psdband", "psdfrac"] + ["logpsdband", "logpsdfrac"]
"""List of frequency-band feature names (one value per band)."""
MATRIX_FEATURES = ["cohere", "zcohere", "imcoh", "zimcoh", "pcorr", "zpcorr"]
"""List of connectivity/matrix feature names (channel pairs)."""
BANDED_MATRIX_FEATURES = ["cohere", "zcohere", "imcoh", "zimcoh"]
"""Matrix features with frequency band dimension (stored as dict with band keys)."""
SIMPLE_MATRIX_FEATURES = ["pcorr", "zpcorr"]
"""Matrix features without frequency bands (single 2D correlation matrix)."""
HIST_FEATURES = ["psd"]
"""List of histogram/spectral feature names."""
FEATURES = LINEAR_FEATURES + LINEAR_2D_FEATURES + BAND_FEATURES + MATRIX_FEATURES + HIST_FEATURES
"""Complete list of all available features."""
WAR_FEATURES = [f for f in FEATURES if "nspike" not in f]
"""Features available in WindowAnalysisResult (excludes spike-related)."""
# ---------------------------------------------------------------------------
# Centralised feature-name → FeatureType mapping
# ---------------------------------------------------------------------------
FEATURE_TYPES: dict[str, FeatureType] = {
**{f: FeatureType.LINEAR for f in LINEAR_FEATURES},
**{f: FeatureType.LINEAR_2D for f in LINEAR_2D_FEATURES},
**{f: FeatureType.BAND for f in BAND_FEATURES},
**{f: FeatureType.BANDED_MATRIX for f in BANDED_MATRIX_FEATURES},
**{f: FeatureType.SIMPLE_MATRIX for f in SIMPLE_MATRIX_FEATURES},
**{f: FeatureType.HIST for f in HIST_FEATURES},
}
"""Mapping from every known feature name to its :class:`FeatureType`."""
[docs]
def classify_feature(feature_name: str) -> FeatureType:
"""Return the :class:`FeatureType` for a given feature name.
Args:
feature_name: Name of the feature (e.g. ``'rms'``, ``'psdband'``, ``'cohere'``).
Returns:
The corresponding :class:`FeatureType` enum member.
Raises:
ValueError: If *feature_name* is not a recognised feature.
"""
try:
return FEATURE_TYPES[feature_name]
except KeyError:
raise ValueError(
f"Unknown feature '{feature_name}'. "
f"Known features: {sorted(FEATURE_TYPES.keys())}"
) from None
FREQ_BANDS = {
"delta": (1, 4),
"theta": (4, 8),
"alpha": (8, 13),
"beta": (13, 25),
"gamma": (25, 40),
}
"""Dictionary mapping frequency band names to (min_hz, max_hz) tuples.
Delta band adjusted to 1-4 Hz (changed from 0.1-4 Hz) for reliable coherence
estimation with short epochs and to avoid insufficient cycles warnings.
"""
BAND_NAMES = [k for k, _ in FREQ_BANDS.items()]
"""Ordered list of frequency band names: ['delta', 'theta', 'alpha', 'beta', 'gamma']."""
FREQ_BAND_TOTAL = (1, 40)
"""Total frequency range covered by all bands (min, max) in Hz."""
FREQ_MINS = [v[0] for _, v in FREQ_BANDS.items()]
"""List of minimum frequencies for each band."""
FREQ_MAXS = [v[1] for _, v in FREQ_BANDS.items()]
"""List of maximum frequencies for each band."""