Source code for arviz_plots.plots.pava_residual_plot

"""Posterior predictive check for residuals using PAV-adjusted calibration."""
import warnings
from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal

from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.helper_stats import isotonic_fit

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import filter_aes, get_visual_kwargs, set_wrap_layout
from arviz_plots.visuals import (
    fill_between_y,
    labelled_title,
    labelled_x,
    labelled_y,
    line_x,
    scatter_xy,
)


[docs] def plot_ppc_pava_residuals( dt, x_var, data_type="binary", ci_prob=None, var_names=None, filter_vars=None, # pylint: disable=unused-argument group="posterior_predictive", coords=None, # pylint: disable=unused-argument sample_dims=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "lines", "markers", "reference_line", "credible_interval", "xlabel", "ylabel", "title", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "lines", "markers", "reference_line", "credible_interval", "xlabel", "ylabel", "title", ], Mapping[str, Any] | bool, ] = None, **pc_kwargs, ): """PAV-adjusted calibration residual plot. Uses the pool adjacent violators (PAV) algorithm for isotonic regression and computes residuals as the difference between the calibrated event probabilities (CEP) and the predicted probabilities. A horizontal line at zero corresponds to perfect calibration. Details are discussed in [1]_ and [2]_. Parameters ---------- dt : DataTree Input data x_var : array-like, series, DataArray, or str Variable to use for x-axis. If a string is given, it should be the name of a variable in the `constant_data` group. data_type : str Defaults to "binary". Other options are "categorical" and "ordinal". If "categorical", the plot will show the "one-vs-others" calibration and generate one plot per category. If "ordinal", the plot will display cumulative conditional event probabilities and generate (number of categories - 1) plots. ci_prob : float, optional Probability for the credible interval. Defaults to ``rcParams["stats.ci_prob"]``. var_names : str or list of str, optional One or more variables to be plotted. Currently only one variable is supported. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, "like", "regex"}, optional, default=None If None (default), interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. group : str, optional The group from which to get the unique values. Defaults to "posterior_predictive". It could also be "prior_predictive". Notice that this plots always use the "observed_data" so use with extra care if you are using "prior_predictive". coords : dict, optional Coordinates to plot. CURRENTLY NOT IMPLEMENTED sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals`. visuals : mapping of {str : mapping or bool}, optional Valid keys are: * markers -> passed to :func:`~arviz_plots.visuals.scatter_xy` * reference_line -> passed to :func:`~arviz_plots.visuals.line_x` * credible_interval -> passed to :func:`~arviz_plots.visuals.fill_between_y` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` markers defaults to True for residual plots. Pass False to disable markers. **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.grid` Returns ------- PlotCollection Examples -------- Plot the PAVA residual plot for the zeros and non-zeros in a negative bimomial model of the roaches dataset. .. plot:: :context: close-figs >>> from arviz_plots import plot_ppc_pava_residuals, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('roaches_nb') >>> plot_ppc_pava_residuals(dt, >>> var_names="y_pos", >>> x_var="roach count") .. minigallery:: plot_ppc_pava_residuals References ---------- .. [1] Säilynoja et al. *Recommendations for visual predictive checks in Bayesian workflow*. (2025) arXiv preprint https://arxiv.org/abs/2503.01509 .. [2] Dimitriadis et al *Stable reliability diagrams for probabilistic classifiers*. PNAS, 118(8) (2021). https://doi.org/10.1073/pnas.2016191118 """ if ci_prob is None: ci_prob = rcParams["stats.ci_prob"] if sample_dims is None: sample_dims = rcParams["data.sample_dims"] if isinstance(sample_dims, str): sample_dims = [sample_dims] sample_dims = list(sample_dims) if visuals is None: visuals = {} else: visuals = visuals.copy() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend if labeller is None: labeller = BaseLabeller() visuals.setdefault("markers", True) if group == "prior_predictive": warnings.warn( "\n`plot_ppc_pava_residuals` always use the `observed_data` group." "\nBe cautious when using it for prior predictive checks.", UserWarning, stacklevel=2, ) if isinstance(x_var, str): x_val_name = x_var x_var = dt.constant_data[x_var].values elif hasattr(x_var, "values"): x_val_name = x_var.name x_var = x_var.values else: x_val_name = "x" ds_residuals = isotonic_fit( dt, var_names, group, ci_prob, data_type, residuals=True, x_var=x_var ) plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs.setdefault("cols", ["__variable__"]) pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, ds_residuals) plot_collection = PlotCollection.wrap( ds_residuals, backend=backend, **pc_kwargs, ) if aes_by_visuals is None: aes_by_visuals = {} else: aes_by_visuals = aes_by_visuals.copy() ## reference line at zero reference_ls_kwargs = get_visual_kwargs(visuals, "reference_line") if reference_ls_kwargs is not False: _, _, reference_ls_ignore = filter_aes( plot_collection, aes_by_visuals, "reference_line", sample_dims ) reference_ls_kwargs.setdefault("color", "B2") reference_ls_kwargs.setdefault("linestyle", "C1") reference_ls_kwargs.setdefault("y", 0) plot_collection.map( line_x, "reference_line", data=ds_residuals.sel(plot_axis="x"), ignore_aes=reference_ls_ignore, **reference_ls_kwargs, ) ## credible interval ci_kwargs = get_visual_kwargs(visuals, "credible_interval") _, _, ci_ignore = filter_aes(plot_collection, aes_by_visuals, "credible_interval", sample_dims) if ci_kwargs is not False: ci_kwargs.setdefault("color", "C0") ci_kwargs.setdefault("alpha", 0.25) plot_collection.map( fill_between_y, "credible_interval", data=ds_residuals, x=ds_residuals.sel(plot_axis="x"), y_bottom=ds_residuals.sel(plot_axis="y_bottom"), y_top=ds_residuals.sel(plot_axis="y_top"), ignore_aes=ci_ignore, **ci_kwargs, ) ## markers residual_ms_kwargs = get_visual_kwargs(visuals, "markers") if residual_ms_kwargs is not False: _, _, residual_ms_ignore = filter_aes( plot_collection, aes_by_visuals, "markers", sample_dims ) residual_ms_kwargs.setdefault("color", "C0") residual_ms_kwargs.setdefault("marker", "C6") plot_collection.map( scatter_xy, "markers", data=ds_residuals, ignore_aes=residual_ms_ignore, **residual_ms_kwargs, ) # set xlabel _, xlabels_aes, xlabels_ignore = filter_aes( plot_collection, aes_by_visuals, "xlabel", sample_dims ) xlabel_kwargs = get_visual_kwargs(visuals, "xlabel") if xlabel_kwargs is not False: if "color" not in xlabels_aes: xlabel_kwargs.setdefault("color", "B1") xlabel_kwargs.setdefault("text", x_val_name) plot_collection.map( labelled_x, "xlabel", ignore_aes=xlabels_ignore, subset_info=True, **xlabel_kwargs, ) # set ylabel _, ylabels_aes, ylabels_ignore = filter_aes( plot_collection, aes_by_visuals, "ylabel", sample_dims ) ylabel_kwargs = get_visual_kwargs(visuals, "ylabel") if ylabel_kwargs is not False: if "color" not in ylabels_aes: ylabel_kwargs.setdefault("color", "B1") ylabel_kwargs.setdefault("text", "CEP residual") plot_collection.map( labelled_y, "ylabel", ignore_aes=ylabels_ignore, subset_info=True, **ylabel_kwargs, ) # title title_kwargs = get_visual_kwargs(visuals, "title") _, _, title_ignore = filter_aes(plot_collection, aes_by_visuals, "title", sample_dims) if title_kwargs is not False: plot_collection.map( labelled_title, "title", ignore_aes=title_ignore, subset_info=True, labeller=labeller, **title_kwargs, ) return plot_collection