import os
import json
from abc import ABC

import pytorch_lightning as pl

from src.utils.misc import *
from src.dl.optimizers.lookahead import Lookahead

TRAINING_KEY = "training/"
VALIDATION_KEY = "validation/"
TEST_KEY = "testing/"


@quick_register
class SingleTaskLearner(pl.LightningModule, ABC):
    def __init__(self, model, train_loader, valid_loader, test_loader, optimizer, lr_scheduler,
                 loss_fn, logging_functions, regularization_fn=None, dirs_dict=None, use_lookahead=False, **kwargs):
        super().__init__()
        self.model = model
        self.train_loader, self.valid_loader, self.test_loader = train_loader, valid_loader, test_loader
        self.optimizer = optimizer
        self.use_lookahead = use_lookahead
        self.lr_scheduler = lr_scheduler
        self.loss_fn = loss_fn
        self.regularization_fn = regularization_fn
        self.logging_functions = logging_functions
        self.dirs_dict = dirs_dict
        self.initial_kwargs = kwargs

        # Some models need to be called with unique keyword arguments which might depend on training state. This method
        # takes in the training state (i.e. 'self') and outputs those keyword arguments.
        if "model_kwargs_getter" in kwargs:
            self.model_kwargs_getter = kwargs["model_kwargs_getter"]
        else:
            self.model_kwargs_getter = lambda tr_state: dict()

    def forward(self, inputs, **kwargs):
        # Gather model kwargs. Sometimes the forward pass of the model depend on the training state (i.e. pretraining
        # of DEQ models). Here, we construct a dictionary that we pass to the model to inform it of the training state.
        model_kwargs = self.model_kwargs_getter(self)
        model_kwargs.update(kwargs)

        # Run forward pass.
        outs, model_dict = self.model(inputs, **model_kwargs)

        return outs, model_dict

    # ____ Training. ____
    def training_step(self, data_batch, batch_idx, optimizer_idx=0, *args, **kwargs):
        # ____ Compute the model loss. ____
        loss, metric_logs = self.common_step(data_batch, batch_idx, optimizer_idx=0, prepend_key=TRAINING_KEY)

        # ____Take a learning rate scheduler step. ____
        scheduler = self.lr_schedulers()
        scheduler.step()

        # ____ Log metrics. ____
        self.logger.log_metrics(metric_logs, step=self.global_step)

        return {"loss": loss}

    # ____ Validation. ____
    def validation_step(self, data_batch, batch_nb, prepend_key=VALIDATION_KEY, *args, **kwargs):
        _, metric_logs = self.common_step(data_batch, batch_nb, optimizer_idx=0, prepend_key=prepend_key)

        return metric_logs

    def validation_epoch_end(self, outputs):
        averaged_metrics = average_evaluation_results(outputs)

        # Log the averaged metrics.
        self.logger.log_metrics(averaged_metrics, step=self.global_step)

        return averaged_metrics

    # ____ Testing. ____
    def test_step(self, *args, prepend_key=TEST_KEY, **kwargs):
        # Do exactly the same thing as we do at the end of the validation epoch.
        return self.validation_step(*args, prepend_key=prepend_key, **kwargs)

    def test_epoch_end(self, outputs):
        # Do exactly the same thing as we do at the end of the validation epoch.
        averaged_metrics = self.validation_epoch_end(outputs=outputs)

        # Save results as a json.
        test_results_filepath = os.path.join(self.dirs_dict["ckpt_dir_rel"], "test_results.json")
        with open(test_results_filepath, "w") as f:
            json.dump(obj=averaged_metrics, fp=f, indent=4)

        # Return the results both to command line and logger.
        test_return = dict(progress_bar=averaged_metrics, log=averaged_metrics)

        return test_return

    # ____ Common step. ____
    def common_step(self, data_batch, batch_idx, prepend_key="", *args, **kwargs):
        assert (self.training and (TRAINING_KEY in prepend_key)) or \
               (not self.training and (VALIDATION_KEY in prepend_key)) or \
               (not self.training and (TEST_KEY in prepend_key))
        # ___ Unpack the data batch. ___
        xs, ys = data_batch

        # ___ If asked, concatenate copies of the input. ___
        if "input_concat_multiplier" in self.initial_kwargs:
            assert isinstance(self.initial_kwargs["input_concat_multiplier"], int)
            xs = torch.cat([xs] * self.initial_kwargs["input_concat_multiplier"], dim=0)
            ys = torch.cat([ys] * self.initial_kwargs["input_concat_multiplier"], dim=0)

        # ___ Forward. ___
        preds, model_logs = self.forward(xs)

        # ___ Compute losses. ___
        total_loss = self.loss_fn(preds, ys)
        regularization_loss = self.regularization_fn(preds, ys, xs=xs, model_logs=model_logs, system=self) \
            if self.regularization_fn is not None else 0.

        total_loss = total_loss + regularization_loss

        # ___ Log relevant data. ___
        tracked_vars = dict(xs=xs, ys=ys, preds=preds, model_logs=model_logs,
                            total_loss=total_loss, batch_idx=batch_idx, prepend_key=prepend_key,
                            global_step=self.global_step, epoch=self.current_epoch, **kwargs)

        metric_logs = self.log_results(**tracked_vars)

        # ____ Return. ____
        return total_loss, metric_logs

    # ____ Optimizers. ____
    def configure_optimizers(self):
        optimizer = self.optimizer(self.model.parameters())
        if self.use_lookahead:
            optimizer = Lookahead(optimizer, k=5, alpha=0.5)
        scheduler = self.lr_scheduler(optimizer)
        return [optimizer], [scheduler]

    # ____ Data related. ____
    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.valid_loader

    def test_dataloader(self):
        return self.test_loader

    # ____ Logging. ____
    def log_results(self, **kwargs):
        prepend_key = kwargs["prepend_key"]
        metric_logs = dict()

        # Run the logging functions.
        for logging_fn in self.logging_functions:
            metric_logs = logging_fn(metric_logs, self, **kwargs)

        # ____ Prepend training mode to all keys. ____
        metric_logs = prepend_string_to_dict_keys(prepend_key=prepend_key, dictinary=metric_logs)

        return metric_logs

    def reconcile_input_and_model_types(self, tensor):
        first_model_param = next(self.parameters())
        if (tensor.dtype, tensor.device) != (first_model_param.dtype, first_model_param.device):
            return tensor.type_as(first_model_param)
        return tensor

