from spaghettini import quick_register
import os
import oyaml as yaml
from pprint import pprint
import numpy as np
import random as random
import torch
from datetime import datetime
from argparse import Namespace
from typing import MutableMapping, Callable

import wandb

from pytorch_lightning import seed_everything

USE_GPU = torch.cuda.is_available()


def to_cuda(xs):
    if type(xs) is not list and type(xs) is not tuple:
        return xs.cuda() if USE_GPU else xs
    items = list()
    for curr_item in xs:
        curr_item = curr_item.cuda() if USE_GPU else curr_item
        items.append(curr_item)

    return items


def transfer_to_average_dict(averaged_values_dict, curr_output_dict):
    for k, v in curr_output_dict.items():
        if k not in averaged_values_dict:
            averaged_values_dict[k] = [v]
        else:
            averaged_values_dict[k].append(v)


def average_values_in_list_of_list_of_dicts(list_of_list_of_dicts):
    averaged_values_dict = dict()
    for list_of_dicts in list_of_list_of_dicts:
        for i in range(len(list_of_dicts)):
            output_dict = list_of_dicts[i]
            transfer_to_average_dict(averaged_values_dict, output_dict)

    averaged_scalar_metrics_dict = dict()
    for k, v in averaged_values_dict.items():
        try:
            averaged_scalar_metrics_dict[k] = np.array(v).mean()
        except:
            print("Skipping any non-scalar metric that was logged. ")

    return averaged_scalar_metrics_dict


def average_values_in_list_of_dicts(list_of_dicts):
    averaged_values_dict = dict()
    for curr_output_dict in list_of_dicts:
        for k, v in curr_output_dict.items():
            if k not in averaged_values_dict:
                averaged_values_dict[k] = [v]
            else:
                averaged_values_dict[k].append(v)
    averaged_scalar_metrics_dict = dict()
    for k, v in averaged_values_dict.items():
        try:
            averaged_scalar_metrics_dict[k] = np.array(v).mean()
        except:
            print("Skipping any non-scalar metric that was logged. ")

    return averaged_scalar_metrics_dict


def average_evaluation_results(results):
    assert isinstance(results, list)
    assert len(results) > 0
    if isinstance(results[0], dict):
        return average_values_in_list_of_dicts(results)
    elif isinstance(results[0], list):
        return average_values_in_list_of_list_of_dicts(results)
    else:
        message = f"The evaluation results are not in a recognizable format."
        raise ValueError(message)


def prepend_string_to_dict_keys(prepend_key, dictinary):
    return {"{}{}".format(prepend_key, k): v for k, v in dictinary.items()}


def postpend_string_to_dict_keys(postpend_key, dictionary):
    return {"{}{}".format(k, postpend_key): v for k, v in dictionary.items()}


def print_experiment_config(path="."):
    yaml_path = os.path.join(path, "template.yaml")
    config_dict = yaml.safe_load(open(yaml_path))
    pprint(config_dict)


def sendline_and_get_response(s, line):
    s.sendline(line)
    s.prompt()
    reply = str(s.before.decode("utf-8"))
    pprint(reply)


def getnow(return_int=False):
    now = datetime.now()
    dt_string = now.strftime("%Y_%m_%d_%H_%M_%S")
    if return_int:
        return int(dt_string)
    else:
        return dt_string


def get_num_params_of_pytorch_model(module):
    model_parameters = filter(lambda p: p.requires_grad, module.parameters())
    return sum([np.prod(p.size()) for p in model_parameters])


def get_num_of_allocated_tensors():
    import torch
    import gc
    count = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                count += 1
        except:
            pass
    return count


def set_seed(seed=None):
    seed = getnow(return_int=True) % 2 ** 32 if seed is None else seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    seed_everything(seed=seed)


@quick_register
def seed_workers(worker_id):
    # Used to make sure Pytorch dataloaders don't return identical random numbers amongst different workers.
    set_seed()


@quick_register
def freeze_thaw(epoch, freeze, thaw):
    if freeze <= epoch < thaw:
        return 0.
    else:
        return 1.


def set_hyperparams(config_path, logger):
    with open(config_path, 'r') as f:
        x = yaml.safe_load(f)

        # Sanitize hyperparameter indicators.
        x = sanitize_hyperparameter_indicators(x)
        logger.log_hyperparams(x)


def sanitize_hyperparameter_indicators(config_dict):
    new_dict = dict()
    for k, v in config_dict.items():
        if isinstance(v, dict):
            v = sanitize_hyperparameter_indicators(v)
        if "{" in k and "}" in k:
            lind = k.rfind("{")
            sanitized_k = k[:lind]
            new_dict[sanitized_k] = v
        else:
            new_dict[k] = v
    return new_dict


def set_hyperparams_pure_wandb(config_path):
    with open(config_path, 'r') as f:
        x = yaml.safe_load(f)
        x = _convert_params(x)
        x = _flatten_dict(x)
        x = _sanitize_callable_params(x)
        wandb.config.update(x, allow_val_change=True)


def _convert_params(params):
    # Taken from pytorch lightning codebase.
    # in case converting from namespace
    if isinstance(params, Namespace):
        params = vars(params)

    if params is None:
        params = {}

    return params


def _flatten_dict(params, delimiter: str = '/'):
    # Taken from pytorch lightning codebase.

    def _dict_generator(input_dict, prefixes=None):
        prefixes = prefixes[:] if prefixes else []
        if isinstance(input_dict, MutableMapping):
            for key, value in input_dict.items():
                key = str(key)
                if isinstance(value, (MutableMapping, Namespace)):
                    value = vars(value) if isinstance(value, Namespace) else value
                    for d in _dict_generator(value, prefixes + [key]):
                        yield d
                else:
                    yield prefixes + [key, value if value is not None else str(None)]
        else:
            yield prefixes + [input_dict if input_dict is None else str(input_dict)]

    return {delimiter.join(keys): val for *keys, val in _dict_generator(params)}


def _sanitize_callable_params(params):
    # Taken from pytorch lightning codebase.
    def _sanitize_callable(val):
        # Give them one chance to return a value. Don't go rabbit hole of recursive call.
        if isinstance(val, Callable):
            try:
                _val = val()
                if isinstance(_val, Callable):
                    return val.__name__
                return _val
            except Exception:
                return getattr(val, "__name__", None)
        return val

    return {key: _sanitize_callable(val) for key, val in params.items()}


def enlarge_matplotlib_defaults(plt_object):
    # plt_object.rc('legend', fontsize='medium')
    plt_object.rc('font', size=15)


def get_relative_path_from_absolute_path(abs_path):
    """Assuming that all code is under src/, obtain the relative path from the absolute path"""
    src_idx = abs_path.find("src")
    return abs_path[src_idx:]


def stdlog(abs_path, log):
    """Print out a log to stdout, prepended by which file is generating the log."""
    curr_path = get_relative_path_from_absolute_path(abs_path)
    print(f">>>> ({curr_path})\n{log}\n<<<<")


def update_matplotlib_style(plt):
    plt.style.use('fivethirtyeight')


def is_scalar(x):
    """Return True iff x is a single dimensional number (float, int, 1 dimensional tensor etc.), or a boolean."""
    # Accept bools, ints and floats.
    if isinstance(x, bool) or isinstance(x, int) or isinstance(x, float):
        return True

    # Accept numpy arrays that become 0 dimensional when squeezed.
    if isinstance(x, np.ndarray):
        x_squeezed = x.squeeze()
        return len(x_squeezed.shape) == 0

    # Accept tensors that become 0 dimensional when squeezed.
    if isinstance(x, torch.Tensor):
        x_squeezed = x.squeeze()
        return len(x_squeezed.shape) == 0

    # Reject everything else.
    return False


@quick_register
def int_powers_of_k(k, num=100, multiplier=1.):
    powers = [int(multiplier * k**i) for i in range(num)]
    powers = sorted(list(set(powers)))
    return powers


def power_method(f0, z0, n_iters=200):
    """Estimating the spectral radius of J using power method
    Args:
        f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
        z0 (torch.Tensor): Input to the function f
        n_iters (int, optional): Number of power method iterations. Defaults to 200.
    Returns:
        tuple: (largest eigenvector, largest (abs.) eigenvalue)
    """
    evector = torch.randn_like(z0)
    bsz = evector.shape[0]
    for i in range(n_iters):
        vTJ = torch.autograd.grad(f0, z0, evector, retain_graph=(i < n_iters-1), create_graph=False)[0]
        evalue = (vTJ * evector).reshape(bsz, -1).sum(1, keepdim=True) / (evector * evector).reshape(bsz, -1).sum(1, keepdim=True)
        evector = (vTJ.reshape(bsz, -1) / vTJ.reshape(bsz, -1).norm(dim=1, keepdim=True)).reshape_as(z0)
    return dict(evector=evector, evalue=torch.abs(evalue))