from collections import OrderedDict
from typing import Callable, Dict, Sequence, Any

import optuna
import torch
from torch import nn

from adl4cv.classification.loss.loss_calculator import LossEvaluatorHyperParameterSet
from adl4cv.classification.model.classification_module import ClassificationModuleHyperParameterSet, \
    ClassificationModule, ClassificationModuleHyperParameterSpace
from adl4cv.classification.model.convolutional.resnet import ResNet18Definition
from adl4cv.classification.model.graph.message_passing_net import MessagePassingNetDefinition
from adl4cv.classification.model.models import ClassificationModuleDefinition, ModelType
from adl4cv.classification.optimizer.optimizers import OptimizerDefinition
from adl4cv.classification.optimizer.schedulers import SchedulerDefinition
from adl4cv.parameters.params import DefinitionSpace


class GeneralNetHyperParameterSet(ClassificationModuleHyperParameterSet):
    """HyperParameterSet of the GeneralNet"""

    def __init__(self,
                 iid_net_def: ClassificationModuleDefinition = ResNet18Definition(),
                 mpn_net_def: ClassificationModuleDefinition = MessagePassingNetDefinition(),
                 optimizer_definition: OptimizerDefinition = None,
                 scheduler_definition: SchedulerDefinition = None,
                 loss_calc_params: Dict[str, LossEvaluatorHyperParameterSet] = {},
                 **kwargs):
        """
        Creates new HyperParameterSet
        :param iid_net_def: The definition of the feature extraction part
        :param mpn_net_def: The definition of the message passing network part
        :func:`~ClassificationModuleHyperParameterSet.__init__`
        """
        super().__init__(optimizer_definition, scheduler_definition, loss_calc_params, **kwargs)
        self.iid_net_def = iid_net_def
        self.mpn_net_def = mpn_net_def

    def definition_space(self):
        return GeneralNetHyperParameterSpace(self)


class GeneralNetDefinition(ClassificationModuleDefinition):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: GeneralNetHyperParameterSet = GeneralNetHyperParameterSet()):
        super().__init__(ModelType.GeneralNet, hyperparams)

    @property
    def _instantiate_func(self) -> Callable:
        return GeneralNet

    def definition_space(self):
        return GeneralNetDefinitionSpace(self.hyperparams.definition_space())


class GeneralNet(ClassificationModule):
    """
    GeneralNet
    """
    def __init__(self, params: GeneralNetHyperParameterSet = GeneralNetHyperParameterSet()):
        self.edge_index = None
        super().__init__(params)

    def define_model(self) -> torch.nn.Module:
        iid_net = self.params.iid_net_def.instantiate()
        mpn_net = self.params.mpn_net_def.instantiate()

        return nn.Sequential(OrderedDict([
            ('iid_net', iid_net),
            ('mpn_net', mpn_net)
        ]))

    def initialize_model(self):
        pass

    def forward(self, x: torch.tensor):
        """
        Runs the forward pass on the data
        :param x: The input to be forwarded
        :return: The output of the model
        """
        x, intermediate = self.model.iid_net.features(x)
        self.model.iid_net.head(x)  # Run also the other head to have IID layer outputs
        x = self.model.mpn_net(x, intermediate)
        return x

    def forward_with_intermediate(self, x: torch.tensor):
        """
        Runs the forward pass on the data
        :param x: The input to be forwarded
        :return: The output of the model
        """
        x, intermediate = self.model.iid_net.features(x)
        self.model.iid_net.head(x)  # Run also the other head to have IID layer outputs
        x = self.model.mpn_net(x, intermediate)
        return x, intermediate

    def features(self, x):
        x, _ = self.model.iid_net.features(x)
        return x

    @property
    def feature_model(self):
        return self.model.iid_net.feature_model

    def forward_iid(self, x: torch.tensor):
        x = self.model.iid_net.forward(x)
        return x

    def forward_with_iid(self, x: torch.tensor):
        x = self.model.iid_net.features(x)
        iid_logits = self.model.iid_net.head(x)
        x = self.model.mpn_net(x)
        return x, iid_logits


class GeneralNetHyperParameterSpace(ClassificationModuleHyperParameterSpace):
    """HyperParameterSpace of the GeneralNet"""

    def __init__(self,
                 default_hyperparam_set: GeneralNetHyperParameterSet = GeneralNetHyperParameterSet(),
                 ):
        super().__init__(default_hyperparam_set)
        self.iid_net_space = default_hyperparam_set.iid_net_def.definition_space() \
            if default_hyperparam_set.iid_net_def is not None else None
        self.mpn_net_space = default_hyperparam_set.mpn_net_def.definition_space() \
            if default_hyperparam_set.mpn_net_def is not None else None

    @property
    def search_grid(self) -> Dict[str, Sequence[Any]]:
        search_grid = {}
        if self.iid_net_space is not None:
            search_grid.update(self.iid_net_space.search_grid)

        if self.mpn_net_space is not None:
            search_grid.update(self.mpn_net_space.search_grid)

        search_grid.update(super().search_grid)
        return search_grid

    @property
    def search_space(self) -> Dict[str, Sequence[Any]]:
        search_space = {}
        if self.iid_net_space is not None:
            search_space.update(self.iid_net_space.search_grid)

        if self.mpn_net_space is not None:
            search_space.update(self.mpn_net_space.search_grid)

        search_space.update(super().search_grid)
        return search_space

    def suggest(self, trial: optuna.Trial) -> GeneralNetHyperParameterSet:
        """
        Sugges new HyperParameterSet for a trial
        :return: Suggested HyperParameterSet
        """
        hyperparams = super().suggest(trial=trial)
        hyperparams.iid_net_def = self._suggest_iid_net_params(trial)
        hyperparams.mpn_net_def = self._suggest_mpn_net_params(trial)

        return hyperparams

    def _suggest_iid_net_params(self, trial: optuna.Trial) -> ClassificationModuleDefinition:
        """
        Suggest new feature extraction definition for a trial
        :return: Suggested Definition of the feature extraction part
        """
        if self.iid_net_space is None:
            return self.default_hyperparam_set.iid_net_def
        return self.iid_net_space.suggest(trial)

    def _suggest_mpn_net_params(self, trial: optuna.Trial) -> ClassificationModuleDefinition:
        """
        Suggest new feature refinement definition for a trial
        :return: Suggested Definition of the feature refinement part
        """
        if self.mpn_net_space is None:
            return self.default_hyperparam_set.mpn_net_def
        return self.mpn_net_space.suggest(trial)


class GeneralNetDefinitionSpace(DefinitionSpace):
    """DefinitionSpace of the GeneralNet"""

    def __init__(self, hyperparam_space: GeneralNetHyperParameterSpace = GeneralNetHyperParameterSpace()):
        super().__init__(ModelType.GeneralNet, hyperparam_space)

    def suggest(self, trial: optuna.Trial) -> GeneralNetDefinition:
        return GeneralNetDefinition(self.hyperparam_space.suggest(trial))
