import torch
import torch.nn as nn


class DecoderBase(nn.Module):
    """docstring for Decoder"""
    def __init__(self):
        super(DecoderBase, self).__init__()


    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

    def decode(self, x, z):
        """
        Args:
            x: (batch_size, seq_len)
            z: (batch_size, n_sample, nz)

        Returns: Tensor1
            Tensor1: the output logits with size (batch_size * n_sample, seq_len, vocab_size)

        """

        raise NotImplementedError

    def reconstruct_error(self, x, z):
        """reconstruction loss
        Args:
            x: (batch_size, *)
            z: (batch_size, n_sample, nz)
        Returns:
            loss: (batch_size, n_sample). Loss
            across different sentence and z
        """

        raise NotImplementedError

    def beam_search_decode(self, z, K):
        """beam search decoding
        Args:
            z: (batch_size, nz)
            K: the beam size

        Returns: List1
            List1: the decoded word sentence list
        """

        raise NotImplementedError

    def sample_decode(self, z):
        """sampling from z
        Args:
            z: (batch_size, nz)

        Returns: List1
            List1: the decoded word sentence list
        """

        raise NotImplementedError

    def greedy_decode(self, z):
        """greedy decoding from z
        Args:
            z: (batch_size, nz)

        Returns: List1
            List1: the decoded word sentence list
        """

        raise NotImplementedError

    def log_probability(self, x, z):
        """
        Args:
            x: (batch_size, *)
            z: (batch_size, n_sample, nz)
        Returns:
            log_p: (batch_size, n_sample).
                log_p(x|z) across different x and z
        """

        raise NotImplementedError




# import torch

import time
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

import numpy as np

from .decoder_helper import BeamSearchNode

class LSTMDecoder(DecoderBase):
    """LSTM decoder with constant-length data"""
    def __init__(self, args, vocab, model_init, emb_init):
        super(LSTMDecoder, self).__init__()
        self.ni = args.ni
        self.nh = args.dec_nh
        self.nz = args.nz
        self.vocab = vocab
        self.device = args.device
        self.add_skip = args.add_skip
        self.add_bow = args.add_bow

        # no padding when setting padding_idx to -1
        self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1)

        self.dropout_in = nn.Dropout(args.dec_dropout_in)
        self.dropout_out = nn.Dropout(args.dec_dropout_out)

        # for initializing hidden state and cell
        self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False)

        # concatenate z with input
        self.lstm = nn.LSTM(input_size=args.ni + args.nz,
                            hidden_size=args.dec_nh,
                            num_layers=1,
                            batch_first=True)

        # prediction layer
        self.pred_linear = nn.Linear(args.dec_nh + (args.nz if args.add_skip else 0), len(vocab), bias=False)

        vocab_mask = torch.ones(len(vocab))
        # vocab_mask[vocab['<pad>']] = 0
        self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduction='none')
        
        if self.add_bow:
            self.z_to_bow = nn.Sequential(
                nn.Linear(args.dim_z, args.dim_z),
                nn.Sigmoid(),
                nn.Linear(args.dim_z, len(vocab)),
            )

        self.reset_parameters(model_init, emb_init)

    def reset_parameters(self, model_init, emb_init):
        # for name, param in self.lstm.named_parameters():
        #     # self.initializer(param)
        #     if 'bias' in name:
        #         nn.init.constant_(param, 0.0)
        #         # model_init(param)
        #     elif 'weight' in name:
        #         model_init(param)

        # model_init(self.trans_linear.weight)
        # model_init(self.pred_linear.weight)
        for param in self.parameters():
            model_init(param)
        emb_init(self.embed.weight)

    def sample_text(self, input, z, EOS, device):
        sentence = [input]
        max_index = 0

        input_word = input
        batch_size, n_sample, _ = z.size()
        seq_len = 1
        z_ = z.expand(batch_size, seq_len, self.nz)
        seq_len = input.size(1)
        softmax = torch.nn.Softmax(dim=0)
        while max_index != EOS and len(sentence) < 100:
            # (batch_size, seq_len, ni)
            word_embed = self.embed(input_word)
            word_embed = torch.cat((word_embed, z_), -1)
            c_init = self.trans_linear(z).unsqueeze(0)
            h_init = torch.tanh(c_init)
            if len(sentence) == 1:
                h_init = h_init.squeeze(dim=1)
                c_init = c_init.squeeze(dim=1)
                output, hidden = self.lstm(word_embed, (h_init, c_init))
            else:
                output, hidden = self.lstm(word_embed, hidden)
            # (batch_size * n_sample, seq_len, vocab_size)
            output_logits = self.pred_linear(output)
            output_logits = output_logits.view(-1)
            probs = softmax(output_logits)
            # max_index = torch.argmax(output_logits)
            max_index = torch.multinomial(probs, num_samples=1)
            input_word = torch.tensor([[max_index]]).to(device)
            sentence.append(max_index)
        return sentence

    def decode(self, input, z):
        """
        Args:
            input: (batch_size, seq_len)
            z: (batch_size, n_sample, nz)
        """

        # not predicting start symbol
        # sents_len -= 1

        batch_size, n_sample, _ = z.size()
        seq_len = input.size(1)

        # (batch_size, seq_len, ni)
        word_embed = self.embed(input)
        word_embed = self.dropout_in(word_embed)

        if n_sample == 1:
            z_ = z.expand(batch_size, seq_len, self.nz)

        else:
            word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
                                   .contiguous()

            # (batch_size * n_sample, seq_len, ni)
            word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)

            z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
            z_ = z_.view(batch_size * n_sample, seq_len, self.nz)

        # (batch_size * n_sample, seq_len, ni + nz)
        word_embed = torch.cat((word_embed, z_), -1)

        z = z.view(batch_size * n_sample, self.nz)
        c_init = self.trans_linear(z).unsqueeze(0)
        h_init = torch.tanh(c_init)
        # h_init = self.trans_linear(z).unsqueeze(0)
        # c_init = h_init.new_zeros(h_init.size())
        output, _ = self.lstm(word_embed, (h_init, c_init))

        output = self.dropout_out(output)

        # (batch_size * n_sample, seq_len, vocab_size)
        if self.add_skip:
            output_logits = self.pred_linear(torch.cat((output,z.unsqueeze(1).repeat(1,seq_len,1)), -1))
        else:
            output_logits = self.pred_linear(output)

        return output_logits

    def reconstruct_error(self, x, z):
        """Cross Entropy in the language case
        Args:
            x: (batch_size, seq_len)
            z: (batch_size, n_sample, nz)
        Returns:
            loss: (batch_size, n_sample). Loss
            across different sentence and z
        """

        #remove end symbol
        src = x[:, :-1]

        # remove start symbol
        tgt = x[:, 1:]

        batch_size, seq_len = src.size()
        n_sample = z.size(1)

        # (batch_size * n_sample, seq_len, vocab_size)
        output_logits = self.decode(src, z)

        if n_sample == 1:
            tgt = tgt.contiguous().view(-1)
        else:
            # (batch_size * n_sample * seq_len)
            tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
                     .contiguous().view(-1)

        # (batch_size * n_sample * seq_len)
        loss = self.loss(output_logits.view(-1, output_logits.size(2)), tgt)
        
        if self.add_bow:
            batch_size, n_sample, nz = z.shape
            z = z.view(batch_size*n_sample, nz)
            tgt_bow, tgt_len = x[:, 1:].unsqueeze(1).repeat(1,n_sample,1).view(batch_size*n_sample, -1), x.shape[1]-1
            loss_bow = self.loss(self.z_to_bow(z).unsqueeze(-1).repeat(1,1,tgt_len), tgt_bow) / tgt_len
            loss_bow = loss_bow.view(batch_size*n_sample*tgt_len)
            loss = loss + (loss_bow - loss_bow.detach())
            # only return the gradient but not the value of loss_bow
            # in order to avoid conflict in evaluation (e.g. calculation of reconstruction likelihood)

        # (batch_size, n_sample)
        return loss.view(batch_size, n_sample, -1).sum(-1)


    def log_probability(self, x, z):
        """Cross Entropy in the language case
        Args:
            x: (batch_size, seq_len)
            z: (batch_size, n_sample, nz)
        Returns:
            log_p: (batch_size, n_sample).
                log_p(x|z) across different x and z
        """

        return -self.reconstruct_error(x, z)

    def beam_search_decode(self, z, K=5):
        """beam search decoding, code is based on
        https://github.com/pcyin/pytorch_basic_nmt/blob/master/nmt.py

        the current implementation decodes sentence one by one, further batching would improve the speed

        Args:
            z: (batch_size, nz)
            K: the beam width

        Returns: List1
            List1: the decoded word sentence list
        """

        decoded_batch = []
        batch_size, nz = z.size()

        # (1, batch_size, nz)
        c_init = self.trans_linear(z).unsqueeze(0)
        h_init = torch.tanh(c_init)

        # decoding goes sentence by sentence
        for idx in range(batch_size):
            # Start with the start of the sentence token
            decoder_input = torch.tensor([[self.vocab["<s>"]]], dtype=torch.long, device=self.device)
            decoder_hidden = (h_init[:,idx,:].unsqueeze(1), c_init[:,idx,:].unsqueeze(1))

            node = BeamSearchNode(decoder_hidden, None, decoder_input, 0., 1)
            live_hypotheses = [node]

            completed_hypotheses = []

            t = 0
            while len(completed_hypotheses) < K and t < 100:
                t += 1

                # (len(live), 1)
                decoder_input = torch.cat([node.wordid for node in live_hypotheses], dim=0)

                # (1, len(live), nh)
                decoder_hidden_h = torch.cat([node.h[0] for node in live_hypotheses], dim=1)
                decoder_hidden_c = torch.cat([node.h[1] for node in live_hypotheses], dim=1)

                decoder_hidden = (decoder_hidden_h, decoder_hidden_c)


                # (len(live), 1, ni) --> (len(live), 1, ni+nz)
                word_embed = self.embed(decoder_input)
                word_embed = torch.cat((word_embed, z[idx].view(1, 1, -1).expand(
                    len(live_hypotheses), 1, nz)), dim=-1)

                output, decoder_hidden = self.lstm(word_embed, decoder_hidden)

                # (len(live), 1, vocab_size)
                output_logits = self.pred_linear(output)
                decoder_output = F.log_softmax(output_logits, dim=-1)

                prev_logp = torch.tensor([node.logp for node in live_hypotheses], dtype=torch.float, device=self.device)
                decoder_output = decoder_output + prev_logp.view(len(live_hypotheses), 1, 1)

                # (len(live) * vocab_size)
                decoder_output = decoder_output.view(-1)

                # (K)
                log_prob, indexes = torch.topk(decoder_output, K-len(completed_hypotheses))

                live_ids = indexes // len(self.vocab)
                word_ids = indexes % len(self.vocab)

                live_hypotheses_new = []
                for live_id, word_id, log_prob_ in zip(live_ids, word_ids, log_prob):
                    node = BeamSearchNode((decoder_hidden[0][:, live_id, :].unsqueeze(1),
                        decoder_hidden[1][:, live_id, :].unsqueeze(1)),
                        live_hypotheses[live_id], word_id.view(1, 1), log_prob_, t)

                    if word_id.item() == self.vocab["</s>"]:
                        completed_hypotheses.append(node)
                    else:
                        live_hypotheses_new.append(node)

                live_hypotheses = live_hypotheses_new

                if len(completed_hypotheses) == K:
                    break

            for live in live_hypotheses:
                completed_hypotheses.append(live)

            utterances = []
            for n in sorted(completed_hypotheses, key=lambda node: node.logp, reverse=True):
                utterance = []
                utterance.append(self.vocab.id2word(n.wordid.item()))
                # back trace
                while n.prevNode != None:
                    n = n.prevNode
                    utterance.append(self.vocab.id2word(n.wordid.item()))

                utterance = utterance[::-1]

                utterances.append(utterance)

                # only save the top 1
                break

            decoded_batch.append(utterances[0])

        return decoded_batch


    def greedy_decode(self, z):
        return self.sample_decode(z, greedy=True)

    def sample_decode(self, z, greedy=False):
        """sample/greedy decoding from z
        Args:
            z: (batch_size, nz)

        Returns: List1
            List1: the decoded word sentence list
        """

        batch_size = z.size(0)
        decoded_batch = [[] for _ in range(batch_size)]

        # (batch_size, 1, nz)
        c_init = self.trans_linear(z).unsqueeze(0)
        h_init = torch.tanh(c_init)

        decoder_hidden = (h_init, c_init)
        decoder_input = torch.tensor([self.vocab["<s>"]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1)
        end_symbol = torch.tensor([self.vocab["</s>"]] * batch_size, dtype=torch.long, device=self.device)

        mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device)
        length_c = 1
        while mask.sum().item() != 0 and length_c < 100:

            # (batch_size, 1, ni) --> (batch_size, 1, ni+nz)
            word_embed = self.embed(decoder_input)
            word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1)

            output, decoder_hidden = self.lstm(word_embed, decoder_hidden)
            
            # (batch_size, 1, vocab_size) --> (batch_size, vocab_size)
            if self.add_skip:
                decoder_output = self.pred_linear(torch.cat((output,z.unsqueeze(1)), -1))
            else:
                decoder_output = self.pred_linear(output)
            output_logits = decoder_output.squeeze(1)

            # (batch_size)
            if greedy:
                max_index = torch.argmax(output_logits, dim=1)
            else:
                probs = F.softmax(output_logits, dim=1)
                max_index = torch.multinomial(probs, num_samples=1).squeeze(1)

            decoder_input = max_index.unsqueeze(1)
            length_c += 1

            for i in range(batch_size):
                word = self.vocab.id2word(max_index[i].item())
                if mask[i].item():
                    decoded_batch[i].append(self.vocab.id2word(max_index[i].item()))

            mask = torch.mul((max_index != end_symbol), mask)

        return decoded_batch
    
    def forward(self, x, z):
        """ the loss of reconstruction """
        return self.reconstruct_error(x, z).mean(dim=1)