import numpy as np
from itertools import permutations


def get_membership_matrix(indices: np.ndarray) -> np.ndarray:
    """
    :param indices: (num_nodes,) Membership indices for nodes
    :return membership_matrix: (num_nodes, num_indices) Membership matrix
    """
    num_indices = np.max(indices) + 1
    return np.eye(num_indices)[indices, :]


def compute_group_balance(clusters: np.ndarray, groups: np.ndarray, normalize: bool = False) -> (np.ndarray, float):
    """
    :param clusters: (num_nodes,) Predicted cluster identities
    :param groups: (num_nodes,) Protected group memberships
    :param normalize: Whether to balance the groups with group size
    :return balances: (num_clusters,) Balance with each cluster
    :return avg_balance: Average balance of clusters
    """

    # Get memberships
    cluster_memberships = get_membership_matrix(clusters)
    group_memberships = get_membership_matrix(groups)
    num_clusters = cluster_memberships.shape[1]
    num_groups = group_memberships.shape[1]

    # Count number of elements in each cluster
    group_sizes = group_memberships.sum(axis=0).reshape((-1,))

    # Compute balance
    counts = np.matmul(cluster_memberships.T, group_memberships)
    balances = np.zeros((num_clusters,))
    for c in range(num_clusters):
        balance = float('inf')
        for g1 in range(num_groups):
            for g2 in range(num_groups):
                curr_balance = counts[c, g1] / (1e-6 + counts[c, g2])
                if normalize:
                    curr_balance = curr_balance * (group_sizes[g2] / (1e-6 + group_sizes[g1]))
                if curr_balance < balance:
                    balance = curr_balance
        balances[c] = balance

    return balances, balances.mean()


def compute_individual_balance(clusters: np.ndarray, fair_mat: np.ndarray, normalize: bool = False) -> \
        (np.ndarray, float):
    """
    :param clusters: (num_nodes,) Predicted clusters
    :param fair_mat: (num_nodes, num_nodes) Fairness graph under which balance must be computed
    :param normalize: Whether to normalize balance using cluster sizes
    :return balances: (num_nodes,) Balance for each individual
    :return avg_balance: Average balance of the individuals
    """

    # Get memberships
    cluster_memberships = get_membership_matrix(clusters)
    num_clusters = cluster_memberships.shape[1]
    num_nodes = fair_mat.shape[0]

    # Compute cluster sizes
    cluster_sizes = cluster_memberships.sum(axis=0).reshape((-1,))

    # Compute balance
    counts = np.matmul(fair_mat, cluster_memberships)
    balances = np.zeros((num_nodes,))
    for i in range(num_nodes):
        balance = float('inf')
        for c1 in range(num_clusters):
            for c2 in range(num_clusters):
                curr_balance = counts[i, c1] / (1e-6 + counts[i, c2])
                if normalize:
                    curr_balance = curr_balance * (cluster_sizes[c2] / (1e-6 + cluster_sizes[c1]))
                if curr_balance < balance:
                    balance = curr_balance
        balances[i] = balance

    return balances, balances.mean()


def reflow_clusters(clusters: np.ndarray):
    """
    :param clusters: (num_nodes,) Cluster assignment
    :return reflow: (num_nodes,) Cluster number changed so that they are contiguous. If no point belongs to cluster 3
                    then cluster 4 will become cluster 3, cluster 5 will become cluster 4, and so on.
    """
    id_map = dict()
    idx = 0
    reflow = np.zeros(clusters.shape)
    for i in range(clusters.shape[0]):
        if clusters[i] not in id_map:
            id_map[clusters[i]] = idx
            idx += 1
        reflow[i] = id_map[clusters[i]]
    return reflow


def align_clusters(true_clusters: np.ndarray, pred_clusters: np.ndarray) -> (int, np.ndarray, np.ndarray):
    """
    :param true_clusters: (num_nodes,) Ground truth clusters
    :param pred_clusters: (num_nodes,) Predicted clusters
    :return num_mistakes: Number of mistakes incurred by the best alignment
    :return reflow_true: (num_nodes,) True clusters reflowed
    :return aligned_pred: (num_nodes,) Aligned predicted clusters
    """
    # Reflow clusters
    reflow_true = reflow_clusters(true_clusters)
    reflow_pred = reflow_clusters(pred_clusters)
    num_clusters_true = int(np.max(reflow_true) + 1)
    num_clusters_pred = int(np.max(reflow_pred) + 1)

    assert num_clusters_true == num_clusters_pred, 'Required num_clusters_true == num_clusters_pred.'

    # Find alignment with minimum error
    perms = set(permutations(list(range(num_clusters_true))))
    aligned_clusters = np.zeros(pred_clusters.shape)
    min_mistakes = float('inf')
    for perm in perms:
        temp_clusters = np.zeros(pred_clusters.shape)
        for i in range(pred_clusters.shape[0]):
            temp_clusters[i] = perm[pred_clusters[i]]

        mistakes = (temp_clusters != reflow_true).sum()
        if mistakes < min_mistakes:
            min_mistakes = mistakes
            aligned_clusters = temp_clusters

    return min_mistakes, reflow_true, aligned_clusters
