import pytorch_lightning as pl
from torch import optim, nn
import os
import torch
import numpy as np
import pdb
import h5py

from functorch import vmap
from functorch import combine_state_for_ensemble
from typing import List, Union, Optional, Tuple
import copy
from torch.linalg import matrix_norm
from torch.nn.utils.parametrizations import spectral_norm
from torchvision.ops import MLP
from torch.distributions import Categorical
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch import Tensor
# from .. import functional as F
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.parameter import Parameter
from torch.types import _dtype as DType

from datasets import loadMatFile


class LinearizedForwardModel(nn.Module):
    def __init__(
        self, 
        basepath, 
        slice_num, 
        time_idx,
        nan_to_zero:bool = False,
        ssp_transform=None, 
        at_transform=None):
        """
        slice_num - assumes 1 indexed
        time_idx - assumes 1 indexed
        """
        super().__init__()
        self.basepath = os.path.join(basepath, f'Slice_{slice_num}/Acoustic_Data/')
        self.slice_num = slice_num
        self.time_idx = time_idx
        self.ssp_transform = copy.deepcopy(ssp_transform)
        self.at_transform = copy.deepcopy(at_transform)
        
        # Load mE
        e_path = os.path.join(self.basepath, f'time_saves/{time_idx}_m_E.mat')
        mE = loadMatFile(e_path, 'm_E')
        num_d = 13 if mE.shape[0] % 13 == 0 else 12
        mE = mE.reshape(num_d, -1, 400, 5)[:11,:231,:,:2].reshape(11*231, 800)
        mE = torch.Tensor(mE)
        if nan_to_zero:
            mE[torch.isnan(mE)] = 0.0
        self.register_buffer('mE', mE, persistent=True)

        # Load SSPs
        ssp_path = os.path.join(self.basepath, 'cm_ssp.mat')
        ssp = loadMatFile(ssp_path, 'cm_ssp', time_idx)[:11,:231]
        self.register_buffer('ssp', torch.Tensor(ssp), persistent=True)
        
        # Load Arrival Times                 
        dir_ats = loadMatFile(os.path.join(self.basepath, 'cm_tau_dir.mat'), 'cm_tau_dir', time_idx)
        sur_ats = loadMatFile(os.path.join(self.basepath, 'cm_tau_sur.mat'), 'cm_tau_sur', time_idx)
        self.register_buffer('ats', torch.Tensor(np.concatenate((dir_ats[None], sur_ats[None]), 0)), persistent=True)
        
    def __call__(self, x):
        return self.forward(x)
    def forward(self, x):
        """
        x - (bs, 11, 231)
        
        out - (bs, 2, 20, 20)
        """
        if self.ssp_transform:
            x= self.ssp_transform.unnormalize(x)
            
        bs = x.size(0)
        ssps = (x - self.ssp[None]).view(bs, -1)
        athats = torch.matmul(ssps, self.mE)
        athats = torch.transpose(athats.view(bs, 20,20,2), 1,3) + self.ats[None] 
        if self.at_transform:
            return self.at_transform(athats)
#         print(ssps.size(), self.mE.size(), athats.size(), self.ats[None].size())
        return athats


class SharedQKAttention(nn.Module):
    """
    Simplified version of MHA with shared q/k projection
    """
    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.,
        bias=False,
        add_bias_kv=False, 
        add_zero_attn=False,
        kdim=None,
        vdim=None,
        batch_first=True,
        device=None,
        dtype=None,
    ):
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}

        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.qk_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
#         self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
        self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
        self.register_parameter('in_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn
        self._reset_parameters()
    def _reset_parameters(self):

        xavier_uniform_(self.qk_proj_weight)
        xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
#         if self.bias_k is not None:
#             xavier_normal_(self.bias_k)
#         if self.bias_v is not None:
#             xavier_normal_(self.bias_v)
    def _none_or_dtype(self, input: Optional[Tensor]) -> Optional[DType]:
        if input is None:
            return None
        elif isinstance(input, torch.Tensor):
            return input.dtype
        raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
    def _canonical_mask(
        self,
        mask: Optional[Tensor],
        mask_name: str,
        other_type: Optional[DType],
        other_name: str,
        target_type: DType,
        check_other: bool = True,
) -> Optional[Tensor]:

        if mask is not None:
            _mask_dtype = mask.dtype
            _mask_is_float = torch.is_floating_point(mask)
            if _mask_dtype != torch.bool and not _mask_is_float:
                raise AssertionError(
                    f"only bool and floating types of {mask_name} are supported")
            if check_other and other_type is not None:
                if _mask_dtype != other_type:
                    warnings.warn(
                        f"Support for mismatched {mask_name} and {other_name} "
                        "is deprecated. Use same type for both instead."
                    )
            if not _mask_is_float:
                mask = (
                    torch.zeros_like(mask, dtype=target_type)
                    .masked_fill_(mask, float("-inf"))
                )
        return mask
    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            key_padding_mask: Optional[Tensor] = None,
            need_weights: bool = True,
            attn_mask: Optional[Tensor] = None,
            average_attn_weights: bool = True,
            is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
        
        if attn_mask is not None and is_causal:
            raise AssertionError("Only allow causal mask or attn_mask")

        is_batched = query.dim() == 3

        key_padding_mask = self._canonical_mask(
            mask=key_padding_mask,
            mask_name="key_padding_mask",
            other_type=self._none_or_dtype(attn_mask),
            other_name="attn_mask",
            target_type=query.dtype
        )

        any_nested = query.is_nested or key.is_nested or value.is_nested
        assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
                                f"The fast path was not hit because {why_not_fast_path}")

        if self.batch_first and is_batched:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [x.transpose(1, 0) for x in (query, key, value)]


        attn_output, attn_output_weights = F.multi_head_attention_forward(
            query, key, value, self.embed_dim, self.num_heads,
            self.in_proj_weight, self.in_proj_bias,
            self.bias_k, self.bias_v, self.add_zero_attn,
            self.dropout, self.out_proj.weight, self.out_proj.bias,
            training=self.training,
            key_padding_mask=key_padding_mask, need_weights=need_weights,
            attn_mask=attn_mask,
            use_separate_proj_weight=True,
            q_proj_weight=self.qk_proj_weight, k_proj_weight=self.qk_proj_weight,
            v_proj_weight=self.v_proj_weight,
            average_attn_weights=average_attn_weights,
            )
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights
        
class WeightedAverageBase(nn.Module):
    """
    Extension of Linearized Forward Model. Initializes arrival time prediction 
    with a set of reference LFM passes and weights them based on the attention w.r.t. 
    the reference SSPs.
    """
    def __init__(self):
        super().__init__()
#         pass
    
    def build_ensemble(self, basepath, slice_nums, time_idxs, ssp_transform=None, at_transform=None):
        models = []
        for s, t in zip(slice_nums, time_idxs):
            models.append(LinearizedForwardModel(basepath, s, t, nan_to_zero=True, ssp_transform=ssp_transform, at_transform=at_transform)) 
        self.models = nn.ModuleList(models)
        
        _, _, buffers = combine_state_for_ensemble(self.models)
        
        # Grab reference ssps
        if ssp_transform:
            self.key = ssp_transform(buffers[1]).view(1, self.num_pred, -1)
        else:
            self.key = buffers[1].view(1, self.num_pred, -1)
        
    def fill_nan_with_avg(self, athats):
        # replace nans with averages
        with torch.no_grad():
            nanmean = athats.nanmean(dim=1, keepdims=True)
            mask = athats.isnan()
        athats[mask] = nanmean.expand(*athats.size())[mask]

        # replace all nan preds with 0
        athats[athats.isnan()]=0.0#torch.nan_to_num(athats)
        return athats
    
    def LFMForward(self, x):
        # convert to vmap for efficiency
        fmodel, params, buffers = combine_state_for_ensemble(self.models)
        
        athats = vmap(fmodel, in_dims=(0, 0, None))((), buffers, x)
        
        # move bs to first
        athats.transpose_(0, 1)
        
        # replace nans with average
        athats = self.fill_nan_with_avg(athats)
        return athats
        
    def forward(self, x):
        """
        x: bs x r x d
        
        returns:
        y: bs x numat x num_s x num_r
        """
        pass
    
class WeightedAverageNet(WeightedAverageBase):
    def __init__(
        self,
        basepath: str,
        slice_nums: List[int],
        time_idxs: List[int], 
        embed_dim: int = 512,
        num_heads: int = 1,
        ssp_depth: int = 231,
        ssp_range: int = 11,
        at_size: int = 800, 
        dropout: float = 0.0,
        tanh: bool = False,
        ssp_transform: Union[nn.Module, None] = None,
        at_transform: Union[nn.Module, None] = None,
    ):
        """
        Extension of Linearized Forward Model. Initializes arrival time prediction 
        with a set of reference LFM passes and weights them based on the attention w.r.t. 
        the reference SSPs.
        """
        super().__init__()
        self.ssp_size = ssp_range*ssp_depth
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.at_size = at_size
        # Load Linear forward models
        assert len(slice_nums) == len(time_idxs)
        self.num_pred = len(slice_nums)
        self.dropout = dropout
        
        # build ensemble
        self.build_ensemble( basepath, slice_nums, time_idxs, ssp_transform=ssp_transform, at_transform=at_transform)

        
        # input linear layer to convert query to embed dim
        self.query_proj = nn.Linear(self.ssp_size, self.embed_dim, bias=False)
        
        # instantiate mha model
        self.mha = SharedQKAttention(
            embed_dim=self.embed_dim,
            num_heads=self.num_heads,
            bias=False,
            batch_first=True,
            dropout=self.dropout,
#             kdim=self.ssp_size,
            vdim=self.at_size,
        )
        
        out = []
        # output layer
        if tanh:
            out += [nn.Tanh()]
        out += [nn.Linear(self.embed_dim, self.at_size, bias=False)]
        self.linear_out = nn.Sequential(*out)
        
        
    def forward(self, x):
        """
        x: bs x r x d
        
        returns:
        y: bs x numat x num_s x num_r
        """
        bs = x.size(0)
        # input projection layer
        q = self.query_proj(x.view(bs,1, -1))
        # grab reference ssps to use as key
        k = self.query_proj(self.key.expand(bs, *self.key.size()[1:]).to(q.device))
        v = self.LFMForward(x).reshape(bs, self.num_pred, -1)
        
        out, probs = self.mha(q,k,v)
        
        out = self.linear_out(out)
        return out.view(bs, 2, 20, 20), probs
    



class WeightedAverageBaseModule(pl.LightningModule):
    def __init__(
        self,
        basepath: str,
        slice_nums: List[int],
        time_idxs: List[int], 
        embed_dim: int = 512,
        num_heads: int = 1,
        ssp_depth: int = 231,
        ssp_range: int = 11,
        at_size: int = 800, 
        dropout: float = 0.0,
        batch_size=1000,
        learning_rate=1e-3,
        ssp_transform: Union[nn.Module, None] = None,
        at_transform: Union[nn.Module, None] = None,
        log_matrix_norms: bool = True,
        spectral_norm_query: bool = False,
        spectral_norm_out: bool = False,
        spectral_norm_mha: bool = False,
        store_at_transform = True,
    ):
        super().__init__()
        # not used?
        self.learning_rate = learning_rate
        self.batch_size=batch_size
        
        self.basepath = basepath
        self.slice_nums = slice_nums
        self.time_idxs = time_idxs
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ssp_depth = ssp_depth
        self.ssp_range = ssp_range
        self.at_size = at_size 
        self.dropout = dropout
        
        self.spectral_norm_query = spectral_norm_query
        self.spectral_norm_out = spectral_norm_out
        self.spectral_norm_mha = spectral_norm_mha
        
        self.log_matrix_norms = log_matrix_norms

        self.at_transform = at_transform
        
        self.build_net(ssp_transform, at_transform)
        self.apply_sn()
        self.save_hyperparameters()
    def apply_sn(self):
        if self.spectral_norm_query:
            spectral_norm(self.net.query_proj)
        if self.spectral_norm_out:
            if isinstance(self.net.linear_out, nn.Linear):
                spectral_norm(self.net.linear_out)
            else:
                for m in self.net.linear_out.children():
                    if isinstance(m, nn.Linear):
                        spectral_norm(m)
        if self.spectral_norm_mha:
            spectral_norm(self.net.mha,'qk_proj_weight')
            spectral_norm(self.net.mha,'v_proj_weight')
            spectral_norm(self.net.mha.out_proj)
        
    def _log_matrix_norms(self):
        
        if self.log_matrix_norms:
            with torch.no_grad():
                # log norms of matrices
#                 for matrix in ['query_proj','linear_out']:
                norm = matrix_norm(self.net.query_proj.weight).item()
                self.log(f'matrix_norm/query_proj', norm)
                out_layer = self.net.linear_out
                if isinstance(out_layer, nn.Sequential):
                    norm = matrix_norm(self.net.linear_out[-1].weight).item()
                else:
                    norm = matrix_norm(out_layer.weight).item()
                self.log(f'matrix_norm/linear_out', norm)
                
                mha_weights = ['v_proj_weight', 'qk_proj_weight'] 
                for matrix in mha_weights:
                    norm = matrix_norm(getattr(self.net.mha, matrix)).item()
                    self.log(f'matrix_norm/{matrix}', norm)
             
    def forward(self,x):
        return self.net(x)
    def training_step(self, batch, batch_idx):
        x, y = batch
        yhat, probs = self(x)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        
        with torch.no_grad():
            # calculate entropy
            # bs x 1 x num_ref
            entropy = Categorical(probs=probs[:,0]).entropy().mean().item()
            self.log("entropy", entropy)
            
        self._log_matrix_norms()
                                
        return loss

    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        yhat, probs = self(x)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        self.log("val_loss", loss, prog_bar=True)
        
        # unnormalized
        if self.at_transform is not None:
            un_y = self.at_transform.unnormalize(y)
            un_yhat = self.at_transform.unnormalize(yhat)
            yidx = torch.isfinite(un_y)
            un_loss = nn.functional.mse_loss(un_yhat[yidx], un_y[yidx])
            self.log("un_val_loss", un_loss)#, prog_bar=True)
        return {"val_loss": loss, "denorm_val_loss": un_loss}
    
    def validation_epoch_end(self, outputs):
        #todo
        return None
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        yhat, probs = self(x)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        self.log("test_loss", loss)

class WeightedAverageModule(WeightedAverageBaseModule):

    def build_net(self, ssp_transform=None, at_transform=None):
        self.net = WeightedAverageNet(
            basepath=self.basepath, 
            slice_nums=self.slice_nums, 
            time_idxs=self.time_idxs, 
            embed_dim=self.embed_dim, 
            num_heads=self.num_heads, 
            ssp_depth=self.ssp_depth, 
            ssp_range=self.ssp_range, 
            at_size=self.at_size, 
            dropout=self.dropout,
            tanh=self.tanh,
            ssp_transform=ssp_transform, 
            at_transform=at_transform)
            

class PETALNet(WeightedAverageNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.decode = nn.Linear(self.embed_dim, self.ssp_size)
    def forward(self, x, decode=False):
        """
        x: bs x r x d
        
        returns:
        y: bs x numat x num_s x num_r
        """
        bs = x.size(0)
        # input projection layer
        q = self.query_proj(x.view(bs,1, -1))
        # grab reference ssps to use as key
        k = self.query_proj(self.key.expand(bs, *self.key.size()[1:]).to(q.device))
        v = self.LFMForward(x).reshape(bs, self.num_pred, -1)
        
        out, probs = self.mha(q,k,v)
        
        out = self.linear_out(out)
        if decode:
#             xhat = self.decode(F.linear(q[:,0], self.mha.qk_proj_weight))
            xhat = self.decode(q[:,0])

            return out.view(bs, 2, 20, 20), probs, xhat.view(bs, 11, 231)
        else:
            return out.view(bs, 2, 20, 20), probs
        
    def na_forward(self, z, decode=False):
        """
        z: bs x embed_dim
        x: bs x r x d
        
        returns:
        y: bs x numat x num_s x num_r
        """
        bs = z.size(0)
        # input projection layer
        q = z.view(bs, 1, -1)    
        # grab reference ssps to use as key
        k = self.query_proj(self.key.expand(bs, *self.key.size()[1:]).to(q.device))
        
        xhat = self.decode(q[:,0]).view(bs, 11, 231)
        v = self.LFMForward(xhat).reshape(bs, self.num_pred, -1)
        
        out, probs = self.mha(q,k,v)
        
        out = self.linear_out(out)
        if decode:
            return out.view(bs, 2, 20, 20), probs, xhat.view(bs, 11, 231)
        else:
            return out.view(bs, 2, 20, 20), probs 

class PETALModule(WeightedAverageBaseModule):
    def __init__(
        self,
        basepath: str,
        slice_nums: List[int],
        time_idxs: List[int], 
        embed_dim: int = 512,
        num_heads: int = 1,
        ssp_depth: int = 231,
        ssp_range: int = 11,
        at_size: int = 800, 
        dropout: float = 0.0,
        batch_size=1000,
        learning_rate=1e-3,
        ssp_transform: Union[nn.Module, None] = None,
        at_transform: Union[nn.Module, None] = None,
        log_matrix_norms: bool = True,
        spectral_norm_query: bool = False,
        spectral_norm_out: bool = False,
        spectral_norm_mha: bool = False,
        lambda_rec: float = 1.0,
    ):
        
        super().__init__(
            basepath = basepath,
            slice_nums = slice_nums,
            time_idxs = time_idxs, 
            embed_dim = embed_dim,
            num_heads = num_heads,
            ssp_depth = ssp_depth,
            ssp_range = ssp_range,
            at_size = at_size,
            dropout = dropout,
            batch_size = batch_size,
            learning_rate = learning_rate,
            ssp_transform = ssp_transform,
            at_transform = at_transform,
            log_matrix_norms = log_matrix_norms,
            spectral_norm_query = spectral_norm_query,
            spectral_norm_out = spectral_norm_out,
            spectral_norm_mha = spectral_norm_mha,
        )
        self.lambda_rec = lambda_rec
        

    def build_net(self, ssp_transform=None, at_transform=None):
        self.net = PETALNet(
            basepath=self.basepath, 
            slice_nums=self.slice_nums, 
            time_idxs=self.time_idxs, 
            embed_dim=self.embed_dim, 
            num_heads=self.num_heads, 
            ssp_depth=self.ssp_depth, 
            ssp_range=self.ssp_range, 
            at_size=self.at_size, 
            dropout=self.dropout,
            ssp_transform=ssp_transform, 
            at_transform=at_transform)
        
    def apply_sn(self):
        if self.spectral_norm_query:
            spectral_norm(self.net.query_proj)
        if self.spectral_norm_out:
            spectral_norm(self.net.decode)
        if self.spectral_norm_mha:
            spectral_norm(self.net.mha,'qk_proj_weight')
            # should turn off for future work?    
            spectral_norm(self.net.mha.out_proj)

    def forward(self,x, decode=False):
        return self.net(x, decode)
    def na_forward(self, z, decode=False):
        return self.net.na_forward(z, decode)
    
    def encoder(self, x):
        return self.net.query_proj(x.view(-1,self.ssp_depth*self.ssp_range))
    def decoder(self, z):
        return self.net.decode(z).view(-1, self.ssp_range, self.ssp_depth)   
    def training_step(self, batch, batch_idx):
        x, y = batch
        yhat, probs, xhat = self(x, decode=True)
        # Train only on not NaNs
        yidx = torch.isfinite(y)
        obs_loss = nn.functional.mse_loss(yhat[yidx], y[yidx])
        # Logging to TensorBoard by default
        self.log("obs_loss", obs_loss)

        rec_loss = self.lambda_rec*nn.functional.mse_loss(xhat, x)
        self.log("rec_loss", rec_loss)
        loss = obs_loss + rec_loss

        with torch.no_grad():
            # calculate entropy
            # bs x 1 x num_ref
            entropy = Categorical(probs=probs[:,0]).entropy().mean().item()
            self.log("entropy", entropy)
            
        self._log_matrix_norms()
           
        return loss


