import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PegasusForConditionalGeneration, T5ForConditionalGeneration
import torch.nn.functional as F


class CLGenerator(nn.Module):
    def __init__(self, PTM, model_name, pad_id, args):
        super(CLGenerator, self).__init__()
        self.PTM = PTM
        if self.PTM == "pegasus":
            self.generator = PegasusForConditionalGeneration.from_pretrained(model_name)
        elif self.PTM == "t5" or "codet5":
            self.generator = T5ForConditionalGeneration.from_pretrained(model_name)
        else:
            raise NotImplementedError("not support this PTM yet")
        self.pad_id = pad_id
        self.hidden_size = self.generator.config.hidden_size
        self.vocab_size = self.generator.config.vocab_size
        self.linear_layer = nn.Linear(self.hidden_size, self.hidden_size)
        nn.init.xavier_uniform_(self.linear_layer.weight)
        self.ignore_index = -100
        self.loss_fct = CrossEntropyLoss(ignore_index=self.ignore_index)
        self.args = args

    def form_ngram(self, input_tensor, n=2):
        """
        input_tensor: batch x sample_num x seq_len
        return: batch x seq_len-3 x 4
        """
        bsz, cand_num, seq_len = input_tensor.size(0), input_tensor.size(1), input_tensor.size(2)
        seq_len_clip = seq_len - n + 1
        input_tensor_repeated = input_tensor[:, :, None, :].repeat(1, 1, seq_len_clip, 1)
        help_matrix_1 = torch.triu(torch.ones(seq_len, seq_len))
        help_matrix_2 = torch.triu(torch.ones(seq_len, seq_len), diagonal=n)
        help_matrix = (help_matrix_1 - help_matrix_2)[:seq_len_clip].bool()[None, None, :, :]
        ret_tensor = torch.masked_select(input_tensor_repeated, help_matrix.to(input_tensor.device))
        return ret_tensor.view(bsz, cand_num, seq_len_clip, n)

    def torch_bleu(self, ref_tensor, sys_tensor):
        """
        ref_tensor: batch x seq_len1
        sys_tensor: batch x sample_num x seq_len2
        """
        sys_padding = (~(sys_tensor == self.pad_id)).float()
        ref_padding = (~(ref_tensor == self.pad_id)).float()
        # 将 ref 和 sys的pad_id 换成不一样的 防止pad_id 的影响
        n = min(min(self.args.n_gram, ref_tensor.size(-1)), sys_tensor.size(-1))
        ref_lengths = torch.sum(ref_padding, dim=-1) - n + 1
        ref_ones = torch.ones_like(ref_lengths, device=ref_lengths.device)
        ref_lengths = torch.where(ref_lengths > 0, ref_lengths, ref_ones)
        sys_lengths = torch.sum(sys_padding, dim=-1) - n + 1
        sys_ones = torch.ones_like(sys_lengths, device=sys_lengths.device)
        sys_lengths = torch.where(sys_lengths > 0, sys_lengths, sys_ones)
        ref_tensor = ref_tensor * ref_padding
        bsz, sample_num = sys_tensor.size(0), sys_tensor.size(1)
        ref_tensor = ref_tensor[:, None, :].repeat(1, sample_num, 1)
        input_tensor1_4gram = self.form_ngram(ref_tensor, n).float()
        input_tensor2_4gram = self.form_ngram(sys_tensor, n).float()  # batch x sample_num x seq_len-3 x 4
        sim_matrix = torch.cosine_similarity(input_tensor2_4gram.unsqueeze(3), input_tensor1_4gram.unsqueeze(2),
                                             dim=-1) >= 1.0
        sim_matrix = torch.sum(torch.max(sim_matrix, dim=-1).values, dim=-1)
        length = sys_lengths + ref_lengths.unsqueeze(1)
        return sim_matrix / length  # batch x sample_num

    def affine_transformation(self, input_features, padding_mask, axis=1):
        trans_tmp = F.relu(self.linear_layer(input_features))  # batch
        length = torch.sum(padding_mask, axis=1).unsqueeze(-1)
        trans_tmp = trans_tmp * padding_mask.unsqueeze(-1).float()
        trans_emb = torch.sum(trans_tmp, axis=axis)
        return trans_emb * (1 / length)

    @torch.no_grad()
    def sample_from_model(self, src_inp, src_pad_mask):
        candidate_id_flat = self.generator.generate(
            input_ids=src_inp,
            attention_mask=src_pad_mask,
            num_return_sequences=self.args.beam_size,
            num_beam_groups=self.args.beam_size,
            diversity_penalty=self.args.diversity_pen,
            num_beams=self.args.beam_size,
            max_length=self.args.max_length + 2,
            min_length=self.args.min_length + 1,  # +1 from or
            no_repeat_ngram_size=self.args.no_repeat_ngram,
            length_penalty=self.args.length_pen,
            early_stopping=self.args.early_stop,
        )
        return candidate_id_flat.view(src_inp.size(0), self.args.beam_size, -1)

    def pad2max_len(self, input_tensor, max_len):
        pad_size = max_len - input_tensor.shape[-1]
        pad_tensor = torch.ones([input_tensor.shape[0], input_tensor.shape[1], pad_size],
                                device=input_tensor.device).long()
        return torch.cat([input_tensor, pad_tensor], dim=-1)

    def ranking_loss(self, cos_distance, bleu_distance):
        # equivalent to initializing TotalLoss to 0
        # here is to avoid that some special samples will not go into the following for loop
        margin = 0.01
        ones = torch.ones(cos_distance.size(), device=cos_distance.device)
        loss_func = torch.nn.MarginRankingLoss(0.0)
        total_loss = loss_func(cos_distance, cos_distance, ones)

        margin_func = self.args.margin_func  # choice: constant, ranking
        # candidate loss
        n = cos_distance.size(1)
        for i in range(1, n):
            pos_score = cos_distance[:, :-i]
            pos_bleu = bleu_distance[:, :-i]
            pos_bleu = (pos_bleu > 0.0).float()
            neg_score = cos_distance[:, i:]
            ones = torch.ones(pos_score.size(), device=cos_distance.device)
            loss_func = torch.nn.MarginRankingLoss(margin * i, reduction='none')
            marginal_loss = loss_func(pos_score, neg_score, ones)
            total_loss += (marginal_loss * pos_bleu).mean()

        return total_loss

    @torch.no_grad()
    def generate(self, input_ids, attention_mask, args):
        self.generator.eval()
        decoder = self.generator.get_decoder()
        # get summary embedding
        if args.diversity_pen > 0:
            num_group = args.beam_size
        else:
            num_group = 1
        ret_dict = self.generator.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            num_beam_groups=num_group,
            diversity_penalty=args.diversity_pen,
            num_return_sequences=args.beam_size,
            num_beams=args.beam_size,
            max_length=args.max_length + 2,
            # +2 from original because we start at step=1 and stop before max_length
            min_length=args.min_length + 1,  # +1 from original because we start at step=1
            no_repeat_ngram_size=args.no_repeat_ngram,
            length_penalty=args.length_pen,
            early_stopping=args.early_stop,
            output_scores=True,
            return_dict_in_generate=True,
            output_hidden_states=True
        )
        cand_ids = ret_dict["sequences"].view(input_ids.size(0), args.beam_size, -1)

        if args.alpha == 0.0:
            return cand_ids[:, 0, :]
        scores = ret_dict["sequences_scores"].view(input_ids.size(0), -1)
        cand_mask = (cand_ids != self.pad_id).long()
        cand_len = torch.sum(cand_mask, dim=-1)
        max_len = torch.max(cand_len).item()
        cand_ids = cand_ids[:, :, :max_len]
        # cand_len = torch.sum((cand_ids != self.pad_id), dim=-1)
        encoder_hidden_states = ret_dict["encoder_hidden_states"][-1]  # src_len x batch x hidden
        encoder_feature = self.affine_transformation(encoder_hidden_states, attention_mask)  # batch x h
        decoder_hidden_states = []
        for sample_idx in range(cand_ids.size(1)):
            sampled_input_dec = cand_ids[:, sample_idx]
            sample_pad_mask = ~(sampled_input_dec == self.pad_id)
            if args.PTM != "codet5":
                sample_pad_mask[:, 0] = 1
            decoder_out = decoder(input_ids=sampled_input_dec, attention_mask=sample_pad_mask,
                                  encoder_hidden_states=encoder_hidden_states,
                                  encoder_attention_mask=attention_mask)  # last layer
            decoder_feature = decoder_out[0]  # batch x tgt_len x hidden
            # if "t5" in self.PTM:
            #     decoder_feature = decoder_feature * (self.generator.model_dim ** -0.5)

            decoder_feature = self.affine_transformation(decoder_feature, sample_pad_mask)  # batch x h
            decoder_hidden_states.append(decoder_feature.unsqueeze(1))
        decoder_feature = torch.cat(decoder_hidden_states, dim=1)  # batch x sample_num x h
        cos_distance = torch.cosine_similarity(encoder_feature.unsqueeze(1), decoder_feature,
                                               dim=-1)  # batch x sample_num
        normalize = torch.sum(0 - scores, keepdim=True, dim=-1)
        score = (1 - args.alpha) * (scores / normalize) + args.alpha * cos_distance
        max_indices = torch.argmax(score, dim=-1)[:, None, None]
        dummy = max_indices.repeat(1, 1, cand_ids.size(2))
        return torch.gather(cand_ids, 1, dummy).squeeze(1)  # batch x seq_len

    def forward(self, src_inp, target_inp, target_outp):
        """
        cos_score distance of hypothesis to source
        bleu its actual bleu score
        """
        encoder = self.generator.get_encoder()
        decoder = self.generator.get_decoder()

        batch_size = src_inp.size(0)
        target_outp = target_outp.masked_fill(target_outp == self.pad_id, self.ignore_index)
        src_pad_mask = ~(src_inp == self.pad_id)

        encoder_hidden_states = encoder(src_inp, src_pad_mask)['last_hidden_state']

        cand_ids = self.sample_from_model(src_inp, src_pad_mask)  # batch x beam_size x seq_len
        # prepare contrastive learning
        samples_from_batch = target_inp[None, :, :].repeat(batch_size, 1, 1)
        cand_len = cand_ids.size(2)
        samples_len = samples_from_batch.size(2)
        if samples_len < cand_len:
            samples_from_batch = self.pad2max_len(samples_from_batch, cand_len)
        else:
            samples_from_batch = samples_from_batch[:, :, :cand_len]
        assert cand_ids.dtype == torch.int64
        assert samples_from_batch.dtype == torch.int64
        samples_all = torch.cat([cand_ids, samples_from_batch], dim=1)  # batch x total_sample_num x seq_len
        assert samples_all.dtype == torch.int64
        bleu_distance = self.torch_bleu(target_inp, samples_all)  # batch x total_sample_num
        # torch_bleu > 0.49 are ignored
        bleu_mask = (bleu_distance < 0.5)  # use to mask the gold
        bleu_distance_masked = bleu_distance * bleu_mask.float()
        sample_num = min(self.args.max_sample_num - 1, bleu_distance_masked.size(1) - 1)
        bleu_distance, bleu_indices = torch.sort(bleu_distance_masked, dim=-1, descending=True)
        sampled_bleu_distance = bleu_distance[:, :sample_num]
        sampled_bleu_indices = bleu_indices[:, :sample_num]
        # concat itself
        self_indices = torch.arange(0, batch_size).reshape(batch_size, 1).to(
            sampled_bleu_indices.device) + cand_ids.size(1)  # manually add gold
        sampled_indices = torch.cat([self_indices, sampled_bleu_indices], dim=-1)

        self_bleu = torch.full([batch_size, 1], 0.5, device=sampled_bleu_distance.device)
        sampled_bleu_distance = torch.cat([self_bleu, sampled_bleu_distance], dim=-1)
        dummy = sampled_indices.unsqueeze(-1).repeat(1, 1, samples_all.size(2))
        sampled_input = torch.gather(samples_all, 1, dummy)  # batch x sample_num x seq_len

        # print("sampled_bleu_distance sort", torch.sort(sampled_bleu_distance, dim=-1, descending=True).values)
        # feed 到 decoder 里面得到 feature
        decoder_hidden_states = []
        for sample_idx in range(sampled_indices.size(-1)):
            sampled_input_dec = sampled_input[:, sample_idx, :]
            sample_pad_mask = ~(sampled_input_dec == self.pad_id)
            if self.args.PTM != "codet5":
                sample_pad_mask[:, 0] = 1
            decoder_out = decoder(input_ids=sampled_input_dec, attention_mask=sample_pad_mask,
                                  encoder_hidden_states=encoder_hidden_states,
                                  encoder_attention_mask=src_pad_mask)  # last layer
            decoder_feature = decoder_out[0]  # batch x tgt_len x hidden
            # if "t5" in self.PTM:
            #     decoder_feature = decoder_feature * (self.generator.model_dim ** -0.5)
            decoder_feature = self.affine_transformation(decoder_feature, sample_pad_mask)  # batch x h
            decoder_hidden_states.append(decoder_feature.unsqueeze(1))

        encoder_feature = self.affine_transformation(encoder_hidden_states, src_pad_mask)  # batch x h
        decoder_feature = torch.cat(decoder_hidden_states, dim=1)  # batch x sample_num x h
        cos_distance = torch.cosine_similarity(encoder_feature.unsqueeze(1), decoder_feature,
                                               dim=-1)  # batch x samle_num
        cl_loss = self.ranking_loss(cos_distance, sampled_bleu_distance)
        tgt_pad_mask = ~(target_inp == self.pad_id)
        if self.args.PTM != "codet5":
            tgt_pad_mask[:, 0] = 1
        decoder_out = decoder(input_ids=target_inp, attention_mask=tgt_pad_mask,
                              encoder_hidden_states=encoder_hidden_states,
                              encoder_attention_mask=src_pad_mask)  # last layer

        decoder_last_layer = decoder_out[0]  # batch x tgt_len x hidden

        if "t5" in self.PTM:
            decoder_last_layer = decoder_last_layer * (self.generator.model_dim ** -0.5)

        lm_logits = self.generator.lm_head(decoder_last_layer)
        nll_loss = self.loss_fct(lm_logits.view(-1, self.vocab_size), target_outp.view(-1))

        return {'loss': nll_loss + cl_loss, "cl_loss": cl_loss}


class GeneratorBaseline(nn.Module):
    def __init__(self, PTM, model_name, pad_id, scratch):
        super(GeneratorBaseline, self).__init__()
        if PTM == "pegasus":
            self.backbone = PegasusForConditionalGeneration.from_pretrained(model_name)
        elif PTM == "t5" or "codet5":
            self.backbone = T5ForConditionalGeneration.from_pretrained(model_name)
            if scratch:
                print("random initialize...")
                self.backbone = T5ForConditionalGeneration(self.backbone.config)
        else:
            raise NotImplementedError("not support this PTM")
        self.vocab_size = self.backbone.config.vocab_size
        self.hidden_size = self.backbone.config.hidden_size
        self.pad_id = pad_id
        self.ignore_index = -100
        self.loss_fct = CrossEntropyLoss(ignore_index=self.ignore_index)

    def forward(self, src_inp, target_inp, target_outp):
        decoder = self.backbone.get_decoder()
        encoder = self.backbone.get_encoder()
        target_outp = target_outp.masked_fill(target_outp == self.pad_id, self.ignore_index)

        src_pad_mask = ~(src_inp == self.pad_id)
        tgt_pad_mask = ~(target_inp == self.pad_id)
        tgt_pad_mask[:, 0] = 1
        # get document embedding
        encoder_out = encoder(src_inp, attention_mask=src_pad_mask)['last_hidden_state']  # last layer
        dec_out = decoder(input_ids=target_inp, attention_mask=tgt_pad_mask,
                          encoder_hidden_states=encoder_out,
                          encoder_attention_mask=src_pad_mask, past_key_values=None)  # last layer
        target_out = dec_out[0]
        if "t5" in self.PTM:
            target_out = target_out * (self.backbone.model_dim ** -0.5)
        # summ_out b x seq_len x hidden_size
        lm_logits = self.backbone.lm_head(target_out)
        nll_loss = self.loss_fct(lm_logits.view(-1, self.vocab_size), target_outp.view(-1))
        return {'loss': nll_loss}
