from abc import ABC

import math

import copy
from itertools import chain

import numpy as np
import torch
import torch.nn as nn
from pytorch_lightning.callbacks import Callback, EarlyStopping, ProgressBar
from torch.optim import Optimizer, Adam, SGD

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar

import torch.nn.functional as F


def block(n_hidden, n_layers):
    layers = []
    for _ in range(n_layers):
        layers += [nn.Linear(n_hidden, n_hidden), nn.LeakyReLU(0.2, inplace=True)]
    return layers


class Scale(nn.Module):
    def __init__(self, k):
        super().__init__()
        self.k = k

    def forward(self, x):
        return torch.multiply(x, self.k)


class GumbleMaxBinary(nn.Module):
    def __init__(self, tau):
        super().__init__()
        self.tau = tau

    def forward(self, x):
        return F.gumbel_softmax(x, tau=self.tau, hard=True)


class Generator(nn.Module):
    def __init__(self, latent_dim, graph, var_dims, n_hidden, n_layers, output_x: int = 1,
                 binary_keys=None, upper=None, lower=None):
        super().__init__()
        self.latent_dim = latent_dim
        self.graph = graph
        self.var_dims = var_dims

        self.upper = upper
        self.lower = lower

        self.model_dict = {}

        for key, value in graph.items():
            inputs = np.array(value).astype(int)
            if len(inputs) == 0:
                continue
            observed = inputs[inputs >= 0]
            latent = abs(inputs[inputs < 0])
            input_dim = latent_dim * len(latent) + np.sum(var_dims[observed])
            if n_layers == 0:
                if binary_keys is None or key not in binary_keys:
                    last_layer = [nn.Linear(input_dim, output_x * var_dims[key])]

                else:
                    last_layer = [nn.Linear(input_dim, var_dims[key]),
                                  nn.Sigmoid(),
                                  GumbleMaxBinary(0.1)]

                self.model_dict[key] = nn.Sequential(*last_layer)
            else:
                if binary_keys is None or key not in binary_keys:
                    last_layer = [nn.Linear(n_hidden, output_x * var_dims[key])]

                else:
                    last_layer = [nn.Linear(n_hidden, var_dims[key]),
                                  nn.Sigmoid(),
                                  GumbleMaxBinary(0.1)]
                self.model_dict[key] = nn.Sequential(nn.Linear(input_dim, n_hidden), nn.LeakyReLU(0.2, inplace=True),
                                                     *block(n_hidden, n_layers - 1),
                                                     *last_layer, )
        self.models = nn.ModuleList([model for key, model in self.model_dict.items()])

    def _helper_forward(self, z, x=None, do_key=None, data=None):
        var = {}
        for key, value in self.graph.items():
            if (do_key is not None) and (key == do_key):
                if x.shape[0] == self.var_dims[do_key]:
                    var[key] = x.reshape((self.var_dims[do_key], -1)).repeat((1, z.shape[0])).t()
                elif x.shape[0] == z.shape[0]:
                    var[key] = x.reshape(-1, x.shape[1])
                else:
                    raise Exception(f'wrong do-var dim. z: {z.shape}, x: {x.shape}')
            else:
                inputs = np.array(value).astype(int)
                if len(inputs) == 0:
                    start = np.sum(self.var_dims[: key])
                    end = np.sum(self.var_dims[: (key + 1)])
                    var[key] = data[:, start:end]
                else:
                    latent = tuple(z[:, (i - 1) * self.latent_dim:i * self.latent_dim] for i in abs(inputs[inputs < 0]))
                    observed = tuple(var[i] for i in inputs[inputs >= 0]) # TODO not sure
                    var[key] = self.model_dict[key](torch.cat(latent + observed, dim=1))
                    if self.lower is not None and self.upper is not None and key in self.lower:
                        with torch.no_grad():   # more stable results
                            lower, upper = self.lower[key].type_as(var[key]), self.upper[key].type_as(var[key])
                            var[key].copy_(var[key].data.clamp(min=lower, max=upper))
        observed = tuple(var[i] for i in range(len(self.var_dims)))
        return torch.cat(observed, dim=1)

    def forward(self, z, data):
        return self._helper_forward(z, data=data)

    def do(self, z, x, do_key, data):
        return self._helper_forward(z, x=x, do_key=do_key, data=data)


class Discriminator(nn.Module):
    def __init__(self, input_dim, n_hidden, n_layers):
        super().__init__()  # (Discriminator, self).__init__()
        if n_layers == 0:
            self.model = nn.Sequential(nn.Linear(input_dim, 1))
        else:
            self.model = nn.Sequential(
                nn.Linear(input_dim, n_hidden), nn.LeakyReLU(0.2, inplace=True),
                *block(n_hidden, n_layers - 1),
                nn.Linear(n_hidden, 1),
            )

    def forward(self, X):
        return self.model(X)


class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_epoch_end(self, trainer, pl_module):
        each_me = copy.deepcopy(trainer.callback_metrics)
        metric = dict([(k,each_me[k].cpu().numpy()) for k in each_me])
        self.metrics.append(metric)


class LagrangeStart(EarlyStopping):
    def on_validation_end(self, trainer, pl_module) -> None:
        pass

    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not pl_module.pre_train:
            return
        done, _ = self._run_early_stopping_check(trainer)
        if done:
            pl_module.pre_train = False
            metrics = trainer.callback_metrics
            dist_min = metrics['dist_min']
            dist_max = metrics['dist_max']
            best2 = max(dist_max, dist_min)
            best = max(pl_module.best_dist_min, pl_module.best_dist_max)
            pl_module.tol = min(best, best2) * pl_module.tol_coeff

    def _run_early_stopping_check(self, trainer: "pl.Trainer") -> tuple:
        """
        Checks whether the early stopping condition is met
        and if so tells the trainer to stop the training.
        """
        logs = trainer.callback_metrics

        if trainer.fast_dev_run or not self._validate_condition_metric(  # disable early_stopping with fast_dev_run
            logs
        ):  # short circuit if metric not present
            return False, None

        current = logs.get(self.monitor)

        should_stop, reason = self._evaluate_stopping_criteria(current)

        if self.verbose and should_stop:
            self._log_info(trainer, "Lagrange multiplier starts")

        return should_stop, current


class LitProgressBar(TQDMProgressBar):
    def init_validation_tqdm(self):
        """Override this to customize the tqdm bar for validation."""
        # The main progress bar doesn't exist in `trainer.validate()`
        has_main_bar = self.trainer.state.fn != "validate"
        from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm
        import sys
        bar = Tqdm(
            desc="Validating",
            position=(2 * self.process_position + has_main_bar),
            disable=True,
            leave=not has_main_bar,
            dynamic_ncols=True,
            file=sys.stdout,
        )
        return bar


class SimpleRegression(pl.LightningModule, ABC):
    def __init__(self, hidden_dim, lr, do_var, target_var, var_dims):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(var_dims[do_var], hidden_dim),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
        )

        self.do_start = np.sum(var_dims[: do_var])
        self.do_end = np.sum(var_dims[: (do_var + 1)])

        self.target_start = np.sum(var_dims[: target_var])
        self.target_end = np.sum(var_dims[: (target_var + 1)])
        self.lr = lr

        self.loss = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x,  = batch
        inp, out = x[:, self.do_start:self.do_end], x[:, self.target_start:self.target_end]
        loss = self.loss(self(inp), out)
        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return {'loss': loss}

    def configure_optimizers(self):
        return SGD(self.parameters(), lr=self.lr)


class VAEIv(pl.LightningModule, ABC):
    def __init__(self, var_dims, x_dim=1, z_dim=1, y_dim=1, latent_dim=32, lr=1e-3):
        super().__init__()

        self.save_hyperparameters()
        self.lr = lr

        self.x_index = (np.sum(var_dims[:1]), np.sum(var_dims[:2]))
        self.z_index = (np.sum(var_dims[:0]), np.sum(var_dims[:1]))
        self.y_index = (np.sum(var_dims[:2]), np.sum(var_dims[:3]))

        # encoder, decoder
        self.encoder = nn.Sequential(
            nn.Linear(x_dim+z_dim+y_dim, latent_dim//2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(latent_dim//2, latent_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.decoder_x = nn.Sequential(
            nn.Linear(latent_dim+z_dim, 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(8, 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(4, x_dim)
        )


        self.decoder_y = nn.Sequential(
            nn.Linear(latent_dim+x_dim, 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(8, 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(4, y_dim)
        )
        # distribution parameters
        self.fc_mu = nn.Linear(latent_dim, latent_dim)
        self.fc_var = nn.Linear(latent_dim, latent_dim)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def gaussian_likelihood(self, x_hat, logscale, x):
        scale = torch.exp(logscale)
        mean = x_hat
        dist = torch.distributions.Normal(mean, scale)

        # measure prob of seeing image under p(x|z)
        log_pxz = dist.log_prob(x)
        return log_pxz.sum(dim=-1)

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl

    def training_step(self, batch, batch_idx):
        data, = batch

        # encode x to get the mu and variance parameters
        data_encoded = self.encoder(data)
        mu, log_var = self.fc_mu(data_encoded), self.fc_var(data_encoded)

        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        u = q.rsample()

        # decoded
        x_hat = self.decoder_x(torch.cat((u, data[:, self.z_index[0]: self.z_index[1]]), dim=1))
        y_hat = self.decoder_y(torch.cat((u, x_hat), dim=1))

        # reconstruction loss
        x_loss = self.gaussian_likelihood(x_hat, self.log_scale, data[:, self.x_index[0]: self.x_index[1]])
        y_loss = self.gaussian_likelihood(y_hat, self.log_scale, data[:, self.y_index[0]: self.y_index[1]])
        recon_loss = x_loss + y_loss

        # kl
        kl = self.kl_divergence(u, mu, std)

        # elbo
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        self.log_dict({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean()
        }, prog_bar=True, on_epoch=True, on_step=False)

        return elbo