ExperimentPlotter#
- class neurodent.visualization.ExperimentPlotter(wars: WindowAnalysisResult | list[WindowAnalysisResult], features: list[str] | None = None, exclude: list[str] | None = None, use_abbreviations: bool = True, plot_order: dict | None = None)[source]#
Bases:
objectA class for creating various plots from a list of multiple experimental datasets.
This class provides methods for creating different types of plots (boxplot, violin plot, scatter plot, etc.) from experimental data with consistent data processing and styling.
Plot Ordering#
The class automatically sorts data according to predefined plot orders for columns like ‘channel’, ‘genotype’, ‘sex’, ‘isday’, and ‘band’. Users can customize this ordering during initialization:
plotter = ExperimentPlotter(wars, plot_order={'channel': ['LMot', 'RMot', ...]})
The default plot orders are defined in constants.DF_SORT_ORDER.
Validation and Warnings#
The class automatically validates plot order against the processed DataFrame during plotting and raises warnings for any mismatches. Use validate_plot_order() to explicitly validate:
plotter.validate_plot_order(df)
Examples
Customize plot ordering during initialization:
custom_order = { 'channel': ['LMot', 'RMot', 'LBar', 'RBar'], # Only include specific channels 'genotype': ['WT', 'KO'], # Standard order 'sex': ['Female', 'Male'] # Custom order } plotter = ExperimentPlotter(wars, plot_order=custom_order)
- __init__(wars: WindowAnalysisResult | list[WindowAnalysisResult], features: list[str] | None = None, exclude: list[str] | None = None, use_abbreviations: bool = True, plot_order: dict | None = None)[source]#
Initialize plotter with WindowAnalysisResult object(s).
- Parameters:
wars (
WindowAnalysisResultorlist[WindowAnalysisResult]) – Single WindowAnalysisResult or list of WindowAnalysisResult objectsfeatures (
list[str], optional) – List of features to extract. If None, defaults to [‘all’]exclude (
list[str], optional) – List of features to exclude from extractionuse_abbreviations (
bool, optional) – Whether to use abbreviations for channel namesplot_order (
dict, optional) – Dictionary mapping column names to the order of values for plotting. If None, uses constants.DF_SORT_ORDER.
- validate_plot_order(df: DataFrame, raise_errors: bool = False) dict[source]#
Validate that the current plot_order contains all necessary categories for the data.
- Parameters:
df (
pd.DataFrame) – DataFrame to validate against (should be the DataFrame that will be sorted).raise_errors (
bool, optional) – Whether to raise errors for validation issues. Default is False.
- Returns:
Dictionary with validation results for each column
- Return type:
dict
- pull_timeseries_dataframe(feature: str, groupby: str | list[str], channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, strict_groupby: bool = False)[source]#
Process feature data for plotting.
- Parameters:
feature (
str) – The feature to get.groupby (
strorlist[str]) – The variable(s) to group by.channels (
strorlist[str], optional) – The channels to get. If ‘all’, all channels are used.collapse_channels (
bool, optional) – Whether to average the channels to one value.average_groupby (
bool, optional) – Whether to average the groupby variable(s).strict_groupby (
bool, optional) – If True, raise an exception when groupby columns contain NaN values. If False (default), only issue a warning.
- Returns:
df – A DataFrame with the feature data.
- Return type:
pd.DataFrame
- plot_catplot(feature: str, groupby: str | list[str], df: DataFrame | None = None, x: str | None = None, col: str | None = None, hue: str | None = None, kind: Literal['box', 'boxen', 'violin', 'strip', 'swarm', 'bar', 'point'] = 'box', catplot_params: dict | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, title: str | None = None, cmap: str | None = None, stat_pairs: list[tuple[str, str]] | Literal['all', 'x', 'hue'] | None = None, stat_test: str = 'Mann-Whitney', norm_test: Literal[None, 'D-Agostino', 'log-D-Agostino', 'K-S'] | None = None) FacetGrid[source]#
Create a boxplot of feature data.
- Return type:
FacetGrid- Parameters:
feature (str)
groupby (str | list[str])
df (DataFrame | None)
x (str | None)
col (str | None)
hue (str | None)
kind (Literal['box', 'boxen', 'violin', 'strip', 'swarm', 'bar', 'point'])
catplot_params (dict | None)
channels (str | list[str])
collapse_channels (bool)
average_groupby (bool)
title (str | None)
cmap (str | None)
stat_pairs (list[tuple[str, str]] | Literal['all', 'x', 'hue'] | None)
stat_test (str)
norm_test (Literal[None, 'D-Agostino', 'log-D-Agostino', 'K-S'] | None)
- plot_heatmap(feature: str, groupby: str | list[str], df: DataFrame | None = None, col: str | None = None, row: str | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, cmap: str = 'RdBu_r', norm: Normalize | None = None, height: float = 3, aspect: float = 1)[source]#
Create a 2D feature plot.
Parameters:#
- cmapstr, default=”RdBu_r”
Colormap name or matplotlib colormap object
- normmatplotlib.colors.Normalize, optional
Normalization object. If None, will use CenteredNorm with auto-detected range. Common options: - colors.CenteredNorm(vcenter=0) # Auto-detect range around 0 - colors.Normalize(vmin=-1, vmax=1) # Fixed range - colors.LogNorm() # Logarithmic scale
- Parameters:
feature (str)
groupby (str | list[str])
df (DataFrame | None)
col (str | None)
row (str | None)
channels (str | list[str])
collapse_channels (bool)
average_groupby (bool)
cmap (str)
norm (Normalize | None)
height (float)
aspect (float)
- plot_heatmap_faceted(feature: str, groupby: str | list[str], facet_vars: list[str] | str, df: DataFrame | None = None, **kwargs)[source]#
- Parameters:
feature (str)
groupby (str | list[str])
facet_vars (list[str] | str)
df (DataFrame | None)
- plot_diffheatmap(feature: str, groupby: str | list[str], baseline_key: str | bool | tuple[str, ...], baseline_groupby: list[str] | str | None = None, operation: Literal['subtract', 'divide'] = 'subtract', remove_baseline: bool = False, df: DataFrame | None = None, col: str | None = None, row: str | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, cmap: str = 'RdBu_r', norm: Normalize | None = None, height: float = 3, aspect: float = 1)[source]#
Create a 2D feature plot of differences between groups. Baseline is subtracted from other groups.
Parameters:#
- cmapstr, default=”RdBu_r”
Colormap name or matplotlib colormap object
- normmatplotlib.colors.Normalize, optional
Normalization object. If None, will use CenteredNorm with auto-detected range. Common options: - colors.CenteredNorm(vcenter=0) # Auto-detect range around 0 - colors.Normalize(vmin=-1, vmax=1) # Fixed range - colors.LogNorm() # Logarithmic scale
- Parameters:
feature (str)
groupby (str | list[str])
baseline_key (str | bool | tuple[str, ...])
baseline_groupby (list[str] | str | None)
operation (Literal['subtract', 'divide'])
remove_baseline (bool)
df (DataFrame | None)
col (str | None)
row (str | None)
channels (str | list[str])
collapse_channels (bool)
average_groupby (bool)
cmap (str)
norm (Normalize | None)
height (float)
aspect (float)
- plot_diffheatmap_faceted(feature: str, groupby: str | list[str], facet_vars: str | list[str], baseline_key: str | bool | tuple[str, ...], baseline_groupby: list[str] | str | None = None, operation: Literal['subtract', 'divide'] = 'subtract', remove_baseline: bool = False, df: DataFrame | None = None, cmap: str = 'RdBu_r', norm: Normalize | None = None, **kwargs)[source]#
- Parameters:
feature (str)
groupby (str | list[str])
facet_vars (str | list[str])
baseline_key (str | bool | tuple[str, ...])
baseline_groupby (list[str] | str | None)
operation (Literal['subtract', 'divide'])
remove_baseline (bool)
df (DataFrame | None)
cmap (str)
norm (Normalize | None)
- plot_qqplot(feature: str, groupby: str | list[str], df: DataFrame | None = None, col: str | None = None, row: str | None = None, log: bool = False, channels: str | list[str] = 'all', collapse_channels: bool = False, height: float = 3, aspect: float = 1, **kwargs)[source]#
Create a QQ plot of the feature data.
- Parameters:
feature (str)
groupby (str | list[str])
df (DataFrame | None)
col (str | None)
row (str | None)
log (bool)
channels (str | list[str])
collapse_channels (bool)
height (float)
aspect (float)
- Parameters:
wars (WindowAnalysisResult | list[WindowAnalysisResult])
features (list[str])
exclude (list[str])
use_abbreviations (bool)
plot_order (dict)