from copy import deepcopy
import numpy as np
import torch
from torch import nn


class RNNEnoder(nn.Module):
    def __init__(self, cfg):
        super(RNNEnoder, self).__init__()
        self.cfg = cfg

        self.rnn_type = cfg.MODEL.LANGUAGE_BACKBONE.RNN_TYPE
        self.variable_length = cfg.MODEL.LANGUAGE_BACKBONE.VARIABLE_LENGTH
        self.word_embedding_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_EMBEDDING_SIZE
        self.word_vec_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_VEC_SIZE
        self.hidden_size = cfg.MODEL.LANGUAGE_BACKBONE.HIDDEN_SIZE
        self.bidirectional = cfg.MODEL.LANGUAGE_BACKBONE.BIDIRECTIONAL
        self.input_dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.INPUT_DROPOUT_P
        self.dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.DROPOUT_P
        self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS
        self.corpus_path = cfg.MODEL.LANGUAGE_BACKBONE.CORPUS_PATH
        self.vocab_size = cfg.MODEL.LANGUAGE_BACKBONE.VOCAB_SIZE

        # language encoder
        self.embedding = nn.Embedding(self.vocab_size, self.word_embedding_size)
        self.input_dropout = nn.Dropout(self.input_dropout_p)
        self.mlp = nn.Sequential(nn.Linear(self.word_embedding_size, self.word_vec_size), nn.ReLU())
        self.rnn = getattr(nn, self.rnn_type.upper())(self.word_vec_size,
                                                      self.hidden_size,
                                                      self.n_layers,
                                                      batch_first=True,
                                                      bidirectional=self.bidirectional,
                                                      dropout=self.dropout_p)
        self.num_dirs = 2 if self.bidirectional else 1

    def forward(self, input, mask=None):
        word_id = input
        max_len = (word_id != 0).sum(1).max().item()
        word_id = word_id[:, :max_len]  # mask zero
        # embedding
        output, hidden, embedded, final_output = self.RNNEncode(word_id)
        return {
            'hidden': hidden,
            'output': output,
            'embedded': embedded,
            'final_output': final_output,
        }

    def encode(self, input_labels):
        """
                Inputs:
                - input_labels: Variable long (batch, seq_len)
                Outputs:
                - output  : Variable float (batch, max_len, hidden_size * num_dirs)
                - hidden  : Variable float (batch, num_layers * num_dirs * hidden_size)
                - embedded: Variable float (batch, max_len, word_vec_size)
                """
        device = input_labels.device
        if self.variable_length:
            input_lengths_list, sorted_lengths_list, sort_idxs, recover_idxs = self.sort_inputs(input_labels)
            input_labels = input_labels[sort_idxs]

        embedded = self.embedding(input_labels)  # (n, seq_len, word_embedding_size)
        embedded = self.input_dropout(embedded)  # (n, seq_len, word_embedding_size)
        embedded = self.mlp(embedded)  # (n, seq_len, word_vec_size)

        if self.variable_length:
            if self.variable_length:
                embedded = nn.utils.rnn.pack_padded_sequence(embedded, \
                                                             sorted_lengths_list, \
                                                             batch_first=True)
        # forward rnn
        self.rnn.flatten_parameters()
        output, hidden = self.rnn(embedded)

        # recover
        if self.variable_length:
            # recover embedded
            embedded, _ = nn.utils.rnn.pad_packed_sequence(embedded,
                                                           batch_first=True)  # (batch, max_len, word_vec_size)
            embedded = embedded[recover_idxs]

            # recover output
            output, _ = nn.utils.rnn.pad_packed_sequence(output,
                                                         batch_first=True)  # (batch, max_len, hidden_size * num_dir)
            output = output[recover_idxs]

            # recover hidden
            if self.rnn_type == 'lstm':
                hidden = hidden[0]  # hidden state
            hidden = hidden[:, recover_idxs, :]  # (num_layers * num_dirs, batch, hidden_size)
            hidden = hidden.transpose(0, 1).contiguous()  # (batch, num_layers * num_dirs, hidden_size)
            hidden = hidden.view(hidden.size(0), -1)  # (batch, num_layers * num_dirs * hidden_size)

        # final output
        finnal_output = []
        for ii in range(output.shape[0]):
            finnal_output.append(output[ii, int(input_lengths_list[ii] - 1), :])
        finnal_output = torch.stack(finnal_output, dim=0)  # (batch, number_dirs * hidden_size)

        return output, hidden, embedded, finnal_output

    def sort_inputs(self, input_labels):  # sort input labels by descending
        device = input_labels.device
        input_lengths = (input_labels != 0).sum(1)
        input_lengths_list = input_lengths.data.cpu().numpy().tolist()
        sorted_input_lengths_list = np.sort(input_lengths_list)[::-1].tolist()  # list of sorted input_lengths
        sort_idxs = np.argsort(input_lengths_list)[::-1].tolist()
        s2r = {s: r for r, s in enumerate(sort_idxs)}
        recover_idxs = [s2r[s] for s in range(len(input_lengths_list))]
        assert max(input_lengths_list) == input_labels.size(1)
        # move to long tensor
        sort_idxs = input_labels.data.new(sort_idxs).long().to(device)  # Variable long
        recover_idxs = input_labels.data.new(recover_idxs).long().to(device)  # Variable long
        return input_lengths_list, sorted_input_lengths_list, sort_idxs, recover_idxs
