import torch
import torch.nn as nn
import torch.nn.functional as F
import exp_utils as PQ
import rl_utils
import pytorch_lightning as pl


class TransitionModel(pl.LightningModule):

    class FLAGS(PQ.BaseFLAGS):
        batch_size = 256
        weight_decay = 0.000075
        lr = 0.001
        mul_std = 1

    def __init__(self, dim_state, normalizer, n_units, *, name=''):
        super().__init__()
        self.dim_state = dim_state
        self.normalizer = normalizer
        self.net = rl_utils.MLP(n_units, activation=nn.SiLU)
        self.max_log_std = nn.Parameter(torch.full([dim_state], 0.5), requires_grad=True)
        self.min_log_std = nn.Parameter(torch.full([dim_state], -10.), requires_grad=True)
        self.training_loss = 0.
        self.val_loss = 0.
        self.name = name
        self.mul_std = self.FLAGS.mul_std

    def forward(self, states, actions, det=True):
        output = self.net(self.normalizer(states), actions)
        mean, log_std = output[..., :self.dim_state], output[..., self.dim_state:]
        if self.mul_std:
            mean = mean * self.normalizer.std
        mean = mean + states
        # mean = mean + states
        if det:
            return mean
        log_std = self.max_log_std - F.softplus(self.max_log_std - log_std)
        log_std = self.min_log_std + F.softplus(log_std - self.min_log_std)
        return torch.distributions.Normal(mean, log_std.exp())

    def log_std_loss(self):
        return 0.001 * (self.max_log_std - self.min_log_std).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.FLAGS.lr, weight_decay=self.FLAGS.weight_decay)
        return optimizer

    def training_step(self, batch, batch_idx):
        predictions: torch.distributions.Normal = self(batch['state'], batch['action'], det=False)
        targets = batch['next_state']
        loss = -predictions.log_prob(targets).mean() + self.log_std_loss()
        self.log(f'{self.name}/training_loss', loss.item(), on_step=False, on_epoch=True)
        return {
            'loss': loss,
        }


class ModelTrainer(nn.Module, rl_utils.BaseTrainer):
    n_batches_per_epoch: int

    class FLAGS(PQ.BaseFLAGS):
        batch_size = 256
        weight_decay = 0.00075
        lr = 0.001

    def __init__(self, model: TransitionModel, buf: rl_utils.TorchReplayBuffer, buf_dev, *, device, name):
        super().__init__()
        self.model = model
        self.buf = buf
        self.buf_dev = buf_dev
        self.name = name

        # TODO: different weight decay for different block
        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.FLAGS.lr, weight_decay=self.FLAGS.weight_decay)
        self.init_trainer(device=device)
        self.training_loss = PQ.utils.AverageMeter()
        assert False

    def configure_train_dataloader(self):
        while True:
            yield self.buf.sample(self.FLAGS.batch_size)

    def training_step(self, batch, batch_idx):
        predictions: torch.distributions.Normal = self.model(batch['state'], batch['action'], det=False)
        targets = batch['next_state']
        loss = -predictions.log_prob(targets).mean()
        loss = loss + self.model.log_std_loss()
        # loss = ((predictions - targets) / self.weight).pow(2).mean()
        # loss = F.mse_loss(predictions, targets)

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), 10)
        self.optimizer.step()
        return {'loss': loss.item()}

    def post_step(self, output):
        self.training_loss += output['loss']
        if self.n_batches % 1_000 == 0:
            PQ.log.info(f"[{self.name}] # iter {self.n_batches}: training loss = {self.training_loss.mean:.6f}")
            PQ.writer.add_scalar(f'model/{self.name}/training_loss', self.training_loss.mean, global_step=self.n_batches)

    def validate(self):
        if self.n_batches % 5_000 == 0:
            buf = self.buf_dev
            nll = -self.model(buf.state.to(self.device), buf.action.to(self.device), det=False)\
                .log_prob(buf.next_state.to(self.device))
            PQ.log.info(f"[{self.name}] dev NLL: mean = {nll.mean().item():.6f}, "
                        f"max = {nll.max().item():.6f}, median = {nll.median().item():.6f}")
            PQ.writer.add_scalar(f'model/{self.name}/test_loss', nll.mean().item(), global_step=self.n_batches)


def train_models(model_trainers, n_steps=1):
    assert False
    for model_trainer in model_trainers:
        # model_trainer.model.requires_grad_(True)
        for _ in range(n_steps):
            model_trainer.step()
            # model_trainer.validate()
        # model_trainer.model.requires_grad_(False)
