"""
Helper funcs for plotly figures
"""
import numpy as np
import pandas as pd
import warnings
try:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
except:
warnings.warn("'plotly' not available.", RuntimeWarning)
from collections import defaultdict
from collections.abc import Iterable
from functools import partial
from .. import _helper
from .base import color_std, plotly_default_colors
[docs]def add_offset(x0, xf, offset=0.05):
"""x0 (xf) == lower (upper) limit for the axis range."""
inverse_transform = lambda *xs: [(xf-x0)*x + x0 for x in xs]
return inverse_transform(-offset, 1+offset)
[docs]def get_common_range(fig, axes=["x", "y"], offset=[0.05, 0.05]):
data = defaultdict(list)
for plot in fig.data:
for ax in axes:
if hasattr(plot, f"error_{ax}") and getattr(plot, f"error_{ax}").array is not None:
additions = [np.array([*plot[f"error_{ax}"]["array"]]), -np.array([*plot[f"error_{ax}"]["array"]])]
else:
additions = [0]
for addition in additions:
try:
arr = (plot[ax] + addition)[~np.isnan(plot[ax])]
except:
continue
arr_min, arr_max = arr.min(), arr.max()
data[f"{ax}-min"].append(arr_min)
data[f"{ax}-max"].append(arr_max)
for k, v in data.items():
func = min if "min" in k else max
data[k] = func(v)
ranges = {ax: add_offset(data[f"{ax}-min"], data[f"{ax}-max"], offset=off) for ax, off in zip(axes, offset)}
return ranges
[docs]def get_nplots(fig):
return sum(1 for x in fig.layout if "xaxis" in x)
[docs]def mod_delete_axes(fig, axes=["x", "y"]):
non_visible_axes_specs = dict(visible=False, showgrid=False, zeroline=False)
return {f"{ax}axis{i}": non_visible_axes_specs for ax in axes for i in [""] + [*range(1, get_nplots(fig) + 1)]}
[docs]def get_mod_layout(key, val=None):
def mod_layout(fig, val, axes=["x","y"]):
if isinstance(val, Iterable) and not isinstance(val, str):
return {"{}axis{}_{}".format(ax, i, key): v for (ax, v) in zip(axes, val) for i in [""] + [*range(1, get_nplots(fig) + 1)]}
else:
return {"{}axis{}_{}".format(ax, i, key): val for ax in axes for i in [""] + [*range(1, get_nplots(fig) + 1)]}
if val is None:
return mod_layout
else:
def mod_layout_fixed_val(fig, axes=["x", "y"]):
return mod_layout(fig, val, axes)
return mod_layout_fixed_val
mod_dashes = partial(_helper.sequence_or_stream, ["solid", "dash", "dot"])
mod_ticksize = get_mod_layout("tickfont_size")
mod_logaxes = get_mod_layout("type", "log")
mod_expfmt = get_mod_layout("exponentformat", "power")
mod_range = get_mod_layout("range")
mod_logaxes_expfmt = lambda fig, axes=["x", "y"]: {**mod_logaxes(fig, axes=axes), **mod_expfmt(fig, axes=axes)}
[docs]def mod_common_range(fig, axes=["x", "y"], **kwargs):
return mod_range(fig, val=get_common_range(fig, axes=axes, **kwargs), axes=axes)
[docs]def get_subplots(cols, rows=1, horizontal_spacing=0.03, vertical_spacing=0.03, height=None, width=2500, ticksize=32, font_size=40, font_family="sans-serif",
hovermode=False, delete_axes=False, shared_xaxes=True, shared_yaxes=True, layout_kwargs={},
**make_subplots_kwargs):
height = 800*rows if height is None else height
fig = make_subplots(figure=go.Figure(layout=dict(margin=dict(l=100, r=20, b=80, t=60, pad=1), height=height, width=width)),
shared_yaxes=shared_yaxes, shared_xaxes=shared_xaxes,
horizontal_spacing=horizontal_spacing, vertical_spacing=vertical_spacing, rows=rows, cols=cols,
**make_subplots_kwargs
)
fig.for_each_annotation(lambda a: a.update(font={'size':font_size, "family":font_family}))
fig.update_layout(**mod_ticksize(fig, val=ticksize), legend_font_size=font_size, hovermode=hovermode, **layout_kwargs)
if delete_axes:
fig.update_layout(**mod_delete_axes(fig), margin=dict(l=0, t=0, b=0, r=0), paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)")
return fig
[docs]def transparent_colorscale(fig, threshold=1e-10):
"""Values below threshold are invisible."""
colorscale = fig.layout["coloraxis"]["colorscale"]
low_limit = colorscale[0]
new_low_limit = (threshold, low_limit[1])
new_colorscale = ((0, 'rgba(0,0,0,0)'), new_low_limit, *colorscale[1:])
return new_colorscale
[docs]def multiindex_to_label(i, depth=2):
return [i.get_level_values(k).to_list() for k in range(depth)]
[docs]def set_multicategory_from_df(fig, df):
fig.update_layout(xaxis_type="multicategory", yaxis_type="multicategory")
fig.data[0]["x"] = multiindex_to_label(df.columns)
fig.data[0]["y"] = multiindex_to_label(df.index)
return
[docs]def CI_plot(x, y, CI, label=None, width=0.05, ms=10, color='rgba(255, 127, 14, 0.3)', fig=None, x_title=None, y_title=None):
"""
Box plot where the box corresponds to the CI.
Attributes:
- x: x coordinate for the CI
- y: value of the magnitude for the sample. Example: the mean if CI is a CI for the mean.
- CI: Confidence interval for y.
"""
if fig is None:
fig = get_figure(xaxis_title=x_title, yaxis_title=y_title)
for i, (ci, x_val, ci_stat) in enumerate(zip(CI, x, y)):
fig.add_trace(go.Scatter(x=[x_val]*2, y=ci[::-1], showlegend=False, mode="markers",
marker=dict(color=color, symbol=["arrow-bar-down", "arrow-bar-up"], size=ms, line=dict(color="gray", width=2))
))
fig.add_shape(type="rect", xref="x", yref="y", line=dict(color="gray",width=3), fillcolor=color, x0=i-width, y0=ci[0], x1=i+width, y1=ci[1])
fig.add_shape(type="line", xref="x", yref="y", line=dict(color="gray", width=4), x0=i-width, y0=ci_stat, x1=i+width, y1=ci_stat)
if label is not None:
yrange = [*get_common_range(fig, axes=["y"]).values()][0]
fig.add_trace(go.Scatter(x=[1000], y=[1000], mode="markers", name=label, showlegend=True,
marker=dict(symbol="square", color=color, size=22), line=dict(color="gray", width=2)))
fig.update_layout(**mod_range(fig, ([-0.25, len(x)-0.75], yrange)))
return fig
[docs]def permtest_plot(df, H1="", colorscale="Inferno", log=True, height=800, width=1000, font_size=40, bar_len=0.9, bar_x=0.95, bar_thickness=100):
"""H1 should not contain latex code. Use unicode and HTML for super/sub-indices."""
if log:
df = np.log10(df)
zmin, zmax = np.log10(0.05), 0
legtitle = "log<sub>10</sub>P-value"
else:
zmin, zmax = None, None
legtitle = "P-value"
fig = px.imshow(df, zmin=zmin, zmax=zmax, color_continuous_scale=colorscale)
fig.update_layout(coloraxis_colorbar=dict(len=bar_len, x=bar_x, title=f"{legtitle}<br>H<sub>1</sub>: {H1}", thickness=bar_thickness),
height=height, width=width, font_size=font_size, hovermode=False,
margin=dict(l=0, b=0, t=0, r=0)
)
return fig
[docs]def violin(df, CI=None, CI_line="mean", **CI_kwargs):
"""
Violin plot including optionally the CI.
Attributes:
- df: melted DataFrame. Contains only two columns: variable name (x) and value (y).
The column names set the OX and OY labels.
"""
x, y = df.columns
fig = get_figure(xaxis_title=x, yaxis_title=y)
fig.add_trace(go.Violin(x=df[x], y=df[y], showlegend=False))
if CI is not None:
fig = CI_plot(df[x].unique(), getattr(df.groupby(x), CI_line)().values.squeeze(), CI, fig=fig, **CI_kwargs)
return fig