import numpy as np
import scipy.sparse.linalg
from sklearn.cluster import KMeans
import time


import sys
from functions_for_DDCSBM import *




class ReturnValue:
    def __init__(self, estimated_labels, ov, n_clusters,modularity, eigspace_time, kmeans_time):
        self.estimated_labels = estimated_labels
        self.overlap = ov
        self.n_clusters = n_clusters
        self.modularity = modularity
        self.eigspace_time = eigspace_time
        self.kmeans_time = kmeans_time

def dynamic_community_detection(AT, eta, *args, **kwargs):
    
    '''Function to perform community detection on a graph with n nodes and k communities on T frames according to Algorithm 1
    Use : 
        cluster = dynamic_community_detection(AT, eta)
    
    Output :
        cluster.estimated_labels (set of arrays) : estimated_labels[t] is the vector containing the estimated labels of G_t
        cluster.overlap (array) : overlap[t] is the overlap at time t, if the exact solution is known
        cluster.n_clusters (scalar) : number of clusters used in the k-means step
        cluster.modularity (array) : modularity[t] gives the modularity of the partition at time t  
        cluster.eigspace_time (scalar) : time needed to compute the embedding X
        cluster.kmeans_time (scalar) : time needed to perform the k-means step
            
    Input :
        AT (set of sparse matrices) : AT[t] is the sparse adjacency matrix of G_t
        eta (scalar) : persistence in the community labels
        
    **kwargs
        n_clusters (scalar): if 'None' it is estimated, otherwise the known value is used. By default 'None'
        real_classes (set of arrays) : real_classes[t] is the label vector of G_t. By default 'None'
                   
    '''
    
    real_classes = kwargs.get('real_classes', [[None]])
    n_clusters = kwargs.get('n_clusters', None)
    
    start_time=time.time()
    T = len(AT) # number of time frames
    n = len(AT[0][0].A.T) # number of nodes
    alpha_c = find_transition(T,eta) # critical value of the transition
    c_v = [np.sum(AT[i])/n for i in range(T)] 
    c = np.mean(c_v) # average degree
    c2_v = [np.mean((AT[i].dot(np.ones(n)))**2) for i in range(T)]
    phi = np.mean(c2_v)/c**2 # phi

    lambda_d = alpha_c/np.sqrt(c*phi)
    H = BH_matrix(AT,n,lambda_d,eta) # build the Bethe-Hessian matrix
    info, X = informative_eigs(H)
    for i in range(n*T):
        X[i] = X[i]/np.sqrt(np.sum(X[i]**2))
    eigspace_time = time.time() - start_time
    
    start_time=time.time()
    estimated_labels = [np.zeros(n) for i in range(T)]    
    ov = np.zeros(T)
    modularity = np.zeros(T)

    if info > 1:

        if n_clusters == None:
            n_clusters = estimate_k(X,info,n,AT[0])


        
        kmeans = KMeans(n_clusters = n_clusters) # perform kmeans on the informative eigenvector
        Y = np.array([X[i*n:(i+1)*n] for i in range(T)])

        for i in range(T):
            kmeans.fit(Y[i])
            estimated_labels[i] = kmeans.predict(Y[i])
            if real_classes[0][0] != None:
                es_for_ov = np.zeros(n).astype(int)
                es_for_ov += estimated_labels[i]
                ov[i] = overlap(real_classes[i], es_for_ov)

            modularity[i] = compute_modularity(AT[i], estimated_labels[i])
    kmeans_time = time.time() - start_time
    
    return ReturnValue(estimated_labels, ov, n_clusters, modularity, eigspace_time, kmeans_time)

def dynamic_community_detection_JL(AT, eta, n_clusters, *args, **kwargs):
    
    '''Function to perform community detection on a graph with n nodes and k communities on T frames according to Algorithm 2
    Use : 
        cluster = dynamic_community_detection_JL(AT, eta, n_clusters)
    
    Output :
        cluster.estimated_labels (set of arrays) : estimated_labels[t] is the vector containing the estimated labels of G_t
        cluster.overlap (array) : overlap[t] is the overlap at time t, if the exact solution is known
        cluster.n_clusters (scalar) : number of clusters used in the k-means step
        cluster.modularity (array) : modularity[t] gives the modularity of the partition at time t  
        cluster.eigspace_time (scalar) : time needed to compute the embedding X
        cluster.kmeans_time (scalar) : time needed to perform the k-means step
            
    Input :
        AT (set of sparse matrices) : AT[t] is the sparse adjacency matrix of G_t
        eta (scalar) : persistence in the community labels
        n_clusters (scalar) : number of communities k
        
    **kwargs
        real_classes (set of arrays) : real_classes[t] is the label vector of G_t. By default 'None'
                   
    '''
    
    real_classes = kwargs.get('real_classes', [[None]])
    
    start_time=time.time()
    T = len(AT) # number of time frames
    n = len(AT[0][0].A.T) # number of nodes
    alpha_c = find_transition(T,eta) # critical value of the transition
    c_v = [np.sum(AT[i])/n for i in range(T)] 
    c = np.mean(c_v) # average degree
    c2_v = [np.mean((AT[i].dot(np.ones(n)))**2) for i in range(T)]
    phi = np.mean(c2_v)/c**2 # phi

    lambda_d = alpha_c/np.sqrt(c*phi)
    H = BH_matrix(AT,n,lambda_d,eta) # build the Bethe-Hessian matrix
    X = JL_poly_approx_eigs(H)
    for i in range(n*T):
        X[i] = X[i]/np.sqrt(np.sum(X[i]**2))
    eigspace_time = time.time() - start_time
    
    start_time=time.time()
    estimated_labels = [np.zeros(n) for i in range(T)]    
    ov = np.zeros(T)
    modularity = np.zeros(T)

    kmeans = KMeans(n_clusters = n_clusters) # perform kmeans on the informative eigenvector
    Y = np.array([X[i*n:(i+1)*n] for i in range(T)])

    for i in range(T):
        kmeans.fit(Y[i])
        estimated_labels[i] = kmeans.predict(Y[i])
        if real_classes[0][0] != None:
            es_for_ov = np.zeros(n).astype(int)
            es_for_ov += estimated_labels[i]
            ov[i] = overlap(real_classes[i], es_for_ov)

        modularity[i] = compute_modularity(AT[i], estimated_labels[i])
    kmeans_time = time.time() - start_time
    
    return ReturnValue(estimated_labels, ov, n_clusters, modularity, eigspace_time, kmeans_time)
