"""Tree traversal/embedding util functions."""

import networkx as nx
import numpy as np
from scipy.cluster.hierarchy import dendrogram, linkage, to_tree


def reverse(mat):
    """Reverse square matrix with diagonal set to zero."""
    mat_rev = np.copy(mat)
    mat_rev = 1. / mat_rev
    mat_rev[np.diag_indices_from(mat_rev)] = 0.0
    return mat_rev


def get_leaf_descendants(tree, leaves):
    desc = {}
    for node in tree.nodes():
        if node not in leaves:
            desc[node] = [x for x in nx.descendants(tree, node) if x in leaves]
        else:
            desc[node] = [node]
    return desc


def get_leaves_root(tree):
    leaves = [x for x in tree.nodes() if len(list(tree.neighbors(x))) == 0]
    reversed_tree = nx.reverse_view(tree)
    root = [x for x in reversed_tree.nodes() if len(list(reversed_tree.neighbors(x))) == 0][0]
    return leaves, root


def get_lcas(tree, root):
    lcas = dict(nx.tree_all_pairs_lowest_common_ancestor(tree, root=root))
    all_lcas = lcas.copy()
    for (n1, n2), lca in lcas.items():
        all_lcas[n2, n1] = lca
    return all_lcas


def tree_to_average_linkage(tree, similarities):
    # reconstructed linkage from predicted clustering
    distances = reverse(similarities)
    tree_reversed = nx.reverse_view(tree)
    nodes = list(nx.dfs_postorder_nodes(tree))
    stack = {}
    linked_mat = []
    count = {}
    clusters = {}
    while len(nodes) > 0:
        n1 = nodes.pop(0)
        parents = list(tree_reversed.neighbors(n1))
        if len(parents) > 0:
            parent = parents[0]
            if parent in stack:
                n2 = stack[parent]
                if n1 in count:
                    total = count[n1] + count[n2]
                else:
                    total = 1 + count[n2]
                if n1 not in clusters:
                    clusters[n1] = [n1]
                dist = []
                for i in clusters[n1]:
                    for j in clusters[n2]:
                        dist += [distances[i, j]]
                linked_mat += [[n2, n1, np.mean(dist), total]]
                count[parent] = total
                clusters[parent] = clusters[n1] + clusters[n2]
            else:
                stack[parent] = n1
                if n1 not in count:
                    count[n1] = 1
                if n1 not in clusters:
                    clusters[n1] = [n1]
    linked_mat = np.array(linked_mat)
    return linked_mat


def model2linkage(model):
    # convert sklearn HC model to scipy linkage
    # agg = AgglomerativeClustering(linkage=method)
    # clustering = agg.fit(x)
    # linked = tree_utils.model2linkage(clustering)
    # tree = to_nx_tree(linked)
    tree_as_list = model.children_
    sizes = {}
    linkage_array = []
    start_idx = len(tree_as_list) + 1
    idx = start_idx
    for children in tree_as_list:
        linkage = []
        size = 0
        for child in children:
            linkage += [child]
            if child < start_idx:
                size += 1
            else:
                size += sizes.get(child)
        linkage += [idx - start_idx + 1, size]
        sizes[idx] = size
        idx += 1
        linkage_array += [linkage]
    return np.array(linkage_array).astype(float)


def to_nx_tree(linked):
    tree, nodelist = to_tree(linked, rd=True)
    G = nx.DiGraph()
    for node in nodelist:
        if node.get_left():
            G.add_edge(node.id, node.left.id)
        if node.get_right():
            G.add_edge(node.id, node.right.id)
    return G


def random_tree(n_leaves):
    node_count = 0
    G = nx.DiGraph()
    for i in range(n_leaves):
        G.add_node(i)
        node_count += 1
    nodes = list(G.nodes())
    node_count = len(nodes)
    while len(nodes) > 1:
        ids = list(range(len(nodes)))
        idx1, idx2 = np.random.choice(ids, 2, replace=False)
        idx1, idx2 = min(idx1, idx2), max(idx1, idx2)
        n1, n2 = nodes[idx1], nodes[idx2]
        nodes = nodes[:idx1] + nodes[idx1 + 1:idx2] + nodes[idx2 + 1:]
        G.add_edge(node_count, n1)
        G.add_edge(node_count, n2)
        nodes += [node_count]
        node_count += 1
    return G


def random_linkage(n_leaves):
    X = np.random.random((n_leaves, 2))
    linked = linkage(
        X,
        method='single',
    )
    return linked
