from typing import List, Dict

import torch
from disentangle.AbstractTrainer import AbstractLitModule
from losses import LpSimCLRLoss, UnifiedCLLoss, infonce_loss
from models.encoders import MLP
from utils import xavier_init
import numpy as np


class ContrastiveLitModule(AbstractLitModule):
    def __init__(
        self,
        state_dim: int,
        n_step: int,
        param_dim: int,
        n_views: int,
        learning_rate: float = 1e-5,
        eval_metrics=[],
        code_sharing: Dict[int,List[int]] = None,
        tau=0.1,
        factor_type="discrete",
        **kwargs,
    ):
        super().__init__(
            state_dim=state_dim,
            n_step=n_step,
            n_iv_steps=0,
            param_dim=param_dim,
            n_views=n_views,
            learning_rate=learning_rate,
            eval_metrics=eval_metrics,
            code_sharing=code_sharing,
            factor_type=factor_type,
        )
        self.state_dim = state_dim
        self.param_dim = param_dim
        self.n_views = n_views
        self.learning_rate = learning_rate

        self.save_hyperparameters()

        self.encoder = MLP(
            input_dim=self.input_dim, output_dim=param_dim, 
            hidden_dim=kwargs.get("hidden_dim", 1024), num_layers=5
        )

        # initialize weights
        if self.train():
            xavier_init(self.encoder)

        self.misc = {
            "pred_params": [],
            "pred_states": [],
            "gbt": [],
            "gt_params": [],
            "r2_linear": [],
            "r2_nonlinear": [],
        }
        
        
        self.sim_metric = torch.nn.CosineSimilarity(dim=-1)
        self.criterion = torch.nn.CrossEntropyLoss()

        self.LpSimLoss = UnifiedCLLoss(LpSimCLRLoss(tau=tau)).loss
        self.tau = tau

    def loss(self, z_rec_tuple: torch.Tensor):
        # TODO: switch key and  value in code_sharing for LpSimLoss
        # return self.LpSimLoss(self.code_sharing, z_rec_tuple, torch.rand_like(z_rec_tuple))

        return infonce_loss(
            z_rec_tuple,
            sim_metric=self.sim_metric,
            criterion=self.criterion,
            tau=self.tau,
            projector=(lambda x: x),
            estimated_content_indices=list(self.code_sharing.keys()),
            subsets=list(self.code_sharing.values()),
        )

    def forward(self, states: torch.Tensor):
        if self.dct_layer:
            states: torch.Tensor = self.state_transform(states)
        # states: [n_views, bs, ts, state_dim]
        return self.encoder(states.reshape(-1, states.shape[-2] * states.shape[-1]))

    def training_step(self, batch, batch_idx):
        states = batch["states"].float()
        params_hat = self.forward(states).reshape(states.shape[0], states.shape[1], -1)
        loss = self.loss(params_hat)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        states = batch["states"].float()
        params_hat = (
            self.encoder(states.reshape(-1, states.shape[-2] * states.shape[-1]))
            .reshape(*states.shape[:-2], -1)
            .cpu()
            .numpy()
        )
        # store predicted parameters for the whole dataset (earth)
        self.misc["pred_params"].append(params_hat)
        if "gt_params" in batch:
            self.misc["gt_params"].append(torch.stack(list(batch["gt_params"].values()), -1).cpu().numpy())
