import os
import pickle
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from datasets import load_dataset
import random
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import Dataset
from transformers import ViTModel, WhisperModel, AutoFeatureExtractor, AutoConfig
import os
import pickle
from transformers import (
    TrainingArguments,
    Trainer,
    GPTNeoForCausalLM,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    AutoModel,
    PreTrainedModel,
)


class AdaptorAudioModel(PreTrainedModel):
    def __init__(self, config, model_backbone, is_adaptor=False):
        super().__init__(config)
        self.config = config
        self.backbone = model_backbone
        self.num_labels = 1
        self.lm_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, 1),
        )
        # self.lm_head = nn.Linear(config.hidden_size, 1)
        # freeze the backbone
        print(is_adaptor)
        if is_adaptor:
            print("freezing the backbone")
            for param in self.backbone.parameters():
                param.requires_grad = False
        else:
            print("not freezing the backbone")

    def forward(self, input_ids, attention_mask, labels=None):
        attention_mask = attention_mask.squeeze(0)

        outputs = self.backbone(input_ids, decoder_input_ids=attention_mask)

        last_hidden_states = outputs.encoder_last_hidden_state
        last_embed = torch.mean(last_hidden_states, dim=1)

        logits = self.lm_head(last_embed)

        loss_fct = nn.BCELoss(reduction="none")
        prob = torch.sigmoid(logits)

        loss = loss_fct(prob, labels)
        # loss = loss.squeeze(0)[0]
        # loss = loss.squeeze(0)[0].squeeze()

        return (loss.mean(), logits)


class AdaptorImageModel(PreTrainedModel):
    def __init__(self, config, model_backbone, is_adaptor=False):
        super().__init__(config)
        self.config = config
        self.backbone = model_backbone
        self.num_labels = 1
        self.lm_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, 1),
        )
        # self.lm_head = nn.Linear(config.hidden_size, 1)
        # freeze the backbone
        if is_adaptor:
            print("freezing the backbone")
            for param in self.backbone.parameters():
                param.requires_grad = False
        else:
            print("not freezing the backbone")

    def forward(self, input_ids, attention_mask, labels=None):
        input_ids = input_ids.squeeze(1)

        outputs = self.backbone(input_ids)

        last_hidden_states = outputs.last_hidden_state
        # mean across dim=1
        last_embed = torch.mean(last_hidden_states, dim=1)

        logits = self.lm_head(last_embed)

        loss_fct = nn.BCELoss(reduction="none")
        prob = torch.sigmoid(logits)

        loss = loss_fct(prob, labels)
        # loss = loss.squeeze(0)[0]
        # loss = loss.squeeze(0)[0].squeeze()

        return (loss.mean(), logits)


def load_data_mm(data_path, dataset="sms", input_key="sms", seed=42):
    if dataset == "cifar10":
        dataset = load_dataset("cifar10")
        key = "img"
        # filter dataset_train to only include samples with label 0,1
        dataset_train = (
            dataset["train"]
            .filter(lambda example: example["label"] in [0, 2])  # [0,2]
            .map(lambda example: {"label": 0 if example["label"] == 0 else 1})
        )
        dataset_test = (
            dataset["test"]
            .filter(lambda example: example["label"] in [0, 2])
            .map(lambda example: {"label": 0 if example["label"] == 0 else 1})
        )
    elif dataset == "mnist":
        input_key = "image"
        dataset = load_dataset("mnist")
        """
        dataset_train = dataset["train"].map(
            lambda example: {"label": 0 if example["label"] in [0, 1, 2, 3, 4] else 1}
        )
        dataset_test = dataset["test"].map(
            lambda example: {"label": 0 if example["label"] in [0, 1, 2, 3, 4] else 1}
        )
        """
        dataset_train = (
            dataset["train"]
            .filter(lambda example: example["label"] in [0, 8])
            .map(lambda example: {"label": 0 if example["label"] == 0 else 1})
        )

        dataset_test = (
            dataset["test"]
            .filter(lambda example: example["label"] in [0, 8])
            .map(lambda example: {"label": 0 if example["label"] == 0 else 1})
        )

    elif dataset == "speech_commands":
        dataset = load_dataset("speech_commands", "v0.01")
        # filter dataset_train to only include samples with label [8,9] and map to [0,1]
        dataset_train = (
            dataset["train"]
            .filter(lambda example: example["label"] in [0, 1])
            .map(lambda example: {"label": 0 if example["label"] == 0 else 1})
        )
        dataset_test = (
            dataset["test"]
            .filter(lambda example: example["label"] in [0, 1])
            .map(lambda example: {"label": 0 if example["label"] == 0 else 1})
        )

    dataset_train_1 = (
        dataset_train.filter(lambda example: example["label"] == 1)
        .shuffle(seed=seed)
        .select(range(128))
    )
    # get k samples from dataset_test where label = 0
    dataset_train_0 = (
        dataset_train.filter(lambda example: example["label"] == 0)
        .shuffle(seed=seed)
        .select(range(128))
    )

    return (
        dataset_train_1[input_key],
        dataset_train_1["label"],
        dataset_train_0[input_key],
        dataset_train_0["label"],
    ), (dataset_test[input_key], dataset_test["label"])


def get_k_split(X_train_1, y_train_1, X_train_0, y_train_0, k, seed):
    x_k = X_train_0[0 : int(k / 2)] + X_train_1[0 : int(k / 2)]
    y_k = y_train_0[0 : int(k / 2)] + y_train_1[0 : int(k / 2)]
    ids = list(range(k))
    random.seed(seed)
    random.shuffle(ids)
    x_k = [x_k[i] for i in ids]
    y_k = [y_k[i] for i in ids]
    return x_k, y_k


class ImageDataset(Dataset):
    def __init__(self, image_list, label_list):
        # add assert function if model_type is NOne

        # define variables
        self.images = []
        self.labels = []
        self.transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
            ]
        )

        for img, label in zip(image_list, label_list):
            # prepare the text
            image = self.transform(img)
            image = image.unsqueeze(0)
            label = torch.Tensor([label])
            self.images.append(image)
            self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return (
            self.images[idx],
            self.labels[idx],
            self.labels[idx],
        )


class AudioDataset(Dataset):
    def __init__(
        self,
        audio_list,
        label_list,
        feature_extractor,
        model,
    ):
        # add assert function if model_type is NOne
        print("audio_list", len(audio_list))
        # define variables
        self.input_features = []
        self.decoder_input_ids_lst = []
        self.labels = []

        for rec, label in zip(audio_list, label_list):
            inputs = feature_extractor(
                rec["array"], sampling_rate=16000, return_tensors="pt"
            )
            input_features = inputs.input_features
            decoder_input_ids = (
                torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
            )

            label = torch.Tensor([label])
            self.input_features.append(input_features.squeeze(0))
            self.labels.append(label)
            self.decoder_input_ids_lst.append(decoder_input_ids)

    def __len__(self):
        return len(self.input_features)

    def __getitem__(self, idx):
        breakpoint()
        return (
            self.input_features[idx],
            self.decoder_input_ids_lst[idx],
            self.labels[idx],
        )


class AudioDataset(Dataset):
    def __init__(
        self,
        audio_list,
        label_list,
        feature_extractor,
        model,
    ):
        # add assert function if model_type is NOne
        print("audio_list", len(audio_list))
        # define variables
        self.input_features = []
        self.decoder_input_ids_lst = []
        self.labels = []

        for rec, label in zip(audio_list, label_list):
            inputs = feature_extractor(
                rec["array"], sampling_rate=16000, return_tensors="pt"
            )
            input_features = inputs.input_features
            decoder_input_ids = (
                torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
            )

            label = torch.Tensor([label])
            self.input_features.append(input_features.squeeze(0))
            self.labels.append(label)
            self.decoder_input_ids_lst.append(decoder_input_ids)

    def __len__(self):
        return len(self.input_features)

    def __getitem__(self, idx):
        return (
            self.input_features[idx],
            self.decoder_input_ids_lst[idx],
            self.labels[idx],
        )


def load_model_and_tokenizer(modality, model_name, is_adaptor=False):
    print(is_adaptor)
    config = AutoConfig.from_pretrained(model_name)

    if modality == "image":
        print("loading image model")
        input_key = "img"
        # "google/vit-base-patch16-224-in21k" # "openai/whisper-base
        embed_model = ViTModel.from_pretrained(
            model_name,
        ).cuda()
        embed_tokenizer = None
        model = AdaptorImageModel(
            config=config, model_backbone=embed_model, is_adaptor=is_adaptor
        ).cuda()
    elif modality == "audio":
        print("loading audio model")
        # "openai/whisper-base
        input_key = "audio"
        embed_model = WhisperModel.from_pretrained(
            model_name,
        ).cuda()
        embed_tokenizer = AutoFeatureExtractor.from_pretrained(
            model_name,
        )
        print("audio model")
        model = AdaptorAudioModel(
            config=config, model_backbone=embed_model, is_adaptor=is_adaptor
        ).cuda()

    return model, embed_tokenizer


def finetune_model(
    model_name,
    dataset="cifar10",
    key="img",
    lr=1e-4,
    epoch_range=[10],
    k_range=[18, 32, 48, 64],
    seed=42,
    modality="image",
    save_dir=".",
    is_adaptor=False,
):
    if modality == "image":
        key = "img"
    else:
        key = "audio"

    (X_train_1, y_train_1, X_train_0, y_train_0), (X_test, y_test) = load_data_mm(
        dataset, dataset, key, seed
    )
    results = {}
    print("dataset loaded")

    for num_epoch in epoch_range:
        results[num_epoch] = {}
        for k in k_range:
            # if model is None and tokenizer is None:
            model, feature_extractor = load_model_and_tokenizer(
                model_name=model_name,
                modality=modality,
                is_adaptor=is_adaptor,
            )
            print("model loaded")
            print("Epoch: ", num_epoch)

            # set model name
            results[num_epoch][k] = None

            x_k, y_k = get_k_split(X_train_1, y_train_1, X_train_0, y_train_0, k, seed)
            print(modality)
            if modality == "image":
                print(modality)
                train_dataset = ImageDataset(
                    x_k,
                    y_k,
                )
                test_dataset = ImageDataset(X_test, y_test)
            else:
                train_dataset = AudioDataset(x_k, y_k, feature_extractor, model)
                test_dataset = AudioDataset(X_test, y_test, feature_extractor, model)

            training_args = TrainingArguments(
                num_train_epochs=num_epoch,
                logging_steps=10,
                lr_scheduler_type="constant",
                save_strategy="no",
                save_total_limit=1,
                evaluation_strategy="no",
                per_device_train_batch_size=4,  # 4,  #4,
                per_device_eval_batch_size=1,
                gradient_accumulation_steps=1,
                warmup_steps=0,
                weight_decay=0.01,
                logging_dir="logs",
                learning_rate=lr,
                output_dir=save_dir,
                logging_strategy="epoch",
                report_to="none",
            )

            print(f"start training: epoch {num_epoch}, k {k}, seed {seed}")

            # start training
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                data_collator=lambda data: {
                    "input_ids": torch.stack(
                        [f[0].squeeze(0) for f in data]
                    ),  # torch.stack([f[0] for f in data]), #[f[0].squeeze(0) for f in data]),
                    "attention_mask": torch.stack(
                        [f[1].squeeze(0) for f in data]
                    ),  # [f[1].squeeze(0) for f in data]),
                    "labels": torch.stack([f[-1] for f in data]),
                },
            )
            trainer.train()

            _ = model.eval()

            print("evaluating")
            eval_outputs = eval_adaptor(model, test_dataset, modality=modality)
            print(eval_outputs["accuracy"])
            results[num_epoch][k] = eval_outputs

            del model.backbone
            del model
            del feature_extractor
            torch.cuda.empty_cache()
    return results


def call_finetune(
    seeds=[42, 69, 128, 512, 1024],
    model_name="openai/whisper-large",  # "google/vit-base-patch16-224-in21k",
    dataset="speech_commands",
    key="img",
    lr=1e-3,
    epoch_range=[10],
    k_range=[64, 18, 32, 48, 64],
    output_dir="./evals",
    modality="audio",
    save_dir=".",
    is_adaptor=False,
):
    model_name = model_name
    dataset = dataset
    eval_dataset = dataset
    if is_adaptor:
        model_type = "ft_adaptor"
    else:
        model_type = "ft_full"

    file_name = f"train_{dataset}_eval_{eval_dataset}_lr{lr}.pkl".replace("/", "-")
    final_results = {}
    for seed in seeds:  # 69, 128, 512, 1024]:  # 9, 204, 405, 9205, 2020]:``
        results = finetune_model(
            model_name,
            dataset=dataset,
            key=key,
            lr=lr,
            epoch_range=epoch_range,
            k_range=k_range,
            seed=seed,
            modality=modality,
            save_dir=save_dir,
            is_adaptor=is_adaptor,
        )
        final_results[seed] = results

    model_name_split = model_name.split("/")[-1]
    output_dir = f"{output_dir}/{dataset}/{model_name_split}/{model_type}/"
    save_path = os.path.join(output_dir, file_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    pickle.dump(final_results, open(save_path, "wb"))


def eval_adaptor(
    model,
    test_ds,
    modality="image",
):
    print("eval")
    gt_label, predicted_label, original_text, predicted_text, predicted_scores = (
        [],
        [],
        [],
        [],
        [],
    )
    _ = model.eval()
    with torch.no_grad():
        # iter over all of the test data
        for idx in range(test_ds.__len__()):
            # get the text and label

            image, decoder_input_ids, label = test_ds.__getitem__(idx)
            # label = test_ds.__getnumericlabel__(idx)
            # create prompt (in compliance with the one used during training)
            # perform prediction
            # if modality == "image":
            image = image.unsqueeze(0)

            sample_outputs = model(
                image.cuda(),
                attention_mask=decoder_input_ids.unsqueeze(0).cuda(),
                labels=label.unsqueeze(0).cuda(),
            )
            # decode the predicted tokens into texts

            logits = sample_outputs[-1]
            pred_score = torch.sigmoid(logits)
            if pred_score > 0.5:
                pred = 1
            else:
                pred = 0

            predicted_label.append(int(pred))
            gt_label.append(int(label))
            predicted_scores.append(pred)

        eval_outputs = {
            "predicted_label": predicted_label,
            "gt_label": gt_label,
            "predicted_scores": predicted_scores,
            "accuracy": sum(
                [1 if x == y else 0 for x, y in zip(gt_label, predicted_label)]
            )
            / len(gt_label),
        }

        # predict the accuracy
        return eval_outputs


for m in [
    "openai/whisper-large"
]:  # ["google/vit-large-patch16-224-in21k"]: #["openai/whisper-large"]: #["google/vit-large-patch16-224-in21k"]: #"openai/whisper-large", "google/vit-large-patch16-224-in21k"]:
    if m == "openai/whisper-large":
        dataset = "speech_commands"
        modality = "audio"
        key = "audio"
    else:
        dataset = "mnist"
        modality = "image"
        key = "audio"

    call_finetune(
        model_name=m,  # "openai/whisper-large", #"google/vit-base-patch16-224-in21k",
        dataset=dataset,
        key=key,
        lr=0.00005,  # 5e-05, #1e-05 for whisper. 10, vit (1e-05, 20)
        epoch_range=[10],
        k_range=[128, 64, 18, 32, 48, 256],
        seeds=[42, 69, 128],
        modality=modality,
        save_dir=".",
        is_adaptor=False,
        output_dir="./outputs",
    )
