import torch.nn as nn
from torch.nn import functional as F
import torch.nn.utils.parametrize as P

from .cross_attention import CrossAttentionBlock
from .activations import create_activation
from .fourier_mapping import FourierMapping
from .configs import DecoderWithCrossAttentionConfig


class RowNormalize(nn.Module):
    def forward(self, weight):
        return F.normalize(weight, dim=1)


class DecoderWithCrossAttention(nn.Module):
    def __init__(self, config: DecoderWithCrossAttentionConfig):
        super().__init__()
        self.config = config
        self.num_layer = config.n_layer
        self.hidden_dims = list(config.hidden_dim)
        if len(self.hidden_dims) == 1:
            self.hidden_dims = [self.hidden_dims[0]] * (self.num_layer - 1)  # exclude output layer
        else:
            assert len(self.hidden_dims) == self.num_layer - 1

        self.attention_layer_idxs = list(config.attention_layer_idxs)
        assert len(self.attention_layer_idxs) > 0
        assert all([idx < self.num_layer - 1 for idx in self.attention_layer_idxs])

        # Fourier layer definition
        ff_config = config.fourier_mapping
        if ff_config:
            self.fourier_mapping = FourierMapping(
                ff_type=ff_config.type,
                input_dim=self.config.input_dim,
                ff_dim=ff_config.ff_dim,
                ff_sigma=ff_config.ff_sigma,
                trainable=ff_config.trainable,
            )
            first_layer_dim = ff_config.ff_dim * 2
        else:
            self.fourier_mapping = nn.Identity()
            first_layer_dim = self.config.input_dim

        # Remaining layer definitions
        in_dims = [first_layer_dim] + self.hidden_dims
        out_dims = self.hidden_dims + [self.config.output_dim]

        layers = []
        for layer_idx, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
            if layer_idx in self.attention_layer_idxs:
                xattn_config = self.config.cross_attention
                layer = CrossAttentionBlock(
                    embed_dim=xattn_config.embed_dim,
                    n_head=xattn_config.n_head,
                    input_dim=in_dim,
                    context_dim=self.config.latent_dim,
                    output_dim=out_dim,
                    dropout=xattn_config.dropout,
                    bias=xattn_config.bias,
                    input_layernorm=xattn_config.input_layernorm,
                    residual=xattn_config.residual,
                )
            else:
                layer = nn.Linear(in_dim, out_dim, bias=self.config.use_bias)
                if self.config.use_bias:
                    nn.init.zeros_(layer.bias)
                if config.normalize_mlp_weights:
                    P.register_parametrization(layer, "weight", RowNormalize())

            layers.append(layer)

        self.layers = nn.ModuleList(layers)

        self.activation = create_activation(self.config.activation)
        self.output_bias = config.output_bias

    def forward(self, coord, latents):
        """Computes the signal value for each coordinate.
        Note: `assert outputs.shape[:-1] == coord.shape[:-1]`

        Args
            coord (torch.Tensor): Input coordinates.
            latents (torch.Tensor): Latent vectors to be cross-attended.
                Currently, all cross-attention layers uses the same latents as context.

        Returns
            outputs (torch.Tensor): evaluated values by INR
        """

        batch_size, coord_shape, input_dim = coord.shape[0], coord.shape[1:-1], coord.shape[-1]
        coord = coord.view(batch_size, -1, input_dim)  # flatten the coordinates
        hidden = self.fourier_mapping(coord)

        for layer_idx, layer in enumerate(self.layers):
            if layer_idx in self.attention_layer_idxs:
                hidden = layer(hidden, latents)
            else:
                hidden = layer(hidden)

            is_last_layer = (layer_idx == self.num_layer - 1)
            if not is_last_layer:
                hidden = self.activation(hidden)

        outputs = hidden + self.output_bias
        outputs = outputs.view(batch_size, *coord_shape, -1)
        return outputs

    def compute_modulated_params_dict(self, modulation_params_dict):
        raise NotImplementedError

    def forward_with_params(self, coord, params_dict):
        raise NotImplementedError
