ExperimentPlotter#

class neurodent.visualization.ExperimentPlotter(wars: WindowAnalysisResult | list[WindowAnalysisResult], features: list[str] | None = None, exclude: list[str] | None = None, use_abbreviations: bool = True, plot_order: dict | None = None)[source]#

Bases: object

A class for creating various plots from a list of multiple experimental datasets.

This class provides methods for creating different types of plots (boxplot, violin plot, scatter plot, etc.) from experimental data with consistent data processing and styling.

Plot Ordering#

The class automatically sorts data according to predefined plot orders for columns like ‘channel’, ‘genotype’, ‘sex’, ‘isday’, and ‘band’. Users can customize this ordering during initialization:

plotter = ExperimentPlotter(wars, plot_order={'channel': ['LMot', 'RMot', ...]})

The default plot orders are defined in constants.DF_SORT_ORDER.

Validation and Warnings#

The class automatically validates plot order against the processed DataFrame during plotting and raises warnings for any mismatches. Use validate_plot_order() to explicitly validate:

plotter.validate_plot_order(df)

Examples

Customize plot ordering during initialization:

custom_order = {
    'channel': ['LMot', 'RMot', 'LBar', 'RBar'],  # Only include specific channels
    'genotype': ['WT', 'KO'],  # Standard order
    'sex': ['Female', 'Male']  # Custom order
}
plotter = ExperimentPlotter(wars, plot_order=custom_order)
__init__(wars: WindowAnalysisResult | list[WindowAnalysisResult], features: list[str] | None = None, exclude: list[str] | None = None, use_abbreviations: bool = True, plot_order: dict | None = None)[source]#

Initialize plotter with WindowAnalysisResult object(s).

Parameters:
  • wars (WindowAnalysisResult or list[WindowAnalysisResult]) – Single WindowAnalysisResult or list of WindowAnalysisResult objects

  • features (list[str], optional) – List of features to extract. If None, defaults to [‘all’]

  • exclude (list[str], optional) – List of features to exclude from extraction

  • use_abbreviations (bool, optional) – Whether to use abbreviations for channel names

  • plot_order (dict, optional) – Dictionary mapping column names to the order of values for plotting. If None, uses constants.DF_SORT_ORDER.

validate_plot_order(df: DataFrame, raise_errors: bool = False) dict[source]#

Validate that the current plot_order contains all necessary categories for the data.

Parameters:
  • df (pd.DataFrame) – DataFrame to validate against (should be the DataFrame that will be sorted).

  • raise_errors (bool, optional) – Whether to raise errors for validation issues. Default is False.

Returns:

Dictionary with validation results for each column

Return type:

dict

pull_timeseries_dataframe(feature: str, groupby: str | list[str], channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, strict_groupby: bool = False)[source]#

Process feature data for plotting.

Parameters:
  • feature (str) – The feature to get.

  • groupby (str or list[str]) – The variable(s) to group by.

  • channels (str or list[str], optional) – The channels to get. If ‘all’, all channels are used.

  • collapse_channels (bool, optional) – Whether to average the channels to one value.

  • average_groupby (bool, optional) – Whether to average the groupby variable(s).

  • strict_groupby (bool, optional) – If True, raise an exception when groupby columns contain NaN values. If False (default), only issue a warning.

Returns:

df – A DataFrame with the feature data.

Return type:

pd.DataFrame

plot_catplot(feature: str, groupby: str | list[str], df: DataFrame | None = None, x: str | None = None, col: str | None = None, hue: str | None = None, kind: Literal['box', 'boxen', 'violin', 'strip', 'swarm', 'bar', 'point'] = 'box', catplot_params: dict | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, title: str | None = None, cmap: str | None = None, stat_pairs: list[tuple[str, str]] | Literal['all', 'x', 'hue'] | None = None, stat_test: str = 'Mann-Whitney', norm_test: Literal[None, 'D-Agostino', 'log-D-Agostino', 'K-S'] | None = None) FacetGrid[source]#

Create a boxplot of feature data.

Return type:

FacetGrid

Parameters:
  • feature (str)

  • groupby (str | list[str])

  • df (DataFrame | None)

  • x (str | None)

  • col (str | None)

  • hue (str | None)

  • kind (Literal['box', 'boxen', 'violin', 'strip', 'swarm', 'bar', 'point'])

  • catplot_params (dict | None)

  • channels (str | list[str])

  • collapse_channels (bool)

  • average_groupby (bool)

  • title (str | None)

  • cmap (str | None)

  • stat_pairs (list[tuple[str, str]] | Literal['all', 'x', 'hue'] | None)

  • stat_test (str)

  • norm_test (Literal[None, 'D-Agostino', 'log-D-Agostino', 'K-S'] | None)

plot_heatmap(feature: str, groupby: str | list[str], df: DataFrame | None = None, col: str | None = None, row: str | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, cmap: str = 'RdBu_r', norm: Normalize | None = None, height: float = 3, aspect: float = 1)[source]#

Create a 2D feature plot.

Parameters:#

cmapstr, default=”RdBu_r”

Colormap name or matplotlib colormap object

normmatplotlib.colors.Normalize, optional

Normalization object. If None, will use CenteredNorm with auto-detected range. Common options: - colors.CenteredNorm(vcenter=0) # Auto-detect range around 0 - colors.Normalize(vmin=-1, vmax=1) # Fixed range - colors.LogNorm() # Logarithmic scale

Parameters:
  • feature (str)

  • groupby (str | list[str])

  • df (DataFrame | None)

  • col (str | None)

  • row (str | None)

  • channels (str | list[str])

  • collapse_channels (bool)

  • average_groupby (bool)

  • cmap (str)

  • norm (Normalize | None)

  • height (float)

  • aspect (float)

plot_heatmap_faceted(feature: str, groupby: str | list[str], facet_vars: list[str] | str, df: DataFrame | None = None, **kwargs)[source]#
Parameters:
  • feature (str)

  • groupby (str | list[str])

  • facet_vars (list[str] | str)

  • df (DataFrame | None)

plot_diffheatmap(feature: str, groupby: str | list[str], baseline_key: str | bool | tuple[str, ...], baseline_groupby: list[str] | str | None = None, operation: Literal['subtract', 'divide'] = 'subtract', remove_baseline: bool = False, df: DataFrame | None = None, col: str | None = None, row: str | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, cmap: str = 'RdBu_r', norm: Normalize | None = None, height: float = 3, aspect: float = 1)[source]#

Create a 2D feature plot of differences between groups. Baseline is subtracted from other groups.

Parameters:#

cmapstr, default=”RdBu_r”

Colormap name or matplotlib colormap object

normmatplotlib.colors.Normalize, optional

Normalization object. If None, will use CenteredNorm with auto-detected range. Common options: - colors.CenteredNorm(vcenter=0) # Auto-detect range around 0 - colors.Normalize(vmin=-1, vmax=1) # Fixed range - colors.LogNorm() # Logarithmic scale

Parameters:
  • feature (str)

  • groupby (str | list[str])

  • baseline_key (str | bool | tuple[str, ...])

  • baseline_groupby (list[str] | str | None)

  • operation (Literal['subtract', 'divide'])

  • remove_baseline (bool)

  • df (DataFrame | None)

  • col (str | None)

  • row (str | None)

  • channels (str | list[str])

  • collapse_channels (bool)

  • average_groupby (bool)

  • cmap (str)

  • norm (Normalize | None)

  • height (float)

  • aspect (float)

plot_diffheatmap_faceted(feature: str, groupby: str | list[str], facet_vars: str | list[str], baseline_key: str | bool | tuple[str, ...], baseline_groupby: list[str] | str | None = None, operation: Literal['subtract', 'divide'] = 'subtract', remove_baseline: bool = False, df: DataFrame | None = None, cmap: str = 'RdBu_r', norm: Normalize | None = None, **kwargs)[source]#
Parameters:
  • feature (str)

  • groupby (str | list[str])

  • facet_vars (str | list[str])

  • baseline_key (str | bool | tuple[str, ...])

  • baseline_groupby (list[str] | str | None)

  • operation (Literal['subtract', 'divide'])

  • remove_baseline (bool)

  • df (DataFrame | None)

  • cmap (str)

  • norm (Normalize | None)

plot_qqplot(feature: str, groupby: str | list[str], df: DataFrame | None = None, col: str | None = None, row: str | None = None, log: bool = False, channels: str | list[str] = 'all', collapse_channels: bool = False, height: float = 3, aspect: float = 1, **kwargs)[source]#

Create a QQ plot of the feature data.

Parameters:
  • feature (str)

  • groupby (str | list[str])

  • df (DataFrame | None)

  • col (str | None)

  • row (str | None)

  • log (bool)

  • channels (str | list[str])

  • collapse_channels (bool)

  • height (float)

  • aspect (float)

Parameters: