from typing import List

import torch
from torch.distributions import *
import matplotlib.pyplot as plt
import seaborn as sns

from M_c import M_c
from M_x import M_x


def add_axes(ax, use_x_ticks=True, use_y_ticks=True, axis_range=None):
    """
    Adds x and y axes with an arrow to an existing plot
    :param ax: the Axis object of matplotlib
    :param use_x_ticks: whether to use x ticks
    :param use_y_ticks: whether to use y ticks
    :param axis_range: [xmin, xmax, ymin, ymax]
    :return:
    """
    ax.spines["left"].set_position("zero")
    ax.spines["left"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_position("zero")
    ax.spines["bottom"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_ticks_position("bottom")
    ax.yaxis.set_ticks_position("left")

    # make arrows
    xmin, xmax, ymin, ymax = axis_range
    ax.plot(
        (xmin, xmax - 0.05),
        (0, 0),
        ls="solid",
        lw=1,
        marker=">",
        color="black",
        markevery=[-1],
    )
    ax.plot(
        (0, 0),
        (0, ymax - 0.05),
        ls="solid",
        lw=1,
        marker="^",
        color="black",
        markevery=[-1],
    )

    if axis_range is not None:
        plt.axis(axis_range)

    if not use_x_ticks:
        ax.set_xticks([], [])

    if not use_y_ticks:
        ax.set_yticks([], [])

    plt.xlabel("x", loc="right")


def plot_gaussian_mixture(
    ax,
    gaussian_mean: torch.Tensor,
    gaussian_std: torch.Tensor,
    components_weights: torch.Tensor,
    min_x: float,
    max_x: float,
    steps: int,
    color: str = "black",
    label: str = None,
):
    """
    Plot a Gaussian mixture
    :param ax: the Axis object of matplotlib
    :param gaussian_mean: tensor of shape (M,1)
    :param gaussian_std: tensor of shape (M,1)
    :param components_weights: tensor of shape (M)
    :param min_x: minimum x for which to plot
    :param max_x: maximum x for which to plot
    :param steps: number of points to plot between min_x and max_x
    :param color: color of line (matplotlib)
    :param label: label of the curve in the legend
    :return:
    """
    X = torch.linspace(start=min_x, steps=steps, end=max_x).unsqueeze(
        1
    )  # (N,1)

    mix = Categorical(components_weights)
    comp = Normal(gaussian_mean.squeeze(1), gaussian_std.squeeze(1))
    gmm = MixtureSameFamily(mix, comp)

    Y = gmm.log_prob(X).exp()

    ax.plot(X, Y, c=color, label=label)


def plot_M_x(
    ax,
    epsilon: torch.double,
    gaussian_mean: torch.Tensor,
    gaussian_std: torch.Tensor,
    p_m_given_c: torch.Tensor,
    p_c,
    normalize: bool,
    min_x: float,
    max_x: float,
    steps: int,
    colors: List[str] = ["black", "black"],
    labels: List[str] = None,
):
    """
    Plot the (normalized) M_x(c) curves
    :param ax: the Axis object of matplotlib
    :param epsilon: the length of the interval
    :param gaussian_mean: tensor of shape (C, M,1)
    :param gaussian_std: tensor of shape (C, M,1)
    :param p_m_given_c: tensor of shape (C, M)
    :param p_c: tensor of shape (C)
    :param normalize: whether to normalize M_x or not
    :param min_x: minimum x for which to plot
    :param max_x: maximum x for which to plot
    :param steps: number of points to plot between min_x and max_x
    :param colors: list of colors, one per class (matplotlib)
    :param labels: labels of the curves in the legend
    :return:
    """
    X = torch.linspace(start=min_x, steps=steps, end=max_x).unsqueeze(
        1
    )  # (N,1)

    m_x = M_x(
        X,
        epsilon,
        p_c,
        p_m_given_c,
        gaussian_mean,
        gaussian_std,
        normalize=normalize,
    )

    for c in range(m_x.shape[1]):
        ax.plot(X, m_x[:, c], c=colors[c], ls="--", label=labels[c])


def plot_M_c(
    ax,
    epsilon: torch.double,
    gaussian_mean: torch.Tensor,
    gaussian_std: torch.Tensor,
    p_m_given_c: torch.Tensor,
    p_c,
    normalize: bool,
):
    """
    Plot the (normalized) M_c'(c) matrix
    :param ax: the Axis object of matplotlib
    :param epsilon: the length of the interval
    :param gaussian_mean: tensor of shape (C, M,1)
    :param gaussian_std: tensor of shape (C, M,1)
    :param p_m_given_c: tensor of shape (C, M)
    :param p_c: tensor of shape (C)
    :param normalize: whether to normalize M_x or not
    :return:
    """
    m_c = M_c(
        epsilon,
        p_c,
        p_m_given_c,
        gaussian_mean,
        gaussian_std,
        normalize=normalize,
    )

    if normalize:
        sns.heatmap(m_c, cmap="YlGnBu", ax=ax, vmin=0.0, vmax=1.0)
    else:
        sns.heatmap(m_c, cmap="YlGnBu", ax=ax)
