import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn as nn
import matplotlib.pyplot as plt
import random
from .topk_decoder_utils.DecoderRNN import DecoderRNN
import numpy as np
from .topk_decoder_utils.topk_module import *
#import matplotlib.pyplot as plt


def _inflate(tensor, times, dim):
        """
        Given a tensor, 'inflates' it along the given dimension by replicating each slice specified number of times (in-place)

        """
        repeat_dims = [1] * tensor.dim()
        repeat_dims[dim] = times
        return tensor.repeat(*repeat_dims)

class Teacher_forcing_scheduler():
    def __init__(self, tf_type, tf_ratio, total_steps):
        self.tf_type = tf_type
        self.tf_ratio = tf_ratio
        if tf_type == 'constant':
            pass
        elif tf_type == 'linear':
            self.decrease_every = 1./total_steps
        elif tf_type == 'expo':
            # for total_steps/2, teacher forcing ratio decrease to 0.2
            self.decrease_ratio = total_steps/3.2188
        else:
            raise Exception('Teacher forcing type not implemented!')

    def get_ratio(self, step):
        if step is None:
            return self.tf_ratio

        if self.tf_type == 'constant':
            return self.tf_ratio
        elif self.tf_type == 'linear':
            return 1. - self.decrease_every * step
        elif self.tf_type == 'expo':
            return np.exp(-step/self.decrease_ratio)

class Epsilon_scheduler():
    def __init__(self, eps_type, eps_value, total_steps):
        self.eps_type = eps_type
        self.eps_value = eps_value
        self.total_steps = total_steps
        if eps_type == 'constant' or eps_type == 'step':
            pass
        else:
            raise Exception('Epsilon type not implemented!')

    def get_epsilon(self, step):
        if step is None:
            return self.eps_value

        if self.eps_type == 'constant':
            return self.eps_value
        elif self.eps_type == 'step':
            if step < self.total_steps/4:
                return 1.
            elif step >= self.total_steps/4 and step < self.total_steps/2:
                return 1e-1
            elif step >= self.total_steps/2 and step < self.total_steps/4*3:
                return 1e-2
            else:
                return 1e-3


class TopKDecoder(torch.nn.Module):
    r"""
    Top-K decoding with beam search.

    Args:
        decoder_rnn (DecoderRNN): An object of DecoderRNN used for decoding.
        k (int): Size of the beam.

    Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio
        - **inputs** (seq_len, batch, input_size): list of sequences, whose length is the batch size and within which
          each sequence is a list of token IDs.  It is used for teacher forcing when provided. (default is `None`)
        - **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features
          in the hidden state `h` of encoder. Used as the initial hidden state of the decoder.
        - **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder.
          Used for attention mechanism (default is `None`).
        - **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state
          (default is `torch.nn.functional.log_softmax`).
        - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is
          drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value,
          teacher forcing would be used (default is 0).

    Outputs: decoder_outputs, decoder_hidden, ret_dict
        - **decoder_outputs** (batch): batch-length list of tensors with size (max_length, hidden_size) containing the
          outputs of the decoder.
        - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden
          state of the decoder.
        - **ret_dict**: dictionary containing additional information as follows {*length* : list of integers
          representing lengths of output sequences, *topk_length*: list of integers representing lengths of beam search
          sequences, *sequence* : list of sequences, where each sequence is a list of predicted token IDs,
          *topk_sequence* : list of beam search sequences, each beam is a list of token IDs, *inputs* : target
          outputs if provided for decoding}.
    """

    def __init__(self, decoder_rnn, sos_id, eos_id, topk, soft_topk, 
                 epsilon=1e-3, max_iter=200, teacher_forcing_scheduler=None, epsilon_scheduler=None,
                 initial_method='none', train_tau_iter=0, use_A_on_hidden=False):
        super(TopKDecoder, self).__init__()
        self.rnn = decoder_rnn
        self.hidden_size = self.rnn.hidden_size
        self.V = self.rnn.output_size
        self.SOS = sos_id
        self.EOS = eos_id
        self.teacher_forcing_scheduler = teacher_forcing_scheduler
        self.function=F.log_softmax
        self.state = {}
        self.attentional = self.rnn.use_attention
        self.use_A_on_hidden = use_A_on_hidden
        self.use_topk_inference = False
        if soft_topk:
            self.use_topk_gradient = True
            self.epsilon=epsilon
            self.epsilon_scheduler = epsilon_scheduler
            self.max_iter = max_iter
            self.k = soft_topk
            self.initial_method = initial_method
            if initial_method == 'none':
                self.soft_topk = TopK_custom1(self.k, epsilon=self.epsilon, max_iter=self.max_iter)
            elif initial_method == 'sigmoid':
                self.soft_topk = TopK_custom2_w_initial(self.k, self.V*self.k, epsilon=self.epsilon, max_iter=self.max_iter)
            elif initial_method == 'sigmoid_no_iter':
                self.soft_topk = TopK_custom3_only_initial(self.k, epsilon=self.epsilon, max_iter=self.max_iter)
            elif initial_method == 'train_tau':
                self.train_tau_iter = train_tau_iter
                self.soft_topk = TopK_custom4_train_tau(self.k, self.V*self.k, epsilon=self.epsilon, max_iter=self.max_iter)
        else:
            self.use_topk_gradient = False
            self.epsilon_scheduler = None
            self.k = topk

    @classmethod
    def from_opt(cls, opt, embeddings):
        """Alternate constructor."""

        decoder_rnn = DecoderRNN( 
            embeddings,
            embeddings.word_vocab_size,
            opt.dec_rnn_size,
            max_length = opt.max_length,
            n_layers = opt.dec_layers,
            rnn_cell = opt.rnn_type,
            bidirectional = opt.brnn,
            input_dropout_p = opt.dropout[0] if type(opt.dropout) is list
            else opt.dropout,
            dropout_p = opt.dropout[0] if type(opt.dropout) is list
            else opt.dropout,
            use_attention = opt.global_attention)
        
        teacher_forcing_scheduler = Teacher_forcing_scheduler(opt.teacher_forcing_type, 
                                                             opt.teacher_forcing_ratio, opt.train_steps)
        epsilon_scheduler = Epsilon_scheduler(opt.epsilon_type, 
                                                             opt.epsilon, opt.train_steps)
        return cls(
            decoder_rnn, opt.tgt_sos_id, opt.tgt_eos_id,
            opt.topk, opt.soft_topk, opt.epsilon, opt.topk_iter, 
            teacher_forcing_scheduler, epsilon_scheduler, opt.initial_method,
            opt.train_tau_iter, opt.use_A_on_hidden)
   
    def init_state(self, src, memory_bank, encoder_final):
        """Initialize decoder state with last state of the encoder."""
        def _fix_enc_hidden(hidden):
            # The encoder hidden is  (layers*directions) x batch x dim.
            # We need to convert it to layers x batch x (directions*dim).
            if self.rnn.bidirectional_encoder:
                hidden = torch.cat([hidden[0:hidden.size(0):2],
                                    hidden[1:hidden.size(0):2]], 2)
            return hidden

        if isinstance(encoder_final, tuple):  # LSTM
            self.state["hidden"] = tuple(_fix_enc_hidden(enc_hid)
                                         for enc_hid in encoder_final)
            #print('init hidden', self.state['hidden'][0].size())
        else:  # GRU
            self.state["hidden"] = _fix_enc_hidden(encoder_final)
        if 'A' in self.state:
            del self.state['A']
#        # Init the input feed.
#        batch_size = self.state["hidden"][0].size(1)
#        h_size = (batch_size, self.hidden_size)
#        self.state["input_feed"] = \
#            self.state["hidden"][0].data.new(*h_size).zero_().unsqueeze(0)
#        self.state["coverage"] = None
        
    def map_state(self, fn):
        if isinstance(self.state['hidden'], tuple):
            self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"])
        else:
            self.state['hidden'] = fn(self.state['hidden'], 1)

#        self.state["input_feed"] = fn(self.state["input_feed"], 1)


    def detach_state(self):
        if isinstance(self.state['hidden'], tuple):
            self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
        else:
            self.state['hidden'] = self.state['hidden'].detach()
#        self.state["input_feed"] = self.state["input_feed"].detach()
        
    def change_epsilon_for_inference(self, epsilon):
        self.soft_topk.epsilon = epsilon
        self.use_topk_inference = True
        
    def forward(self, tgt, memory_bank, memory_lengths=None, step=None, gold_score_mode=False):
        """
        Args:
            tgt (LongTensor): sequences of padded tokens
                 ``(tgt_len, batch, nfeats)``.
            memory_bank (FloatTensor): vectors from the encoder
                 ``(src_len, batch, hidden)``.
            memory_lengths (LongTensor): the padded source lengths
                ``(batch,)``.

        Returns:
            (FloatTensor, dict[str, FloatTensor]):

            * dec_outs: output from the decoder (after attn)
              ``(tgt_len, batch, hidden)``.
            * attns: distribution over src at each tgt
              ``(tgt_len, batch, src_len)``.
        """

        inputs = tgt.transpose(0,1)
        encoder_hidden = self.state["hidden"]
        encoder_outputs = memory_bank.transpose(0,1)
        inputs, batch_size, max_length = self.rnn._validate_args(inputs, encoder_hidden, encoder_outputs,
                                                                 self.function)
        teacher_forcing_ratio = self.teacher_forcing_scheduler.get_ratio(step)
        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
        
        if self.training:
            if self.epsilon_scheduler is not None:
                current_epsilon = self.epsilon_scheduler.get_epsilon(step)
                self.soft_topk.epsilon = current_epsilon
        
        #The following cases should go into this 'if':
        # In inference, topk gradient adopted
        if max_length==1 and self.use_topk_inference:
                            
            decoder_hidden = encoder_hidden
            decoder_input = inputs.squeeze(-1)
            
            if 'A' in self.state:
                A = self.state['A']
                sequence_scores = self.state['sequence_scores']
            else:
                A = None
                
                self.V_ones = torch.ones(1, self.V).to(decoder_input.device)
                self.pos_index = Variable(torch.LongTensor(range(int(batch_size/self.k))) * self.k).view(-1, 1)
        
                # Initialize the scores; for the first step,
                # ignore the inflated copies to avoid duplicate entries in the top k
                sequence_scores = torch.Tensor(batch_size, 1)
                sequence_scores.fill_(-float('Inf'))
                sequence_scores.index_fill_(0, torch.LongTensor([i * self.k for i in range(int(batch_size/self.k))]), 0.0)
                sequence_scores = Variable(sequence_scores)
        
                # Initialize the input vector
                decoder_input = Variable(torch.transpose(torch.LongTensor([[self.SOS] * batch_size]), 0, 1))
                if torch.cuda.is_available():
                    decoder_input = decoder_input.cuda()
                    sequence_scores = sequence_scores.cuda()
                    self.pos_index = self.pos_index.cuda()
                ## Inflate the initial hidden states to be of size: b*k x h
                #if isinstance(encoder_hidden, tuple):
                #    decoder_hidden = tuple([_inflate(h, self.k, 1) for h in encoder_hidden])
                #else:
                #    decoder_hidden = _inflate(encoder_hidden, self.k, 1)

            decoder_output, decoder_hidden, attn = self.rnn.forward_step(decoder_input, decoder_hidden, encoder_outputs,
                                                                     function=self.function, A=A)

            sequence_scores = torch.matmul(sequence_scores, self.V_ones)
            sequence_scores += decoder_output.squeeze(1)

            if torch.all(sequence_scores==float('-inf')):
                A = None 
            else:
                if self.initial_method == 'train_tau':
                    A_ori, _ = self.soft_topk(sequence_scores.view(int(batch_size/self.k), -1), use_iter=False)
                else:
                    A_ori, _ = self.soft_topk(sequence_scores.view(int(batch_size/self.k), -1))
                    
                A_ori = A_ori.view([int(batch_size/self.k), self.k, -1, self.k])
                A = A_ori.sum(dim=1)

                if bool(torch.any(A!=A)):
                    raise Exception('nan appeared in A') 
            # If doing local backprop (e.g. supervised training), retain the output layer
            scores, candidates = sequence_scores.view(int(batch_size/self.k), -1).topk(self.k, dim=1)
            # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
            decoder_input = (candidates % self.V).view(batch_size, 1)
            sequence_scores = scores.view(batch_size, 1)
                
            # Update fields for next timestep
            predecessors = (candidates / self.V + self.pos_index.expand_as(candidates)).view(batch_size, 1)
            if self.use_A_on_hidden:
                A_for_hidden = A_ori.sum(dim=2).unsqueeze(0).transpose(-1, -2) #[1, k, bs*k]
                if isinstance(decoder_hidden, tuple):
                    decoder_hidden = tuple([A_for_hidden.matmul(h.view([-1, int(batch_size/self.k), self.k, self.hidden_size])).view(-1, batch_size, self.hidden_size) for h in decoder_hidden])
                else:
                    decoder_hidden = A_for_hidden.matmul(decoder_hidden.view(-1, int(batch_size/self.k), self.k, self.hidden_size)).view(-1, batch_size, self.hidden_size)
            else:
                if isinstance(decoder_hidden, tuple):
                    decoder_hidden = tuple([h.index_select(1, predecessors.squeeze()) for h in decoder_hidden])
                else:
                    decoder_hidden = decoder_hidden.index_select(1, predecessors.squeeze())
 
            # erase scores for end-of-sentence symbol so that they aren't expanded
            eos_indices = decoder_input.data.eq(self.EOS)
            if eos_indices.nonzero().dim() > 0:
                sequence_scores.data.masked_fill_(eos_indices, -float('inf'))

            attns = {"std": attn.transpose(0,1)}
            self.state['A'] = A
            self.state['sequence_scores'] = sequence_scores
            self.state['hidden'] = decoder_hidden
            return decoder_output.transpose(0,1), attns

        #The following cases should go into this 'elif':
        # 1. In training, teacher forcing is used
        # 2. In inference, computing the gold score of target sequence
        # 3. In inference, no topk gradient adopted
        elif use_teacher_forcing or gold_score_mode or max_length==1: # (not self.use_topk_gradient and max_length==1):
            
            metadata = dict()
            if self.rnn.use_attention:
                metadata[DecoderRNN.KEY_ATTN_SCORE] = list()
    
            
            decoder_hidden = encoder_hidden

            decoder_input = inputs.squeeze(-1)
            decoder_output, decoder_hidden, attn = self.rnn.forward_step(decoder_input, decoder_hidden, encoder_outputs,
                                                                     function=self.function)
                

            attns = {"std": attn.transpose(0,1)}
            #print(sequence_symbols)
            self.state['hidden'] = decoder_hidden
            return decoder_output.transpose(0,1), attns

        
        # The following cases should go into this 'else':
        # In training, beam search is used, i.e., not teacher forcing 
        else:
            self.max_length = max_length

            self.pos_index = Variable(torch.LongTensor(range(batch_size)) * self.k).view(-1, 1)
    
            # Inflate the initial hidden states to be of size: b*k x h
            if isinstance(encoder_hidden, tuple):
                hidden = tuple([_inflate(h, self.k, 1) for h in encoder_hidden])
            else:
                hidden = _inflate(encoder_hidden, self.k, 1)
    
            # ... same idea for encoder_outputs and decoder_outputs
            if self.rnn.use_attention:
                inflated_encoder_outputs = _inflate(encoder_outputs, self.k, 0)
            else:
                inflated_encoder_outputs = None
    
            # Initialize the scores; for the first step,
            # ignore the inflated copies to avoid duplicate entries in the top k
            sequence_scores = torch.Tensor(batch_size * self.k, 1)
            sequence_scores.fill_(-float('Inf'))
            sequence_scores.index_fill_(0, torch.LongTensor([i * self.k for i in range(0, batch_size)]), 0.0)
            sequence_scores = Variable(sequence_scores)
    
            # Initialize the input vector
            input_var = Variable(torch.transpose(torch.LongTensor([[self.SOS] * batch_size * self.k]), 0, 1))

            if torch.cuda.is_available():
                input_var = input_var.cuda()
                sequence_scores = sequence_scores.cuda()
                self.pos_index = self.pos_index.cuda()
            # Store decisions for backtracking
            stored_outputs = list()
            stored_scores = list()
            stored_predecessors = list()
            stored_emitted_symbols = list()
            stored_hidden = list()
            
            A = None
            V_ones = torch.ones(1, self.V).to(input_var.device)
            loss_for_tau_accum = 0
            for pp in range(0, max_length):
                # print(pp, torch.cuda.memory_allocated()/1024.**3) 
                # Run the RNN one step forward
                log_softmax_output, hidden, _ = self.rnn.forward_step(input_var, hidden,
                                                                      inflated_encoder_outputs, function=self.function, A=A)
                # print(pp, torch.cuda.memory_allocated()/1024.**3) 
                # To get the full sequence scores for the new candidates, add the local scores for t_i to the predecessor scores for t_(i-1)
                sequence_scores = torch.matmul(sequence_scores, V_ones)
                # sequence_scores = _inflate(sequence_scores, self.V, 1)
                sequence_scores += log_softmax_output.squeeze(1)
                # print(pp, torch.cuda.memory_allocated()/1024.**3) 
                if self.use_topk_gradient:
                    if torch.all(sequence_scores==float('-inf')):
                        A = None
                    else:
                        if self.initial_method == 'train_tau':
                            use_iter = step is not None and step < self.train_tau_iter
                            A_ori, loss_for_tau = self.soft_topk(sequence_scores.view(batch_size, -1), use_iter)
                        else:
                            A_ori, loss_for_tau = self.soft_topk(sequence_scores.view(batch_size, -1))
                        if loss_for_tau is not None:
                            loss_for_tau_accum += loss_for_tau
                        A_ori = A_ori.view([batch_size, self.k, -1, self.k])
                        A = A_ori.sum(dim=1)

                        if bool(torch.any(A!=A)):
                            raise Exception('nan appeared in A') 
                # If doing local backprop (e.g. supervised training), retain the output layer
                stored_outputs.append(log_softmax_output)
                
                scores, candidates = sequence_scores.view(batch_size, -1).topk(self.k, dim=1)
                
                # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
                input_var = (candidates % self.V).view(batch_size * self.k, 1)
                sequence_scores = scores.view(batch_size * self.k, 1)
                
    
                # Update fields for next timestep
                predecessors = (candidates / self.V + self.pos_index.expand_as(candidates)).view(batch_size * self.k, 1)
                if self.use_A_on_hidden and self.use_topk_gradient:
                    A_for_hidden = A_ori.sum(dim=2).unsqueeze(0).transpose(-1, -2) #[1, bs, k, k]
                    
                    if isinstance(hidden, tuple):
                        hidden = tuple([A_for_hidden.matmul(h.view(-1, batch_size, self.k, self.hidden_size)).view(-1, batch_size*self.k, self.hidden_size) for h in hidden])
                    else:
                        hidden = A_for_hidden.matmul(hidden.view(-1, batch_size, self.k, self.hidden_size)).view(-1, batch_size*self.k, self.hidden_size) 
                else:
                    if isinstance(hidden, tuple):
                        hidden = tuple([h.index_select(1, predecessors.squeeze()) for h in hidden])
                    else:
                        hidden = hidden.index_select(1, predecessors.squeeze())
    
                # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded
                stored_scores.append(sequence_scores.clone())
                eos_indices = input_var.data.eq(self.EOS)
    
                if eos_indices.nonzero().dim() > 0:
                    sequence_scores.data.masked_fill_(eos_indices, -float('inf'))
    
                # Cache results for backtracking
                stored_predecessors.append(predecessors)
                stored_emitted_symbols.append(input_var)
                stored_hidden.append(hidden)

            # Do backtracking to return the optimal values
            output, h_t, h_n, s, l, p = self._backtrack(stored_outputs, stored_hidden,
                                                        stored_predecessors, stored_emitted_symbols,
                                                        stored_scores, batch_size, self.hidden_size)
    
            # Build return objects
            decoder_outputs = torch.stack([step[:, 0, :] for step in output])
            metadata = {}
            if loss_for_tau_accum!=0:
                metadata['loss_for_tau'] = loss_for_tau_accum
            return decoder_outputs, metadata

    def _backtrack(self, nw_output, nw_hidden, predecessors, symbols, scores, b, hidden_size):
        """Backtracks over batch to generate optimal k-sequences.

        Args:
            nw_output [(batch*k, vocab_size)] * sequence_length: A Tensor of outputs from network
            nw_hidden [(num_layers, batch*k, hidden_size)] * sequence_length: A Tensor of hidden states from network
            predecessors [(batch*k)] * sequence_length: A Tensor of predecessors
            symbols [(batch*k)] * sequence_length: A Tensor of predicted tokens
            scores [(batch*k)] * sequence_length: A Tensor containing sequence scores for every token t = [0, ... , seq_len - 1]
            b: Size of the batch
            hidden_size: Size of the hidden state

        Returns:
            output [(batch, k, vocab_size)] * sequence_length: A list of the output probabilities (p_n)
            from the last layer of the RNN, for every n = [0, ... , seq_len - 1]

            h_t [(batch, k, hidden_size)] * sequence_length: A list containing the output features (h_n)
            from the last layer of the RNN, for every n = [0, ... , seq_len - 1]

            h_n(batch, k, hidden_size): A Tensor containing the last hidden state for all top-k sequences.

            score [batch, k]: A list containing the final scores for all top-k sequences

            length [batch, k]: A list specifying the length of each sequence in the top-k candidates

            p (batch, k, sequence_len): A Tensor containing predicted sequence
        """

        lstm = isinstance(nw_hidden[0], tuple)

        # initialize return variables given different types
        output = list()
        h_t = list()
        p = list()
        # Placeholder for last hidden state of top-k sequences.
        # If a (top-k) sequence ends early in decoding, `h_n` contains
        # its hidden state when it sees EOS.  Otherwise, `h_n` contains
        # the last hidden state of decoding.
        if lstm:
            state_size = nw_hidden[0][0].size()
            if torch.cuda.is_available():
                h_n = tuple([torch.zeros(state_size).cuda(), torch.zeros(state_size).cuda()])
            else:
                h_n = tuple([torch.zeros(state_size), torch.zeros(state_size)])
        else:
            h_n = torch.zeros(nw_hidden[0].size())
            if torch.cuda.is_available():
                h_n = h_n.cuda()
  
        l = [[self.max_length] * self.k for _ in range(b)]  # Placeholder for lengths of top-k sequences
                                                                # Similar to `h_n`
        # the last step output of the beams are not sorted
        # thus they are sorted here
        sorted_score, sorted_idx = scores[-1].view(b, self.k).topk(self.k)
        # initialize the sequence scores with the sorted last step beam scores
        s = sorted_score.clone()

        batch_eos_found = [0] * b   # the number of EOS found
                                    # in the backward loop below for each batch

        t = self.max_length - 1
        # initialize the back pointer with the sorted order of the last step beams.
        # add self.pos_index for indexing variable with b*k as the first dimension.
        t_predecessors = (sorted_idx + self.pos_index.expand_as(sorted_idx)).view(b * self.k)
        while t >= 0:
            # Re-order the variables with the back pointer
            current_output = nw_output[t].index_select(0, t_predecessors)
            if lstm:
                current_hidden = tuple([h.index_select(1, t_predecessors) for h in nw_hidden[t]])
            else:
                current_hidden = nw_hidden[t].index_select(1, t_predecessors)
            current_symbol = symbols[t].index_select(0, t_predecessors)
            # Re-order the back pointer of the previous step with the back pointer of
            # the current step
            t_predecessors = predecessors[t].index_select(0, t_predecessors).squeeze(-1)

            # This tricky block handles dropped sequences that see EOS earlier.
            # The basic idea is summarized below:
            #
            #   Terms:
            #       Ended sequences = sequences that see EOS early and dropped
            #       Survived sequences = sequences in the last step of the beams
            #
            #       Although the ended sequences are dropped during decoding,
            #   their generated symbols and complete backtracking information are still
            #   in the backtracking variables.
            #   For each batch, everytime we see an EOS in the backtracking process,
            #       1. If there is survived sequences in the return variables, replace
            #       the one with the lowest survived sequence score with the new ended
            #       sequences
            #       2. Otherwise, replace the ended sequence with the lowest sequence
            #       score with the new ended sequence
            #
            eos_indices = symbols[t].data.squeeze(1).eq(self.EOS).nonzero()
            if eos_indices.dim() > 0:
                for i in range(eos_indices.size(0)-1, -1, -1):
                    # Indices of the EOS symbol for both variables
                    # with b*k as the first dimension, and b, k for
                    # the first two dimensions
                    idx = eos_indices[i]
                    b_idx = int(idx[0] / self.k)
                    # The indices of the replacing position
                    # according to the replacement strategy noted above
                    res_k_idx = self.k - (batch_eos_found[b_idx] % self.k) - 1
                    batch_eos_found[b_idx] += 1
                    res_idx = b_idx * self.k + res_k_idx

                    # Replace the old information in return variables
                    # with the new ended sequence information
#                    print(t_predecessors, predecessors[0].size())
                    t_predecessors[res_idx] = predecessors[t][idx[0]]
                    current_output[res_idx, :] = nw_output[t][idx[0], :]
                    if lstm:
                        current_hidden[0][:, res_idx, :] = nw_hidden[t][0][:, idx[0], :]
                        current_hidden[1][:, res_idx, :] = nw_hidden[t][1][:, idx[0], :]
                        h_n[0][:, res_idx, :] = nw_hidden[t][0][:, idx[0], :].data
                        h_n[1][:, res_idx, :] = nw_hidden[t][1][:, idx[0], :].data
                    else:
                        current_hidden[:, res_idx, :] = nw_hidden[t][:, idx[0], :]
                        h_n[:, res_idx, :] = nw_hidden[t][:, idx[0], :].data
                    current_symbol[res_idx, :] = symbols[t][idx[0]]
                    s[b_idx, res_k_idx] = scores[t][idx[0]].data[0]
                    l[b_idx][res_k_idx] = t + 1

            # record the back tracked results
            output.append(current_output)
            h_t.append(current_hidden)
            p.append(current_symbol)

            t -= 1

        # Sort and re-order again as the added ended sequences may change
        # the order (very unlikely)
        s, re_sorted_idx = s.topk(self.k)
        for b_idx in range(b):
            l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx,:]]

        re_sorted_idx = (re_sorted_idx + self.pos_index.expand_as(re_sorted_idx)).view(b * self.k)

        # Reverse the sequences and re-order at the same time
        # It is reversed because the backtracking happens in reverse time order
        output = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(output)]
        p = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(p)]
        if lstm:
            h_t = [tuple([h.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for h in step]) for step in reversed(h_t)]
            h_n = tuple([h.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size) for h in h_n])
        else:
            h_t = [step.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for step in reversed(h_t)]
            h_n = h_n.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size)
        s = s.data

        return output, h_t, h_n, s, l, p

    def _mask_symbol_scores(self, score, idx, masking_score=-float('inf')):
            score[idx] = masking_score

    def _mask(self, tensor, idx, dim=0, masking_score=-float('inf')):
        if len(idx.size()) > 0:
            indices = idx[:, 0]
            tensor.index_fill_(dim, indices, masking_score)

    def update_dropout(self, dropout):
#        self.dropout.p = dropout
        self.embeddings.update_dropout(dropout)
