import pdb
import string

import torch

from src.nllb.Model_generator import nllb_generate_one, mistral_generate_one
from collections import Counter


class Ensembler_generator():
    def __init__(self, model1, tokenizer1, model2, tokenizer2, ensembler, device="cuda:0"):
        self.model1 = model1
        self.tokenizer1 = tokenizer1

        self.model2 = model2
        self.tokenizer2 = tokenizer2

        self.ensembler = ensembler
        self.device = device

    def learining_rate_auto_cal(self, llm1_input_text, llm2_input_text):
        LLM1_input_ids = self.tokenizer1.encode(llm1_input_text, return_tensors="pt").to(self.device)
        LLM1_input_ids_length = LLM1_input_ids.shape[1]

        LLM2_input_ids = self.tokenizer1.encode(llm2_input_text, return_tensors="pt").to(self.device)
        LLM2_input_ids_length = LLM2_input_ids.shape[1]

        LLM1_output = mistral_generate_one(self.model1, LLM1_input_ids).to(self.device)
        LLM2_output = mistral_generate_one(self.model2, LLM2_input_ids).to(self.device)
        pdb.set_trace()
        for _ in range(400):
            ensemble_output = self.ensembler.ensemble(LLM1_output, LLM2_output,
                                                      learning_rate=0)

        return 1

    def mistral_llama_ensemble_output(self, llm1_input_text, llm2_input_text, learning_rate=0):
        LLM1_input_ids = self.tokenizer1.encode(llm1_input_text, return_tensors="pt").to(self.device)
        LLM1_input_ids_length = LLM1_input_ids.shape[1]
        LLM2_input_ids = self.tokenizer2.encode(llm2_input_text, return_tensors="pt").to(self.device)
        LLM2_input_ids_length = LLM2_input_ids.shape[1]

        for _ in range(400):

            LLM1_output = mistral_generate_one(self.model1, LLM1_input_ids).to(self.device)
            LLM2_output = mistral_generate_one(self.model2, LLM2_input_ids).to(self.device)

            preliminary_next_token_idx = torch.argmax(LLM1_output, dim=-1)
            if preliminary_next_token_idx.item() == 2:
                break
            preliminary_next_token_to_string = self.tokenizer1.convert_ids_to_tokens(preliminary_next_token_idx)[
                0].replace(
                "▁",
                " ").strip()
            if len(preliminary_next_token_to_string) == 1 and preliminary_next_token_to_string in string.punctuation:
                next_token_idx = preliminary_next_token_idx
            else:
                ensemble_output = self.ensembler.ensemble(LLM1_output, LLM2_output,
                                                          learning_rate=learning_rate)
                next_token_idx = torch.argmax(ensemble_output, dim=-1)

            if next_token_idx.item() == 2:
                break

            LLM1_input_ids = torch.cat((LLM1_input_ids, next_token_idx.view(1, -1)), dim=1)
            current_input_string = self.tokenizer1.decode(LLM1_input_ids.tolist()[0])

            LLM2_input_ids = self.tokenizer2.encode(current_input_string)

        return self.tokenizer1.decode(LLM1_input_ids[:, LLM1_input_ids_length:].tolist()[0])

    def mistral_nllb_ensemble_translate(self, llm_input_text, nllb_input_text, nllb_tgt_lang, learning_rate=0):
        LLM_input_ids = self.tokenizer1.encode(llm_input_text, return_tensors="pt").to(self.device)
        LLM_input_ids_length = LLM_input_ids.shape[1]
        translator_input_ids = self.tokenizer2.encode(nllb_input_text, return_tensors="pt").to(self.device)
        translator_decoder_prefix_input_ids = torch.tensor(
            [self.tokenizer2.eos_token_id, self.tokenizer2.lang_code_to_id[nllb_tgt_lang]]).to(self.device)

        translator_decoder_input_ids = translator_decoder_prefix_input_ids
        for _ in range(400):

            NLLB_output = nllb_generate_one(self.model2, translator_input_ids, translator_decoder_input_ids).to(
                self.device)
            LLM_output = mistral_generate_one(self.model1, LLM_input_ids).to(self.device)

            preliminary_next_token_idx = torch.argmax(LLM_output, dim=-1)
            if preliminary_next_token_idx.item() == 13:
                break
            preliminary_next_token_to_string = self.tokenizer1.convert_ids_to_tokens(preliminary_next_token_idx)[
                0].replace(
                "▁",
                " ").strip()
            if len(preliminary_next_token_to_string) == 1 and preliminary_next_token_to_string in string.punctuation:
                next_token_idx = preliminary_next_token_idx
            else:
                ensemble_output = self.ensembler.ensemble(LLM_output, NLLB_output,
                                                          learning_rate=learning_rate)
                next_token_idx = torch.argmax(ensemble_output, dim=-1)

            if next_token_idx.item() == 13:
                break

            LLM_input_ids = torch.cat((LLM_input_ids, next_token_idx.view(1, -1)), dim=1)

            current_translation = self.tokenizer1.decode(LLM_input_ids[:, LLM_input_ids_length:].tolist()[0])
            # print(current_translation)
            translator_decoder_current_translation_ids = self.tokenizer2.encode(" " + current_translation,
                                                                                return_tensors="pt",
                                                                                add_special_tokens=False).to(
                self.device)
            translator_decoder_input_ids = torch.cat(
                (translator_decoder_prefix_input_ids, torch.squeeze(translator_decoder_current_translation_ids, dim=0)),
                dim=0)

            # print(self.tokenizer2.decode(translator_decoder_input_ids, skip_special_tokens=True))
            # print(self.tokenizer1.decode(LLM_input_ids[:, LLM_input_ids_length:].tolist()[0]))

        return self.tokenizer1.decode(LLM_input_ids[:, LLM_input_ids_length:].tolist()[0])

    def nllb_mistral_ensemble_translate(self, nllb_input_text, nllb_tgt_lang, llm_input_text, learning_rate=0):
        translator_input_ids = self.tokenizer1.encode(nllb_input_text, return_tensors="pt").to(self.device)
        LLM_input_ids = self.tokenizer2.encode(llm_input_text, return_tensors="pt").to(self.device)
        LLM_input_ids_length = LLM_input_ids.shape[1]
        translator_decoder_prefix_input_ids = torch.tensor(
            [self.tokenizer1.eos_token_id, self.tokenizer1.lang_code_to_id[nllb_tgt_lang]]).to(self.device)

        translator_decoder_input_ids = translator_decoder_prefix_input_ids
        first_flag = True
        for _ in range(1000):

            NLLB_output = nllb_generate_one(self.model1, translator_input_ids, translator_decoder_input_ids).to(
                self.device)
            LLM_output = mistral_generate_one(self.model2, LLM_input_ids).to(self.device)

            first_next_token_idx = torch.argmax(NLLB_output, dim=-1)
            if first_next_token_idx.item() == 2:
                break

            first_next_token_to_string = self.tokenizer1.convert_ids_to_tokens(first_next_token_idx)[0].replace("▁",
                                                                                                                " ").strip()
            if len(first_next_token_to_string) == 1 and first_next_token_to_string in string.punctuation:
                next_token_idx = first_next_token_idx
            else:
                ensemble_output = self.ensembler.ensemble(NLLB_output, LLM_output,
                                                          learning_rate=learning_rate)
                next_token_idx = torch.argmax(ensemble_output, dim=-1)
            if next_token_idx.item() == 2:
                break

            translator_decoder_input_ids = torch.cat((translator_decoder_input_ids, next_token_idx), dim=0)

            next_token_to_string = self.tokenizer1.convert_ids_to_tokens(next_token_idx)[0].replace("▁", " ")
            if first_flag:
                llm_input_text = (llm_input_text + next_token_to_string.strip())
                first_flag = False
            else:
                llm_input_text = (llm_input_text + next_token_to_string)
            LLM_input_ids = self.tokenizer2.encode(llm_input_text,
                                                   return_tensors="pt",
                                                   add_special_tokens=True).to(self.device)

            # print(self.tokenizer1.decode(translator_decoder_input_ids, skip_special_tokens=True))
            # print(self.tokenizer2.decode(LLM_input_ids[:, LLM_input_ids_length:].tolist()[0]))

        return self.tokenizer1.decode(translator_decoder_input_ids, skip_special_tokens=True)
