ExperimentPlotter#
- class neurodent.visualization.ExperimentPlotter(wars: WindowAnalysisResult | list[WindowAnalysisResult], features: list[str] | None = None, exclude: list[str] | None = None, use_abbreviations: bool = True, plot_order: dict | None = None)[source]#
Bases:
objectA class for creating various plots from a list of multiple experimental datasets.
This class provides methods for creating different types of plots (boxplot, violin plot, scatter plot, etc.) from experimental data with consistent data processing and styling.
- Plot Ordering:
The class automatically sorts data according to predefined plot orders for columns like ‘channel’, ‘genotype’, ‘sex’, ‘isday’, and ‘band’. Users can customize this ordering during initialization:
plotter = ExperimentPlotter(wars, plot_order={'channel': ['LMot', 'RMot', ...]})
The default plot orders are defined in constants.DF_SORT_ORDER.
- Validation and Warnings:
The class automatically validates plot order against the processed DataFrame during plotting and raises warnings for any mismatches. Use validate_plot_order() to explicitly validate:
plotter.validate_plot_order(df)
Example
Customize plot ordering during initialization:
custom_order = { 'channel': ['LMot', 'RMot', 'LBar', 'RBar'], # Only include specific channels 'genotype': ['WT', 'KO'], # Standard order 'sex': ['Female', 'Male'] # Custom order } plotter = ExperimentPlotter(wars, plot_order=custom_order)
- Parameters:
wars (
WindowAnalysisResult | list[WindowAnalysisResult]) – Single WindowAnalysisResult or list of WindowAnalysisResult objectsfeatures (
list[str], optional) – List of features to extract. If None, defaults to [‘all’]exclude (
list[str], optional) – List of features to exclude from extractionuse_abbreviations (
bool, optional) – Whether to use abbreviations for channel namesplot_order (
dict, optional) – Dictionary mapping column names to the order of values for plotting. If None, uses constants.DF_SORT_ORDER.
- Variables:
results (
list[WindowAnalysisResult]) – List of analysis results being plotted.channel_names (
list[list[str]]) – List of channel names for each result.channel_to_idx (
list[dict]) – Mapping of channel names to indices for each result.all_channel_names (
list[str]) – Union of all channel names across results.df_wars (
list[pd.DataFrame]) – List of dataframes for each result.concat_df_wars (
pd.DataFrame) – Concatenated dataframe of all results.
- __init__(wars: WindowAnalysisResult | list[WindowAnalysisResult], features: list[str] | None = None, exclude: list[str] | None = None, use_abbreviations: bool = True, plot_order: dict | None = None)[source]#
- validate_plot_order(df: DataFrame, raise_errors: bool = False) dict[source]#
Validate that the current plot_order contains all necessary categories for the data.
Parameters#
- dfpd.DataFrame
DataFrame to validate against (should be the DataFrame that will be sorted).
- raise_errorsbool, optional
Whether to raise errors for validation issues. Default is False.
Returns#
- dict
Dictionary with validation results for each column
- pull_timeseries_dataframe(feature: str, groupby: str | list[str], channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, strict_groupby: bool = False)[source]#
Process feature data for plotting.
Parameters#
- featurestr
The feature to get.
- groupbystr or list[str]
The variable(s) to group by.
- channelsstr or list[str], optional
The channels to get. If ‘all’, all channels are used.
- collapse_channelsbool, optional
Whether to average the channels to one value.
- average_groupbybool, optional
Whether to average the groupby variable(s).
- strict_groupbybool, optional
If True, raise an exception when groupby columns contain NaN values. If False (default), only issue a warning.
Returns#
- dfpd.DataFrame
A DataFrame with the feature data.
- plot_catplot(feature: str, groupby: str | list[str], df: DataFrame | None = None, x: str | None = None, col: str | None = None, hue: str | None = None, kind: Literal['box', 'boxen', 'violin', 'strip', 'swarm', 'bar', 'point'] = 'box', catplot_params: dict | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, title: str | None = None, cmap: str | None = None, stat_pairs: list[tuple[str, str]] | Literal['all', 'x', 'hue'] | None = None, stat_test: str = 'Mann-Whitney', norm_test: Literal[None, 'D-Agostino', 'log-D-Agostino', 'K-S'] | None = None) FacetGrid[source]#
Create a boxplot of feature data.
- plot_heatmap(feature: str, groupby: str | list[str], df: DataFrame | None = None, col: str | None = None, row: str | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, cmap: str = 'RdBu_r', norm: Normalize | None = None, height: float = 3, aspect: float = 1)[source]#
Create a 2D feature plot.
Parameters:#
- cmapstr, default=”RdBu_r”
Colormap name or matplotlib colormap object
- normmatplotlib.colors.Normalize, optional
Normalization object. If None, will use CenteredNorm with auto-detected range. Common options: - colors.CenteredNorm(vcenter=0) # Auto-detect range around 0 - colors.Normalize(vmin=-1, vmax=1) # Fixed range - colors.LogNorm() # Logarithmic scale
- plot_heatmap_faceted(feature: str, groupby: str | list[str], facet_vars: list[str] | str, df: DataFrame | None = None, **kwargs)[source]#
- plot_diffheatmap(feature: str, groupby: str | list[str], baseline_key: str | bool | tuple[str, ...], baseline_groupby: list[str] | str | None = None, operation: Literal['subtract', 'divide'] = 'subtract', remove_baseline: bool = False, df: DataFrame | None = None, col: str | None = None, row: str | None = None, channels: str | list[str] = 'all', collapse_channels: bool = False, average_groupby: bool = False, cmap: str = 'RdBu_r', norm: Normalize | None = None, height: float = 3, aspect: float = 1)[source]#
Create a 2D feature plot of differences between groups. Baseline is subtracted from other groups.
Parameters:#
- cmapstr, default=”RdBu_r”
Colormap name or matplotlib colormap object
- normmatplotlib.colors.Normalize, optional
Normalization object. If None, will use CenteredNorm with auto-detected range. Common options: - colors.CenteredNorm(vcenter=0) # Auto-detect range around 0 - colors.Normalize(vmin=-1, vmax=1) # Fixed range - colors.LogNorm() # Logarithmic scale
- plot_diffheatmap_faceted(feature: str, groupby: str | list[str], facet_vars: str | list[str], baseline_key: str | bool | tuple[str, ...], baseline_groupby: list[str] | str | None = None, operation: Literal['subtract', 'divide'] = 'subtract', remove_baseline: bool = False, df: DataFrame | None = None, cmap: str = 'RdBu_r', norm: Normalize | None = None, **kwargs)[source]#
- plot_qqplot(feature: str, groupby: str | list[str], df: DataFrame | None = None, col: str | None = None, row: str | None = None, log: bool = False, channels: str | list[str] = 'all', collapse_channels: bool = False, height: float = 3, aspect: float = 1, **kwargs)[source]#
Create a QQ plot of the feature data.