import copy
import os.path

import optuna
import pandas as pd
import numpy as np
from typing import Any, Callable, Optional, Dict, Sequence
import torch
from torch.utils.data import DataLoader

from adl4cv.classification.sampling.sampler import SubsetSequentialSampler
from adl4cv.classification.log.logger import init_logger
from adl4cv.classification.model.graph._graph.graph_builder import GraphType
from adl4cv.parameters.params import HyperParameterSet, HyperParameterSpace, DefinitionSpace
from adl4cv.active_learning.query.queries import Query, QueryDefinition, QueryType


class AttentionQueryHyperParameterSet(HyperParameterSet):
    def __init__(self,
                 num_of_estimations: int = 10,
                 log_query: bool = False,
                 **kwargs: Any):
        super().__init__(**kwargs)
        self.num_of_estimations = num_of_estimations
        self.log_query = log_query

    def definition_space(self):
        return AttentionQueryHyperParameterSpace(self)


class AttentionQueryDefinition(QueryDefinition):

    def __init__(self, hyperparams: AttentionQueryHyperParameterSet = AttentionQueryHyperParameterSet()):
        super().__init__(QueryType.AttentionQuery, hyperparams)

    @property
    def _instantiate_func(self) -> Optional[Callable]:
        raise NotImplementedError()

    def instantiate(self, log_folder, *args, **kwargs):
        """Instantiates the module"""
        return AttentionQuery(log_folder, self.hyperparams)

    def definition_space(self):
        return AttentionQueryDefinitionSpace(self.hyperparams.definition_space())


class AttentionQuery(Query):
    def __init__(self, log_folder: str, params: AttentionQueryHyperParameterSet = AttentionQueryHyperParameterSet()):
        super().__init__(params)
        self.log_folder = log_folder
        self.logger = init_logger(self.__class__.__name__)

    @property
    def query_file(self):
        return os.path.join(self.log_folder, "attention_query.csv")

    @property
    def attention_layer(self):
        return "model.mpn_net.model.mpn_layers.layer_0.trans.alpha_dropout"

    def sort_unlabeled_pool_by_metric(self, model, datamodule, num_samples_to_evaluate: int, ascending: bool = False):
        """
        Sorts the first <num_samples_to_evaluate> samples of the unlabeled pool by attention
        :param num_samples_to_evaluate: The number of unlabeled samples to evaluate
        :return: Logs if requested
        """
        self.logger.info("Sorting the unlabeled pool by attention!")
        attentions = self.evaluate_metric(model, datamodule, num_samples_to_evaluate)
        if ascending:
            attentions = -attentions
        query_logs = self.get_query_logs(datamodule, attentions)

        datamodule.sort_unlabeled_indices_by_list(attentions)

        self.logger.info("Unlabeled pool sorted by attention!")
        return query_logs

    def get_query_logs(self, datamodule, attentions):
        samples_evaluated = len(attentions)
        query_logs = pd.DataFrame(
            {"sample_id": datamodule.unlabeled_pool_indices[:samples_evaluated],
             "attention": attentions.tolist()},
        )

        query_logs = query_logs.sort_values("attention", ascending=False)
        return query_logs

    def evaluate_metric(self, model, datamodule, num_samples_to_evaluate: int):
        """
        Estimate the attention of the first <num_samples_to_evaluate> elements of the unlabeled pool
        :param num_samples_to_evaluate: The number of samples to evaluate from the unlabeled pool
        :return: Numpy array of attentions of the evaluated samples
        """
        num_heads = model.params.mpn_net_def.hyperparams.num_heads
        batch_size = datamodule.batch_size
        sample_attentions = torch.zeros((num_heads, num_samples_to_evaluate), device=model.device)

        self.logger.info("Estimating sample attentions!")
        for estimation_idx in range(self.params.num_of_estimations):
            self.logger.debug(f"Sample attention estimation {estimation_idx} started!")

            # Random shuffling of indices
            shuffled_indices = torch.randperm(num_samples_to_evaluate)

            loader = DataLoader(datamodule.unaugmented_dataset_train,
                                batch_size=batch_size,
                                sampler=SubsetSequentialSampler(
                                    torch.tensor(datamodule.unlabeled_pool_indices)[shuffled_indices]),
                                num_workers=datamodule.num_workers,
                                drop_last=True,
                                # more convenient if we maintain the order of subset
                                pin_memory=True)

            model.eval()

            model.clean_inspected_layers()
            handle = model.inspect_layer_output(self.attention_layer)
            self.logger.debug("Attention inspection is registered!")

            self.logger.debug(f"Size of attention query loader: {len(loader)}")

            self.logger.debug("Evaluating samples...")
            with torch.no_grad():
                for (inputs, _) in loader:
                    inputs = inputs.to(model.device)
                    model(inputs)

            model.train()

            handle.remove()
            attention = self._resize_attention(model.inspected_variables[self.attention_layer].detach(),
                                               model,
                                               batch_size)
            sample_attentions[:, shuffled_indices] += attention.view(num_heads,
                                                                     batch_size,
                                                                     -1).mean(dim=1)
            model.clean_inspected_layers()

        self.logger.info("Sample attention estimation finished!")

        sample_attentions /= self.params.num_of_estimations
        sample_attentions = sample_attentions.sum(dim=0)

        return np.squeeze(sample_attentions.cpu().numpy())

    def _resize_attention(self, flat_attention, model, batch_size):
        if model.params.mpn_net_def.hyperparams.graph_builder_def.type == GraphType.SSLDENSE:

            attention_weights = flat_attention.view(
                model.model.mpn_net.params.num_heads,
                batch_size,
                batch_size - 1,
                -1)

        elif model.params.mpn_net_def.hyperparams.graph_builder_def.type == GraphType.DENSE:
            attention_weights = flat_attention.view(
                model.model.mpn_net.params.num_heads,
                batch_size,
                batch_size,
                -1)
        else:
            attention_weights = flat_attention.view(-1,
                                                    batch_size,
                                                    model.model.mpn_net.params.num_message_pass)

        return attention_weights


class AttentionQueryHyperParameterSpace(HyperParameterSpace):

    def __init__(self, default_hyperparam_set: AttentionQueryHyperParameterSet = AttentionQueryHyperParameterSet()):
        self.default_hyperparam_set = default_hyperparam_set

    @property
    def search_grid(self) -> Dict[str, Sequence[Any]]:
        return {}

    @property
    def search_space(self) -> Dict[str, Sequence[Any]]:
        raise NotImplementedError()

    def suggest(self, trial: optuna.Trial) -> AttentionQueryHyperParameterSet:
        hyperparam_set = copy.deepcopy(self.default_hyperparam_set)
        return hyperparam_set


class AttentionQueryDefinitionSpace(DefinitionSpace):
    """DefinitionSpace of the GeneralNet"""

    def __init__(self, hyperparam_space: AttentionQueryHyperParameterSpace = AttentionQueryHyperParameterSpace()):
        super().__init__(QueryType.AttentionQuery, hyperparam_space)

    def suggest(self, trial: optuna.Trial) -> AttentionQueryDefinition:
        return AttentionQueryDefinition(self.hyperparam_space.suggest(trial))
