ExperimentPlotter#

class neurodent.visualization.ExperimentPlotter(wars: WindowAnalysisResult | list[WindowAnalysisResult], features: list[str] | None = None, exclude: list[str] | None = None, use_abbreviations: bool = True, plot_order: dict | None = None)[source]#

Bases: object

A class for creating various plots from a list of multiple experimental datasets.

This class provides methods for creating different types of plots (boxplot, violin plot, scatter plot, etc.) from experimental data with consistent data processing and styling.

Plot Ordering:

The class automatically sorts data according to predefined plot orders for columns like ‘channel’, ‘genotype’, ‘sex’, ‘isday’, and ‘band’. Users can customize this ordering during initialization:

plotter = ExperimentPlotter(wars, plot_order={'channel': ['LMot', 'RMot', ...]})

The default plot orders are defined in constants.DF_SORT_ORDER.

Validation and Warnings:

The class automatically validates plot order against the processed DataFrame during plotting and raises warnings for any mismatches. Use validate_plot_order() to explicitly validate:

plotter.validate_plot_order(df)

Example

Customize plot ordering during initialization:

custom_order = {
    'channel': ['LMot', 'RMot', 'LBar', 'RBar'],  # Only include specific channels
    'genotype': ['WT', 'KO'],  # Standard order
    'sex': ['Female', 'Male']  # Custom order
}
plotter = ExperimentPlotter(wars, plot_order=custom_order)
Parameters:
  • wars (WindowAnalysisResult | list[WindowAnalysisResult]) – Single WindowAnalysisResult or list of WindowAnalysisResult objects

  • features (list[str], optional) – List of features to extract. If None, defaults to [‘all’]

  • exclude (list[str], optional) – List of features to exclude from extraction

  • use_abbreviations (bool, optional) – Whether to use abbreviations for channel names

  • plot_order (dict, optional) – Dictionary mapping column names to the order of values for plotting. If None, uses constants.DF_SORT_ORDER.

Variables:
  • results (list[WindowAnalysisResult]) – List of analysis results being plotted.

  • channel_names (list[list[str]]) – List of channel names for each result.

  • channel_to_idx (list[dict]) – Mapping of channel names to indices for each result.

  • all_channel_names (list[str]) – Union of all channel names across results.

  • df_wars (list[pd.DataFrame]) – List of dataframes for each result.

  • concat_df_wars (pd.DataFrame) – Concatenated dataframe of all results.

__init__(wars: WindowAnalysisResult | list[WindowAnalysisResult], features: list[str] | None = None, exclude: list[str] | None = None, use_abbreviations: bool = True, plot_order: dict | None = None)[source]#
validate_plot_order(df: DataFrame, raise_errors: bool = False) dict[source]#

Validate that the current plot_order contains all necessary categories for the data.

Parameters#

dfpd.DataFrame

DataFrame to validate against (should be the DataFrame that will be sorted).

raise_errorsbool, optional

Whether to raise errors for validation issues. Default is False.

Returns#

dict

Dictionary with validation results for each column

pull_timeseries_dataframe(feature: str, groupby: str | list[str], channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, strict_groupby: bool = False)[source]#

Process feature data for plotting.

Parameters#

featurestr

The feature to get.

groupbystr or list[str]

The variable(s) to group by.

channelsstr or list[str], optional

The channels to get. If ‘all’, all channels are used.

collapse_channelsbool, optional

Whether to average the channels to one value.

average_groupbybool, optional

Whether to average the groupby variable(s).

strict_groupbybool, optional

If True, raise an exception when groupby columns contain NaN values. If False (default), only issue a warning.

Returns#

dfpd.DataFrame

A DataFrame with the feature data.

plot_catplot(feature: str, groupby: str | list[str], df: DataFrame | None = None, x: str | None = None, col: str | None = None, hue: str | None = None, kind: Literal['box', 'boxen', 'violin', 'strip', 'swarm', 'bar', 'point'] = 'box', catplot_params: dict | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, title: str | None = None, cmap: str | None = None, stat_pairs: list[tuple[str, str]] | Literal['all', 'x', 'hue'] | None = None, stat_test: str = 'Mann-Whitney', norm_test: Literal[None, 'D-Agostino', 'log-D-Agostino', 'K-S'] | None = None) FacetGrid[source]#

Create a boxplot of feature data.

plot_heatmap(feature: str, groupby: str | list[str], df: DataFrame | None = None, col: str | None = None, row: str | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, cmap: str = 'RdBu_r', norm: Normalize | None = None, height: float = 3, aspect: float = 1)[source]#

Create a 2D feature plot.

Parameters:#

cmapstr, default=”RdBu_r”

Colormap name or matplotlib colormap object

normmatplotlib.colors.Normalize, optional

Normalization object. If None, will use CenteredNorm with auto-detected range. Common options: - colors.CenteredNorm(vcenter=0) # Auto-detect range around 0 - colors.Normalize(vmin=-1, vmax=1) # Fixed range - colors.LogNorm() # Logarithmic scale

plot_heatmap_faceted(feature: str, groupby: str | list[str], facet_vars: list[str] | str, df: DataFrame | None = None, **kwargs)[source]#
plot_diffheatmap(feature: str, groupby: str | list[str], baseline_key: str | bool | tuple[str, ...], baseline_groupby: list[str] | str | None = None, operation: Literal['subtract', 'divide'] = 'subtract', remove_baseline: bool = False, df: DataFrame | None = None, col: str | None = None, row: str | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, cmap: str = 'RdBu_r', norm: Normalize | None = None, height: float = 3, aspect: float = 1)[source]#

Create a 2D feature plot of differences between groups. Baseline is subtracted from other groups.

Parameters:#

cmapstr, default=”RdBu_r”

Colormap name or matplotlib colormap object

normmatplotlib.colors.Normalize, optional

Normalization object. If None, will use CenteredNorm with auto-detected range. Common options: - colors.CenteredNorm(vcenter=0) # Auto-detect range around 0 - colors.Normalize(vmin=-1, vmax=1) # Fixed range - colors.LogNorm() # Logarithmic scale

plot_diffheatmap_faceted(feature: str, groupby: str | list[str], facet_vars: str | list[str], baseline_key: str | bool | tuple[str, ...], baseline_groupby: list[str] | str | None = None, operation: Literal['subtract', 'divide'] = 'subtract', remove_baseline: bool = False, df: DataFrame | None = None, cmap: str = 'RdBu_r', norm: Normalize | None = None, **kwargs)[source]#
plot_qqplot(feature: str, groupby: str | list[str], df: DataFrame | None = None, col: str | None = None, row: str | None = None, log: bool = False, channels: str | list[str] = 'all', collapse_channels: bool = False, height: float = 3, aspect: float = 1, **kwargs)[source]#

Create a QQ plot of the feature data.