import time
import numpy as np
import sklearn
import sklearn.manifold
import sklearn.cluster
from sklearn.metrics import normalized_mutual_info_score
from sklearn.metrics.cluster import _supervised
from scipy.optimize import linear_sum_assignment
import scipy.sparse
import torch


def feature_detection(A, label):
    """ Computes proportion of l_1 norm
    of each row of A that is given to connections outside of cluster
    """
    n = A.shape[0]
    err = 0
    # TODO: get rid of for loop
    for i in range(n):
        err += np.abs(A[i,label==label[i]]).sum()/np.abs(A[i,:]).sum()
    err /= n
    err = 1-err
    return err

def nmi(label, pred_label):
    return normalized_mutual_info_score(label, pred_label)


def clustering_accuracy(label, pred_label):
    """ from https://github.com/ChongYou/subspace-clustering
    """
    label, pred_label = _supervised.check_clusterings(label, pred_label)
    value = _supervised.contingency_matrix(label, pred_label)
    [r, c] = linear_sum_assignment(-value)
    return value[r, c].sum() / len(label)

def sparsity(A, zero_cutoff=1e-8):
    """ Average number of nonzeros per row
    """
    if isinstance(A, np.ndarray):
        return np.sum(np.abs(A) > zero_cutoff)/A.shape[0]
    elif isinstance(A, torch.Tensor):
        return torch.sum(torch.abs(A) > zero_cutoff).item()/A.shape[0]
    elif isinstance(A,scipy.sparse.csr_matrix):
        return (abs(A) > zero_cutoff).sum()/A.get_shape()[0]
    else:
        raise ValueError('invalid type')
    

def basic_metrics(A, label, verbose=True):
    nnz = sparsity(A)
    fd_error = feature_detection(A, label)
    components = scipy.sparse.csgraph.connected_components(A, return_labels=False)
    wrong_edge = percent_wrong_edge(A, label)
    if verbose:
        print(f"NNZ/ row: {nnz:.2f}   ||| Feat detect: {fd_error:.5f} ")
        print(f"Num comp: {components}       ||| Pct wrong edges: {wrong_edge:.2f}")
    return nnz, fd_error, components, wrong_edge
    

def spectral_clustering_metrics(A, nclass, label, verbose=True, n_init=10, normalize_embed=True, solver_type='lm', extra_dim=0, tol=0):
    """ n_init is number of separate runs of kmeans to average over
    computes average accuracy and nmi
    """
    lap = scipy.sparse.csgraph.laplacian(A, normed=True)
    nnz, fd_error, components, wrong_edge = basic_metrics(A, label, verbose=False)
    if components > nclass:
        print('---Oversegmented graph, setting higher eigensolver tolerance (unstable results)---')
        # oversegmented, need higher tolerance
        tol = 1e-4
        
    start_time = time.time()
    if solver_type=='shift_invert':
        vals, embedding = scipy.sparse.linalg.eigsh(lap, k=nclass+extra_dim, sigma=1e-6, which='LM', tol=tol)
    elif solver_type=='la':
        vals, embedding = scipy.sparse.linalg.eigsh(-lap, k=nclass+extra_dim,
                                    sigma=None,  which='LA', tol=tol)
    elif solver_type=='lm':
        vals, embedding = scipy.sparse.linalg.eigsh(
            2*scipy.sparse.identity(lap.shape[0])-lap,
            k=nclass+extra_dim, sigma=None,  which='LM', tol=tol)
    else:
        raise ValueError('invalid solver')
    eigsolver_elapsed = time.time() - start_time
    
    if normalize_embed:
        embedding = embedding / np.linalg.norm(embedding, axis=1, keepdims=True)
    cluster_model = sklearn.cluster.KMeans(n_clusters=nclass, n_init=10)
    acc_lst = []
    nmi_lst = []
    for _ in range(n_init):
        cluster_model.fit(embedding)
        pred_label = cluster_model.labels_
        acc = clustering_accuracy(label, pred_label)
        nmi_score = nmi(label, pred_label)
        acc_lst.append(acc)
        nmi_lst.append(nmi_score)
    
    if verbose:
        print(f'Acc mean: {np.mean(acc_lst):.3f}   ||| stdev: {np.std(acc_lst):.4f}')
        print(f'NMI mean: {np.mean(nmi_lst):.3f}   ||| stdev: {np.std(nmi_lst):.4f}')
        print(f"NNZ/ row: {nnz:.2f}   ||| Feat detect: {fd_error:.5f} ")
        print(f"Num comp: {components}       ||| Pct wrong edges: {wrong_edge:.2f}")
        print(f'Eigensolver time (s): {eigsolver_elapsed}')
        #print("Min connectivity:", np.min(conn_lst), "| Max connectivity:", np.max(conn_lst))
    # if components > nclass:
    #     # do not record unstable results for oversegmented case
    #     acc_lst = [0]
    
    return acc_lst, nmi_lst, fd_error, nnz

def connectivity_lst(A, label):
    lst = []
    for class_num in np.unique(label):
        class_inds = np.where(label == class_num)[0]
        A_class = A[np.ix_(class_inds, class_inds)]
        lap_class = scipy.sparse.csgraph.laplacian(A_class, normed=True)
        vals = -scipy.sparse.linalg.eigsh(-lap_class, k=2, which='LA')[0]
        conn_val = np.max(vals)
        lst.append(conn_val)
    return lst

def percent_wrong_edge(A, label):
    row, col = A.nonzero()
    matches = label[row] != label[col]
    if isinstance(matches, torch.Tensor):
        matches = matches.float()
    return matches.mean()*100

def print_affinity_info(A):
    rowsums = A.sum(axis=1)
    colsums = A.sum(axis=0)
    nnz = sparsity(A)
    components = scipy.sparse.csgraph.connected_components(A, return_labels=False)
    print(f'MinMax rowsum: {rowsums.min()} | {rowsums.max()}')
    print(f'MinMax colsum: {colsums.min()} | {colsums.max()}')
    print(f'Sparsity: {nnz:.3f}')
    print(f'Num components: {components}')