"""
Helper funcs for plotly figures
"""
import numpy as np
import pandas as pd
import warnings
try:
    import plotly.express as px
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
except:
    warnings.warn("'plotly' not available.", RuntimeWarning)
from collections import defaultdict
from collections.abc import Iterable
from functools import partial
from .. import _helper
from .base import color_std, plotly_default_colors
[docs]def add_offset(x0, xf, offset=0.05):
    """x0 (xf) == lower (upper) limit for the axis range."""
    inverse_transform = lambda *xs: [(xf-x0)*x + x0 for x in xs]
    return inverse_transform(-offset, 1+offset) 
    
[docs]def get_common_range(fig, axes=["x", "y"], offset=[0.05, 0.05]):
    data = defaultdict(list)
    for plot in fig.data:
        for ax in axes:          
            if hasattr(plot, f"error_{ax}") and getattr(plot, f"error_{ax}").array is not None:
                additions = [np.array([*plot[f"error_{ax}"]["array"]]), -np.array([*plot[f"error_{ax}"]["array"]])] 
            else:
                additions = [0]
            for addition in additions:
                try:
                    arr = (plot[ax] + addition)[~np.isnan(plot[ax])]
                except:
                    continue
                arr_min, arr_max = arr.min(), arr.max()
                data[f"{ax}-min"].append(arr_min)
                data[f"{ax}-max"].append(arr_max)
    for k, v in data.items():
        func = min if "min" in k else max
        data[k] = func(v)
    ranges = {ax: add_offset(data[f"{ax}-min"], data[f"{ax}-max"], offset=off) for ax, off in zip(axes, offset)}
    return ranges 
[docs]def get_nplots(fig):
    return sum(1 for x in fig.layout if "xaxis" in x) 
[docs]def mod_delete_axes(fig, axes=["x", "y"]):
    non_visible_axes_specs = dict(visible=False, showgrid=False, zeroline=False) 
    return {f"{ax}axis{i}": non_visible_axes_specs for ax in axes for i in [""] + [*range(1, get_nplots(fig) + 1)]} 
[docs]def get_mod_layout(key, val=None):
    def mod_layout(fig, val, axes=["x","y"]):
        if isinstance(val, Iterable) and not isinstance(val, str):
            return {"{}axis{}_{}".format(ax, i, key): v for (ax, v) in zip(axes, val) for i in [""] + [*range(1, get_nplots(fig) + 1)]}
        else:
            return {"{}axis{}_{}".format(ax, i, key): val for ax in axes for i in [""] + [*range(1, get_nplots(fig) + 1)]}
    if val is None:
        return mod_layout
    else:
        def mod_layout_fixed_val(fig, axes=["x", "y"]):
            return mod_layout(fig, val, axes)
        return mod_layout_fixed_val 
mod_dashes           = partial(_helper.sequence_or_stream, ["solid", "dash", "dot"])
mod_ticksize         = get_mod_layout("tickfont_size")
mod_logaxes          = get_mod_layout("type", "log") 
mod_expfmt           = get_mod_layout("exponentformat", "power")
mod_range            = get_mod_layout("range")
mod_logaxes_expfmt   = lambda fig, axes=["x", "y"]: {**mod_logaxes(fig, axes=axes), **mod_expfmt(fig, axes=axes)}
[docs]def mod_common_range(fig, axes=["x", "y"], **kwargs):
    return mod_range(fig, val=get_common_range(fig, axes=axes, **kwargs), axes=axes) 
[docs]def get_subplots(cols, rows=1, horizontal_spacing=0.03, vertical_spacing=0.03, height=None, width=2500, ticksize=32, font_size=40, font_family="sans-serif",
                 hovermode=False, delete_axes=False, shared_xaxes=True, shared_yaxes=True, layout_kwargs={}, 
                 **make_subplots_kwargs):
    height = 800*rows if height is None else height
    fig = make_subplots(figure=go.Figure(layout=dict(margin=dict(l=100, r=20, b=80, t=60, pad=1), height=height, width=width)),
                        shared_yaxes=shared_yaxes, shared_xaxes=shared_xaxes,                        
                        horizontal_spacing=horizontal_spacing, vertical_spacing=vertical_spacing, rows=rows, cols=cols,
                        **make_subplots_kwargs
                       )
                    
    fig.for_each_annotation(lambda a: a.update(font={'size':font_size, "family":font_family}))
    fig.update_layout(**mod_ticksize(fig, val=ticksize), legend_font_size=font_size, hovermode=hovermode, **layout_kwargs)
    if delete_axes:
        fig.update_layout(**mod_delete_axes(fig), margin=dict(l=0, t=0, b=0, r=0), paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)")
    return fig 
[docs]def transparent_colorscale(fig, threshold=1e-10):
    """Values below threshold are invisible."""
    colorscale = fig.layout["coloraxis"]["colorscale"]
    low_limit = colorscale[0]
    new_low_limit = (threshold, low_limit[1])
    new_colorscale = ((0, 'rgba(0,0,0,0)'), new_low_limit, *colorscale[1:])
    return new_colorscale 
[docs]def multiindex_to_label(i, depth=2):
    return [i.get_level_values(k).to_list() for k in range(depth)] 
[docs]def set_multicategory_from_df(fig, df):
    fig.update_layout(xaxis_type="multicategory", yaxis_type="multicategory")
    fig.data[0]["x"] = multiindex_to_label(df.columns)
    fig.data[0]["y"] = multiindex_to_label(df.index)
    return 
[docs]def CI_plot(x, y, CI, label=None, width=0.05, ms=10, color='rgba(255, 127, 14, 0.3)', fig=None, x_title=None, y_title=None):
    """
    Box plot where the box corresponds to the CI.
    
    Attributes:
        - x:    x coordinate for the CI
        - y:    value of the magnitude for the sample. Example: the mean if CI is a CI for the mean.
        - CI:   Confidence interval for y.
    """
    if fig is None:
        fig = get_figure(xaxis_title=x_title, yaxis_title=y_title)
    for i, (ci, x_val, ci_stat) in enumerate(zip(CI, x, y)):
        fig.add_trace(go.Scatter(x=[x_val]*2, y=ci[::-1], showlegend=False, mode="markers",
                                 marker=dict(color=color, symbol=["arrow-bar-down", "arrow-bar-up"], size=ms, line=dict(color="gray", width=2))
                                ))
        fig.add_shape(type="rect", xref="x", yref="y", line=dict(color="gray",width=3), fillcolor=color, x0=i-width, y0=ci[0], x1=i+width, y1=ci[1])
        fig.add_shape(type="line", xref="x", yref="y", line=dict(color="gray", width=4),  x0=i-width, y0=ci_stat, x1=i+width, y1=ci_stat)
    if label is not None:
        yrange = [*get_common_range(fig, axes=["y"]).values()][0]
        fig.add_trace(go.Scatter(x=[1000], y=[1000], mode="markers", name=label, showlegend=True,
                                 marker=dict(symbol="square", color=color, size=22), line=dict(color="gray", width=2)))
        fig.update_layout(**mod_range(fig, ([-0.25, len(x)-0.75], yrange)))
    return fig 
[docs]def permtest_plot(df, H1="", colorscale="Inferno", log=True, height=800, width=1000, font_size=40, bar_len=0.9, bar_x=0.95, bar_thickness=100):
    """H1 should not contain latex code. Use unicode and HTML for super/sub-indices."""
    if log:
        df = np.log10(df)
        zmin, zmax = np.log10(0.05), 0
        legtitle = "log<sub>10</sub>P-value"
    else:
        zmin, zmax = None, None
        legtitle = "P-value"
    fig = px.imshow(df, zmin=zmin, zmax=zmax, color_continuous_scale=colorscale)
    fig.update_layout(coloraxis_colorbar=dict(len=bar_len, x=bar_x, title=f"{legtitle}<br>H<sub>1</sub>: {H1}", thickness=bar_thickness),
                      height=height, width=width, font_size=font_size, hovermode=False,
                      margin=dict(l=0, b=0, t=0, r=0)
                     )
    return fig 
    
[docs]def violin(df, CI=None, CI_line="mean", **CI_kwargs):
    """
    Violin plot including optionally the CI.
    
    Attributes:
        - df:   melted DataFrame. Contains only two columns: variable name (x) and value (y).
                                  The column names set the OX and OY labels.
    """
    x, y = df.columns
    fig = get_figure(xaxis_title=x, yaxis_title=y)
    fig.add_trace(go.Violin(x=df[x], y=df[y], showlegend=False))
    if CI is not None:
        fig = CI_plot(df[x].unique(), getattr(df.groupby(x), CI_line)().values.squeeze(), CI, fig=fig, **CI_kwargs)
    return fig