Source code for arviz_plots.backend.plotly.legend

"""Plotly legend generation."""
from functools import partial

import numpy as np
import xarray as xr
from arviz_base import xarray_sel_iter
from plotly.graph_objects import Bar, Scatter

from .core import expand_aesthetic_aliases


@expand_aesthetic_aliases
def dealiase_line_kwargs(**kwargs):
    """Convert arviz common interface properties to plotly ones."""
    prop_map = {"linewidth": "width", "linestyle": "dash"}
    return {prop_map.get(key, key): value for key, value in kwargs.items()}


def _trace_matcher(trace, target_viz):
    """See if a plotly trace matches a target trace `target_viz`."""
    return (
        (getattr(trace, "mode", "na") == getattr(target_viz, "mode", "na"))
        and (getattr(trace, "line", "na") == getattr(target_viz, "line", "na"))
        and (trace.marker == target_viz.marker)
        and (trace.text == target_viz.text)
        and (trace.x.shape == target_viz.x.shape)
        and np.allclose(trace.y, target_viz.y)
        and np.allclose(trace.x, target_viz.x)
    )


LINE_SUBKEYS = [
    "backoff",
    "backoffsrc",
    "color",
    "dash",
    "shape",
    "simplify",
    "smoothing",
    "width",
]


[docs] def legend( plot_collection, kwarg_list, label_list, title=None, visual_type="line", visual_kwargs=None, legend_dim=None, update_visuals=True, **kwargs, ): """Generate a legend with plotly. Parameters ---------- plot_collection : PlotCollection The PlotCollection for which a legend should be generated kwarg_list : list List of style dictionaries for each legend entry label_list : list List of labels for each legend entry title : str, optional Title of the legend artist_type : str, optional Type of visual to use for legend entries. Currently only "line" is supported. artist_kwargs : dict, optional Additional kwargs passed to all visuals legend_dim : str or tuple of str, optional update_visuals : bool, default True **kwargs : dict Additional kwargs passed to legend configuration Returns ------- None The legend is added to the target figure inplace """ figure = plot_collection.get_viz("figure") # NOTE: Legend IDs in Plotly must be 'legend', 'legend2', 'legend3', etc. if "legend" in plot_collection.viz.children: legend_number = len(plot_collection.viz["legend"].data_vars) + 1 legend_id = f"legend{legend_number}" else: legend_number = 1 legend_id = "legend" kwargs.setdefault("legend_y", {1: 1, 2: 0, 3: 0.5}[legend_number]) kwargs["legend_title_text"] = title legend_kwargs = kwargs.pop(legend_id, {}).copy() kwargs_list = list(kwargs.items()) for key, value in kwargs_list: if key.startswith("legend"): kwargs.pop(key) legend_kwargs[key[len("legend_") :]] = value kwargs[legend_id] = legend_kwargs if visual_kwargs is None: visual_kwargs = {} else: visual_kwargs = visual_kwargs.copy() if visual_type == "line": visual_fun = figure.add_scatter kwarg_list = [dealiase_line_kwargs(**kws) for kws in kwarg_list] mode = "lines" visual_kwargs.setdefault("line_color", "black") else: raise NotImplementedError("Only line type legends supported for now") if update_visuals: for group, viz_data in plot_collection.viz.children.items(): if group in {"plot", "row_index", "col_index"}: continue viz_ds = viz_data.dataset if any((d not in viz_ds.dims) and (d != "__variable__") for d in legend_dim): continue for var_name, sel, _ in xarray_sel_iter(viz_ds, skip_dims={}): target_viz = viz_ds[var_name].sel(sel).item() if target_viz is None: continue if not isinstance(target_viz, (Scatter, Bar)): break target_plot = plot_collection.get_target(var_name, sel) if isinstance(target_plot, xr.DataArray): target_plot = target_plot.data else: target_plot = [target_plot] trace_matcher = partial(_trace_matcher, target_viz=target_viz) for element in target_plot: element.update_traces( selector=trace_matcher, **plot_collection.get_aes_kwargs(["legendgroup"], var_name, sel), ) for kws, label in zip(kwarg_list, label_list): # plotly allow passing arguments as `line={key: value}` or as `line_key=value` directly # the 2nd option allows for more user flexibility in overriding or extending kwargs kws = {f"line_{key}" if key in LINE_SUBKEYS else key: value for key, value in kws.items()} visual_fun( x=[None], y=[None], name=str(label), mode=mode, showlegend=True, legend=legend_id, **{**visual_kwargs, **kws}, ) figure.update_layout(showlegend=True, **kwargs) return figure.layout.legend