import itertools, io
import os, yaml, datetime, glob

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf 
import optuna

def load_yaml(yaml_path):
    assert os.path.exists(yaml_path), "Yaml path does not exist: " + yaml_path
    with open(yaml_path, "r") as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
    return config


def concat_configs(paths):
    """
    paths: A list of paths to yaml files.
    """
    assert len(paths) > 1
    for cnt, iter_path in enumerate(paths):
        if cnt == 0:
            config = load_yaml(iter_path)
        else:
            add = load_yaml(iter_path)
            for k in add.keys():
                assert not k in config.keys(), "Key {} duplicated.".format(k)
            config.update(add) # concat
    return config


def set_gpu_devices(gpu):
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
    tf.config.experimental.set_visible_devices(physical_devices[gpu], 'GPU')
    tf.config.experimental.set_memory_growth(physical_devices[gpu], True)


def fix_random_seed(flag_seed, seed=None):
    if flag_seed:
        np.random.seed(seed)
        tf.random.set_seed(seed)
        print("Numpy and TensorFlow's random seeds fixed: seed=" + str(seed))
    
    else:
        print("Random seed not fixed.")


def config_checker(config):
    """ config format checker.
    Args: Dict.
    """
    assert config["exp_phase"] in ["try", "tuning", "stat"]
    assert config["weights_base"] in [None, "imagenet"]
    assert config["weights_top"] in [None, "imagenet"]

    if config["exp_phase"] == "stat":
        assert config["pruner_index"] == 0 # no pruning


def config_checker_rnn(config):
    """ config format checker.
    Args: Dict.
    """
    assert config["exp_phase"] in ["try", "tuning", "stat"]

    if config["exp_phase"] == "stat":
        assert config["pruner_index"] == 0 # no pruning


def save_config_as_yaml(config, root_configs, subproject_name, exp_phase, comment, time_stamp):
    """
    - Remark
    Path to checkpoint files is 
    'root_ckptlogs'/'subproject_name'_'exp_phase'/'comment'_'time_stamp'/config_XXX.yaml

    """
    dir_config = "{}/{}_{}/{}_{}".format(
    root_configs, subproject_name, exp_phase, comment, time_stamp)
    if not os.path.exists(dir_config):
        os.makedirs(dir_config)

    tmp_now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    path_config = dir_config + "/config_saved{}.yaml".format(tmp_now)
    with open(path_config, "w") as f:
        yaml.dump(config, f)


def path2folders(path):
    """
    - Args
    path: A str. E.g., "/home/x/y/z" and "/raid/a/b/c.txt".
    - Returns
    folders: A list. E.g., ["home", "x", "y", "z"] and ["raid", "a", "b", "c.txt"].
    """
    path_org = path
    
    if path[-1] == "/":
        path = path[:-1]
    
    folders = []
    cnt = 0
    while 1:
        path, folder = os.path.split(path)
        if folder != "":
            folders.append(folder)
        else:
            break
            
        cnt += 1
        if cnt == 100:
            raise ValueError("Infinite loop: path={}".format(path_org))
            
    folders.reverse()

    return folders


def filename_rewriter_if_exists(filepath):
    """ Rewrites filepath, if there exists the same file name.
    - Args
    filepath: A str. E.g., "/home/x/y/z.txt" and "/data/x/y/z"
    - Returns
    filepath: A str. E.g., "/home/x/y/z(1).txt" and "/data/x/y/z(1)"
        if there already exists the file.
    """
    if not os.path.exists(filepath):
        return filepath

    else:
        # Double extensions not supported.
        _, filename = os.path.split(filepath)
        if filename.count(".") > 1:
            errmsg0 = "Double extensions (e.g., .tar.gz) not supported."
            errmsg1 = "\nGot filepath = {} .".format(filepath)
            raise ValueError(errmsg0 + errmsg1)

        # Rewrite filepath
        from_ = filepath
        pre, ext = os.path.splitext(filepath)
        cnt = 0
        while os.path.exists(filepath):
            cnt += 1
            filepath = pre + "_{}".format(cnt) + ext

        print("File path exists.")
        print("File path has been overwritten\nfrom {}\nto {}".format(from_, filepath))

        return filepath


def permute_in_subindex_order(list_glob):
    assert isinstance(list_glob, list)
    list_glob = sorted(list_glob)    
    if len(list_glob) == 1:
        return list_glob
    else:
        list_order = []
        for i in range(1, len(list_glob)):
            list_order.append(int(os.path.splitext(list_glob[i])[0].split("_")[-1]))

        list_glob_new = [list_glob[0]]
        for i in np.argsort(list_order):
            list_glob_new.append(list_glob[i + 1])

        return list_glob_new


def restrict_classes(llrs, labels, list_classes):
    """ 
    Args:
        llrs: A Tensor with shape (batch, ...). 
            E.g., (batch, duration, num classes, num classes).
        labels: A Tensor with shape (batch, ...). 
            E.g., (batch, ).
        list_classes: A list of integers specifying the classes
            to be extracted. E.g. list_classes = [0,2,9] for NMNIST.
    Returns:
        llrs_rest: A Tensor with shape (<= batch, llrs.shape[:1]). 
            If no class data found in llrs_rest, llrs_rest = None.
        lbls_rest: A Tensor with shape (<= batch, labels.shape[:1]).
            If no class data found in llrs_rest, lbls_rest = None.
    """
    if list_classes == []:
        return llrs, labels

    #assert tf.reduce_min(labels).numpy() <= np.min(list_classes)
    #assert np.max(list_classes) <= tf.reduce_max(labels).numpy() 
    
    ls_idx = []
    for itr_cls in list_classes:
        ls_idx.append(tf.reshape(tf.where(labels == itr_cls), [-1]))
    idx = tf.concat(ls_idx, axis=0)
    idx = tf.sort(idx)
    
    llrs_rest = tf.gather(llrs, idx, axis=0)
    lbls_rest = tf.gather(labels, idx, axis=0)
    
    llrs_rest = None if llrs_rest.shape[0] == 0 else llrs_rest
    lbls_rest = None if lbls_rest.shape[0] == 0 else lbls_rest

    return llrs_rest, lbls_rest


def extract_positive_row(llrs, labels):
    """ Extract y_i-th rows of LLR matrices.
    Args:
        llrs: (batch, duraiton, num classes, num classes)
        labels: (batch,)
    Returns:
        llrs_posrow: (batch, duration, num classes)
    """
    llrs_shape = llrs.shape
    duration = llrs_shape[1]
    num_classes = llrs_shape[2]
    
    labels_oh = tf.one_hot(labels, depth=num_classes, axis=1)
        # (batch, num cls)
    labels_oh = tf.reshape(labels_oh,[-1, 1, num_classes, 1])
    labels_oh = tf.tile(labels_oh, [1, duration, 1, 1])
        # (batch, duration, num cls, 1)

    llrs_pos = llrs * labels_oh
        # (batch, duration, num cls, num cls)
    llrs_posrow = tf.reduce_sum(llrs_pos, axis=2)
        # (batch, duration, num cls): = LLR_{:, :, y_i, :}
        
    return llrs_posrow


def add_max_to_diag(llrs):
    """
    Args:
        llrs: (batch, duration, num classes, num classes)
    Returns:
        llrs_maxdiag: (batch, duration, num classes, num classes),
            max(|llrs|) is added to diag of llrs.
    """
    num_classes = llrs.shape[2]
    
    llrs_abs = tf.abs(llrs)
    llrs_max = tf.reduce_max(llrs_abs)
        # max |LLRs|
    tmp = tf.linalg.tensor_diag([1.] * num_classes) * llrs_max
    tmp = tf.reshape(tmp, [1, 1, num_classes, num_classes])
    llrs_maxdiag = llrs + tmp

    return llrs_maxdiag


def plot_heatmatrix(mx, figsize=(10,7), annot=True):
    """
    Args:
        mx: A square matrix.
        figsize: A tuple of two positive integers.
        annot: A bool. Plot a number at the center of a cell or not.
    """
    plt.figure(figsize=figsize)
    sns.heatmap(mx, annot=annot)
    plt.show()


# https://www.tensorflow.org/tensorboard/image_summaries
def plot_confusion_matrix(cm, class_names):
    """
    Returns a matplotlib figure containing the plotted confusion matrix.

    Args:
    cm (array, shape = [n, n]): a confusion matrix of integer classes
    class_names (array, shape = [n]): String names of the integer classes
    """
    figure = plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    # Compute the labels from the normalized confusion matrix.
    labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

    # Use white text if squares are dark; otherwise black.
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return figure


def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image
