import numpy as np
import torch
from tqdm import trange
from torch import nn
import time
import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)

class EncoderTrainer():
    def __init__(
        self,
        en_model,
        et_optimizer,
        w,
        w_std,
        w_optimizer,
        batch_size,
        get_batch,
        device,
        repre_type,
        phi_norm_loss_ratio=0.1
    ):
        self.en_model = en_model
        self.et_optimizer = et_optimizer
        self.repre_type = repre_type
        if (self.repre_type == 'vec') or (self.repre_type == 'vq_vec'):
            self.w = w
        elif self.repre_type == 'dist':
            self.w = w
            self.w_std = w_std
        self.w_optimizer = w_optimizer
        self.batch_size = batch_size
        self.get_batch = get_batch
        self.diagnostics = dict()
        self.device = device
        self.phi_loss = nn.MSELoss()
        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
        self.phi_norm_loss_ratio = phi_norm_loss_ratio
        self.count = 0

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):
        regress_losses, phi_norm_losses, w_losses, positive_kls, negative_kls, kl_losses = [], [], [], [], [], []
        logs = dict()
        train_start = time.time()

        self.en_model.train()
        for i in trange(num_steps, desc='train_step', smoothing=0.1):
            regress_loss, phi_norm_loss, w_loss, positive_kl, negative_kl, kl_loss = self.train_step()
            regress_losses.append(regress_loss)
            phi_norm_losses.append(phi_norm_loss)
            w_losses.append(w_loss)
            positive_kls.append(positive_kl)
            negative_kls.append(negative_kl)
            kl_losses.append(kl_loss)

        self.count += 1

        logs['training/time'] = time.time() - train_start
        logs['training/pref_loss_mean'] = np.mean(regress_losses)
        logs['training/pref_loss_std'] = np.std(regress_losses)
        logs['training/phi_norm_loss_mean'] = np.mean(phi_norm_losses)
        logs['training/phi_norm_loss_std'] = np.std(phi_norm_losses)
        logs['training/w_loss_mean'] = np.mean(w_losses)
        logs['training/w_loss_std'] = np.std(w_losses)
        logs['training/positive_kl_mean'] = np.mean(positive_kls)
        logs['training/negative_kl_mean'] = np.mean(negative_kls)
        logs['training/kl_loss_mean'] = np.mean(kl_losses)
        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def train_step(self):
        states_1, actions_1, rtg_1, timesteps_1, attention_mask_1 = self.get_batch(self.batch_size)
        states_2, actions_2, rtg_2, timesteps_2, attention_mask_2 = self.get_batch(self.batch_size)

        margin = 0
        lb = (rtg_1[:,-1,0] - rtg_2[:,-1,0]) >= margin
        rb = (rtg_2[:,-1,0] - rtg_1[:,-1,0]) > margin

        if self.repre_type == 'vq_vec':
            phi_e_1, phi_q_1, phi_1 = self.en_model.forward(states_1, actions_1, timesteps_1, attention_mask_1)
            phi_e_2, phi_q_2, phi_2 = self.en_model.forward(states_2, actions_2, timesteps_2, attention_mask_2)
            vq_loss_1 = torch.mean((phi_q_1.detach()-phi_e_1)**2) + 0.25 * torch.mean((phi_q_1 - phi_e_1.detach()) ** 2)
            vq_loss_2 = torch.mean((phi_q_2.detach()-phi_e_2)**2) + 0.25 * torch.mean((phi_q_2 - phi_e_2.detach()) ** 2)
            vq_loss = vq_loss_1 + vq_loss_2
            positive = torch.cat((phi_1[lb], phi_2[rb]), 0)
            negative = torch.cat((phi_2[lb], phi_1[rb]), 0)
            anchor = self.w.expand(positive.shape[0], -1).detach()
            pref_loss = self.triplet_loss(anchor, positive, negative)
            pref_loss += vq_loss
            phi_norm_loss = (self.phi_loss(torch.norm(phi_1, dim=1), torch.ones(self.batch_size).to(self.device))
                    + self.phi_loss(torch.norm(phi_2, dim=1), torch.ones(self.batch_size).to(self.device)))
        elif self.repre_type == 'dist':
            phi_1_mean, phi_1_std = self.en_model.forward(states_1, actions_1, timesteps_1, attention_mask_1) # (64,16),(64,16)
            phi_2_mean, phi_2_std = self.en_model.forward(states_2, actions_2, timesteps_2, attention_mask_2) # (64,16),(64,16)
            positive_mean = torch.cat((phi_1_mean[lb], phi_2_mean[rb]), 0) # (60, 16)
            negative_mean = torch.cat((phi_2_mean[lb], phi_1_mean[rb]), 0) # (60, 16)
            positive_std = torch.cat((phi_1_std[lb], phi_2_std[rb]), 0)
            negative_std = torch.cat((phi_2_std[lb], phi_1_std[rb]), 0)
            positive_dist = torch.distributions.MultivariateNormal(loc=positive_mean, covariance_matrix=torch.diag_embed(torch.exp(positive_std)))
            negative_dist = torch.distributions.MultivariateNormal(loc=negative_mean, covariance_matrix=torch.diag_embed(torch.exp(negative_std)))
            w_std = torch.clamp(self.w_std, min=-5, max=2)
            anchor_dist = torch.distributions.MultivariateNormal(loc=self.w, covariance_matrix=torch.diag_embed(torch.exp(w_std)))
            positive_kl = torch.distributions.kl.kl_divergence(anchor_dist, positive_dist).mean()
            negative_kl = torch.distributions.kl.kl_divergence(anchor_dist, negative_dist).mean()
            kl_loss = positive_kl + 1.0 / negative_kl
            anchor_mean = self.w.expand(positive_mean.shape[0], -1).detach()
            trip_loss = self.triplet_loss(anchor_mean, positive_mean, negative_mean)
            phi_norm_loss = self.phi_loss(torch.norm(phi_1_mean, dim=1), torch.ones(self.batch_size).to(self.device)) \
                    + self.phi_loss(torch.norm(phi_2_mean, dim=1), torch.ones(self.batch_size).to(self.device))
            pref_loss = trip_loss + kl_loss + self.phi_norm_loss_ratio * phi_norm_loss
        elif self.repre_type == 'vec':
            phi_1 = self.en_model.forward(states_1, actions_1, timesteps_1, attention_mask_1)
            phi_2 = self.en_model.forward(states_2, actions_2, timesteps_2, attention_mask_2)
            phi_norm_loss = (self.phi_loss(torch.norm(phi_1, dim=1), torch.ones(self.batch_size).to(self.device))
                    + self.phi_loss(torch.norm(phi_2, dim=1), torch.ones(self.batch_size).to(self.device)))
            positive = torch.cat((phi_1[lb], phi_2[rb]), 0)
            negative = torch.cat((phi_2[lb], phi_1[rb]), 0)
            anchor = self.w.expand(positive.shape[0], -1).detach()
            trip_loss = self.triplet_loss(anchor, positive, negative)
            pref_loss = trip_loss + self.phi_norm_loss_ratio * phi_norm_loss

        self.et_optimizer.zero_grad()
        pref_loss.backward()
        self.et_optimizer.step()

        if self.repre_type == 'dist':
            phi_1_mean, phi_1_std = self.en_model.forward(states_1, actions_1, timesteps_1, attention_mask_1) # (64,16),(64,16)
            phi_2_mean, phi_2_std = self.en_model.forward(states_2, actions_2, timesteps_2, attention_mask_2) # (64,16),(64,16)
            positive_mean = torch.cat((phi_1_mean[lb], phi_2_mean[rb]), 0) # (60, 16)
            negative_mean = torch.cat((phi_2_mean[lb], phi_1_mean[rb]), 0) # (60, 16)
            positive_std = torch.cat((phi_1_std[lb], phi_2_std[rb]), 0)
            negative_std = torch.cat((phi_2_std[lb], phi_1_std[rb]), 0)
            positive_dist = torch.distributions.MultivariateNormal(loc=positive_mean, covariance_matrix=torch.diag_embed(torch.exp(positive_std)))
            negative_dist = torch.distributions.MultivariateNormal(loc=negative_mean, covariance_matrix=torch.diag_embed(torch.exp(negative_std)))
            w_std = torch.clamp(self.w_std, min=-5, max=2)
            anchor_dist = torch.distributions.MultivariateNormal(loc=self.w, covariance_matrix=torch.diag_embed(torch.exp(w_std)))
            positive_kl = torch.distributions.kl.kl_divergence(anchor_dist, positive_dist).mean()
            negative_kl = torch.distributions.kl.kl_divergence(anchor_dist, negative_dist).mean()
            kl_loss = positive_kl + 1.0 / negative_kl
            anchor_mean = self.w.expand(positive_mean.shape[0], -1)
            trip_loss = self.triplet_loss(anchor_mean, positive_mean, negative_mean)
            w_loss = trip_loss + kl_loss
        elif self.repre_type == 'vec':
            phi_1 = self.en_model.forward(states_1, actions_1, timesteps_1, attention_mask_1)
            phi_2 = self.en_model.forward(states_2, actions_2, timesteps_2, attention_mask_2)
            positive = torch.cat((phi_1[lb], phi_2[rb]), 0)
            negative = torch.cat((phi_2[lb], phi_1[rb]), 0)
            anchor = self.w.expand(positive.shape[0], -1)
            w_loss = self.triplet_loss(anchor, positive, negative)
        self.w_optimizer.zero_grad()
        w_loss.backward()
        self.w_optimizer.step()

        # return pref_loss.detach().cpu().item(), phi_norm_loss.detach().cpu().item(), w_loss.detach().cpu().item(), positive_kl.detach().cpu().item(), negative_kl.detach().cpu().item(), kl_loss.detach().cpu().item()
        return pref_loss.detach().cpu().item(), phi_norm_loss.detach().cpu().item(), w_loss.detach().cpu().item(), 0.0, 0.0, 0.0
