from fairsc import Experiment, discover_groups, low_rank_approx
from fairsc.generators import gen_sbm
from fairsc.algorithms import fair_sc, normal_sc
from fairsc.evaluations import compute_individual_balance, align_clusters
import numpy as np
import os
from shutil import rmtree
import networkx as nx
from networkx.algorithms.community import modularity


class SBMExperiment(Experiment):

    def __init__(self, exp_dir: str, seed: int, name: str, **kwargs):
        super(SBMExperiment, self).__init__(exp_dir, seed, name, **kwargs)
        self.kwargs = kwargs

    def _run(self):

        # Get control variables
        use_fair_sc = self.kwargs['use_fair_sc']
        normalize_laplacian = self.kwargs['normalize_laplacian']
        normalize_evec = self.kwargs['normalize_evec']
        num_nodes = self.kwargs['num_nodes']
        num_clusters = self.kwargs['num_clusters']
        num_groups = self.kwargs['num_groups']
        p_val = self.kwargs['p_val']
        q_val = self.kwargs['q_val']
        r_val = self.kwargs['r_val']
        s_val = self.kwargs['s_val']
        fair_in = self.kwargs['fair_in']
        fair_out = self.kwargs['fair_out']
        normalize_balance = self.kwargs['normalize_balance']

        # Get the graph
        adj_mat, fair_mat, ground_truth, _ = gen_sbm(num_nodes, num_clusters, num_groups, p_val, q_val, r_val, s_val,
                                                     fair_in, fair_out)
        original_fair_mat = fair_mat.copy()

        # Perform post processing on fairness matrix
        if use_fair_sc and self.kwargs['fair_mat_post_process_op'] is not None:
            fair_mat_post_process_op = self.kwargs['fair_mat_post_process_op']
            post_process_op_args = self.kwargs['post_process_op_args']
            fair_mat = fair_mat_post_process_op(fair_mat, **post_process_op_args)

        # Run clustering
        if use_fair_sc:
            clusters = fair_sc(adj_mat, fair_mat, num_clusters, normalize_laplacian, normalize_evec)
        else:
            clusters = normal_sc(adj_mat, num_clusters, normalize_laplacian, normalize_evec)

        # Compute balances and mistakes
        balances, avg_balance = compute_individual_balance(clusters, original_fair_mat, normalize_balance)
        num_mistakes, _, _ = align_clusters(ground_truth, clusters)

        # Save the output
        for i in range(adj_mat.shape[0]):
            self.output[i] = str(clusters[i]) + ',' + str(balances[i]) + ',' + str(ground_truth[i])
        self.output['AvgBalance'] = str(avg_balance)
        self.output['NumMistakes'] = str(num_mistakes)

        # Compute modularity
        parition = dict()
        clusters = clusters.tolist()
        for idx, i in enumerate(clusters):
            if i not in parition:
                parition[i] = []
            parition[i].append(idx)
        parition = [set(parition[key]) for key in parition]
        self.output['Modularity'] = modularity(nx.from_numpy_array(adj_mat), parition)

        # Compute ratio-cut
        ratio_cut = 0.0
        all_nodes = set(range(adj_mat.shape[0]))
        for p in parition:
            ratio_cut += nx.algorithms.cuts.cut_size(nx.from_numpy_array(adj_mat), p, all_nodes - p) / len(p)
        self.output['Ratio-Cut'] = ratio_cut


# Prepare configurations
config_common = {
    'use_fair_sc': True,
    'normalize_laplacian': True,
    'normalize_evec': False,
    'num_clusters': 5,
    'num_nodes': 500,
    'num_groups': 5,
    'p_val': 0.4,
    'q_val': 0.3,
    'r_val': 0.2,
    's_val': 0.1,
    'fair_in': 0.8,
    'fair_out': 0.2,
    'normalize_balance': True
}
configs = [
    {
        'use_fair_sc': False
    },
    {
        'fair_mat_post_process_op': None
    },
    {
        'fair_mat_post_process_op': discover_groups,
        'post_process_op_name': 'discover_groups',
        'post_process_op_args': {'num_groups': 5}
    },
    {
        'fair_mat_post_process_op': low_rank_approx,
        'post_process_op_name': 'low_rank_approx',
        'post_process_op_args': {'rank': 5}
    }
]
use_config = 0
n_sims = 10
seeds = [np.random.randint(1000000) for _ in range(n_sims)]
base_dir = './Results'
name = 'default'

# Create base directory if necessary
if not os.path.exists(base_dir):
    os.mkdir(base_dir)

# Remove experiment directory if necessary
exp_dir = os.path.join(base_dir, name)
if os.path.exists(exp_dir):
    rmtree(exp_dir)
os.mkdir(exp_dir)

# Combine the configuration
curr_config = dict((key, config_common[key]) for key in config_common)
for key in configs[use_config]:
    curr_config[key] = configs[use_config][key]

# Start the experiments
for sim in range(n_sims):
    experiment = SBMExperiment(os.path.join(exp_dir, 'sim-' + str(sim)), seeds[sim], 'SBM', **curr_config)
    experiment.run()
    print('Done:', sim + 1, 'of', n_sims)
