from abc import ABC
from argparse import ArgumentParser
from typing import Optional, Union

from pytorch_lightning import LightningModule
from torch.optim import Adam, Optimizer
import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from common import Generator

from geomloss import SamplesLoss

from itertools import chain


class SinkhornGN(LightningModule, ABC):

    def __init__(self, param_fn, graph, var_dims, n_latent, latent_dim, n_hidden, n_layers,
                 lr: float = 0.0002, lagrange_lr: float = 0.01, binary_keys: list = None,
                 cost_func=None, tol_coeff: float = 1.15, diameter: float = None,
                 upper_bounds=None, lower_bounds=None, ade_d1_value=None, ade_d0_value=None,
                 do_var=None, target_var=None, noise=None,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.param_fn = param_fn
        self.noise = noise
        self.graph = graph
        self.var_dims = var_dims
        self.n_latent = n_latent
        self.latent_dim = latent_dim
        self.lr = lr
        self.lagrange_lr = lagrange_lr
        self.cost_func = cost_func
        self.tol_coeff = tol_coeff
        self.tol = 0.
        self.ade_d1_value = ade_d1_value
        self.ade_d0_value = ade_d0_value
        self.do_var = do_var
        self.target_var = target_var
        # self.binary = binary

        Gen = Generator
        generator_params = {'latent_dim': self.latent_dim, 'graph': self.graph, 'var_dims': self.var_dims,
                            'n_hidden': n_hidden, 'n_layers': n_layers, 'upper': upper_bounds,
                            'lower': lower_bounds, 'binary_keys': binary_keys}

        # networks
        self.generator_min = Gen(**generator_params)
        self.generator_max = Gen(**generator_params)

        self.samples_loss = SamplesLoss('sinkhorn', p=1, blur=0.01, scaling=0.9, diameter=diameter,
                                        backend='tensorized')
        self.lagrangian_min, self.lagrangian_max = nn.Parameter(torch.ones(1)), nn.Parameter(torch.ones(1))

        self.example_input_array = torch.zeros(2, self.latent_dim * self.n_latent), torch.zeros(2, sum(self.var_dims))
        self.metrics = []

        self.best_dist_min = np.inf
        self.best_dist_max = np.inf
        self.pre_train = True
        self.flag = True

    def forward(self, z, data):
        return torch.cat([self.generator_min(z, data), self.generator_max(z, data)], dim=1)

    def _sample_noise(self, imgs):
        if self.noise is None or self.noise == 'normal':
            z = torch.randn(imgs.shape[0], self.latent_dim * self.n_latent)
        else:
            z = torch.rand(imgs.shape[0], self.latent_dim * self.n_latent)

        z = z.type_as(imgs)
        return z

    def _calc_loss(self, fake, real):
        fake[real.isnan()] = 0.
        real = real.nan_to_num(0.)
        total_loss = self.samples_loss(fake, real)
        return total_loss

    def _calc_ade(self, z, imgs):
        ades = []
        target_start = np.sum(self.var_dims[: self.target_var])
        target_end = np.sum(self.var_dims[: (self.target_var + 1)])

        for gen in [self.generator_min, self.generator_max]:
            vals = []
            for value in [self.ade_d0_value, self.ade_d1_value]:
                vals.append(
                    gen.do(z, x=torch.Tensor([value]).to(self.device), do_key=self.do_var, data=imgs)\
                        [:, target_start:target_end].mean().detach()
                )
            ades.append(vals[1] - vals[0])

        return ades[0], ades[1]

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, = batch

        z = self._sample_noise(imgs)

        # Lagrangians should be non-negative, so we project them on positive orthant
        # We also clip their max value for more stable optimization purposes.
        with torch.no_grad():
            self.lagrangian_min.copy_(self.lagrangian_min.data.clamp(min=0))
            self.lagrangian_max.copy_(self.lagrangian_max.data.clamp(min=0))

        # train generator
        if optimizer_idx == 0:
            fake_min = self.generator_min(z, imgs)
            min_dist = self._calc_loss(fake_min, imgs)
            g_min = self.lagrangian_min[0] * min_dist

            fake_max = self.generator_max(z, imgs)
            max_dist = self._calc_loss(fake_max, imgs)

            g_max = self.lagrangian_max[0] * max_dist

            param_min = torch.mean(self.param_fn(z, self.generator_min, self.device, data=imgs))
            param_max = torch.mean(self.param_fn(z, self.generator_max, self.device, data=imgs))
            self.log('param_min', param_min, on_step=False, on_epoch=True, prog_bar=True)
            self.log('param_max', param_max, on_step=False, on_epoch=True, prog_bar=True)

            if self.ade_d1_value is not None:
                min_ade, max_ade = self._calc_ade(z, imgs)
                self.log('min_ade', min_ade, on_step=False, on_epoch=True, prog_bar=True)
                self.log('max_ade', max_ade, on_step=False, on_epoch=True, prog_bar=True)
                self.log('ade_point_0', float(self.ade_d0_value), on_step=False, on_epoch=True, prog_bar=True)
                self.log('ade_point_1', float(self.ade_d1_value), on_step=False, on_epoch=True, prog_bar=True)

            self.log('dist_min', min_dist, on_step=False, on_epoch=True, prog_bar=True)
            self.log('dist_max', max_dist, on_step=False, on_epoch=True, prog_bar=True)
            self.log('tol', self.tol, on_step=False, on_epoch=True, prog_bar=True)
            return {'loss': g_min + g_max, 'dist_min': min_dist.detach(), 'dist_max': max_dist.detach()}

        # train lagrange multipliers of constraint that |distance| < tol
        elif optimizer_idx == 1:

            d_loss = 0
            for mode, generator, lagrn in [
                ('min', self.generator_min, self.lagrangian_min[0]),
                ('max', self.generator_max, self.lagrangian_max[0])]:
                fake = generator(z, imgs)
                dist = self._calc_loss(fake, imgs)
                constraint = lagrn * (dist - self.tol)
                d_loss += - constraint

                self.log('lagrangian' + mode, lagrn, on_step=False, on_epoch=True, prog_bar=True)
            return {'loss': d_loss}

        # optimize the ADE
        elif optimizer_idx == 2 and not self.pre_train:
            param_min = torch.mean(self.param_fn(z, self.generator_min, self.device, data=imgs))
            param_max = torch.mean(self.param_fn(z, self.generator_max, self.device, data=imgs))

            return {'loss': param_min - param_max}

    def validation_step(self, batch, batch_idx):
        imgs, = batch

        z = self._sample_noise(imgs)
        fake_min = self.generator_min(z, imgs)
        fake_max = self.generator_max(z, imgs)

        d_loss_min = self._calc_loss(fake_min, imgs)
        d_loss_max = self._calc_loss(fake_max, imgs)

        d_loss = (d_loss_max + d_loss_min) / 2.

        param_min = torch.mean(self.param_fn(z, self.generator_min, self.device, data=imgs))
        param_max = torch.mean(self.param_fn(z, self.generator_max, self.device, data=imgs))

        self.log('val_dist', d_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_param_min', param_min, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_param_max', param_max, on_step=False, on_epoch=True, prog_bar=True)

        return {'loss': d_loss}

    def training_epoch_end(self, outputs):
        metrics = self.trainer.callback_metrics
        dist_min = metrics['dist_min']
        dist_max = metrics['dist_max']
        self.best_dist_min = min(dist_min, self.best_dist_min)
        self.best_dist_max = min(dist_max, self.best_dist_max)
        self.log('best_dist', max(self.best_dist_max, self.best_dist_max), on_step=False, on_epoch=True, prog_bar=True)

    def configure_gradient_clipping(
        self,
        optimizer: Optimizer,
        optimizer_idx: int,
        gradient_clip_val: Optional[Union[int, float]] = None,
        gradient_clip_algorithm: Optional[str] = None,
    ):
        if optimizer_idx == 2 and not self.pre_train:
            self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val,
                                gradient_clip_algorithm=gradient_clip_algorithm)

    def configure_optimizers(self):

        lagrange_lr = 0.01 if self.pre_train else self.lagrange_lr
        opt_l = Adam([self.lagrangian_min, self.lagrangian_max], lr=lagrange_lr)

        if self.pre_train:
            opt_g = Adam(chain(self.generator_max.parameters(), self.generator_min.parameters()), lr=self.lr)
            lr_scheduler = ReduceLROnPlateau(opt_g, mode='min', factor=0.5,
                                             patience=20, threshold=0.0, threshold_mode='abs',
                                             cooldown=0, min_lr=0,
                                             eps=1e-08, verbose=False)

            return (
                {'optimizer': opt_g,
                 'lr_scheduler': {
                     'scheduler': lr_scheduler,
                     'interval': 'epoch',
                     'frequency': 1,
                     'monitor': 'val_dist',
                     'strict': True,
                 }
                 },
                {'optimizer': opt_l}
            )

        else:
            opt_g = Adam(chain(self.generator_max.parameters(), self.generator_min.parameters()), lr=self.lr)
            opt_ade = Adam(chain(self.generator_max.parameters(), self.generator_min.parameters()), lr=self.lr/2)
            return (
                {'optimizer': opt_g},
                {'optimizer': opt_l},
                {'optimizer': opt_ade}
            )

    def on_epoch_start(self):
        if not self.pre_train and self.flag:
            self.trainer.accelerator.setup_optimizers(self.trainer)
            self.flag = False

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--n_hidden', type=int, default=64)
        parser.add_argument('--n_layers', type=int, default=3)
        parser.add_argument('--latent_dim', type=int, default=1)
        parser.add_argument('--lr', type=float, default=1e-3, help="learning rate")
        parser.add_argument('--lagrange_lr', type=float, default=1e-1, help="learning rate")

        return parser
