import os,sys,os.path
import torch
from torch import nn
from torch.nn.modules.linear import Linear
from torch.utils.data import Dataset
import torch.optim as optim
import torchvision.models as models
import numpy as np
from tensorboardX import SummaryWriter
import pickle
from tqdm import tqdm
import copy
import gc
import random
import time
import argparse

from models import FMNIST_CNN, CNN_CIFAR_dropout,CNN_CIFAR100_dropout
# from option import args_parser
from utils import Accuracy,average_weights,LocalUpdate, get_gradients_fc, get_gradients
from sampling import LocalDataset, LocalDataloaders,  partition_data_various_alpha

from clustering_utils import cluster_sampling
from DivFL_utils import submod_sampling
from HiCS_utils import HiCS_sampling, magnitude_gradient

parser = argparse.ArgumentParser()

#Data specifc paremeters
parser.add_argument('--dataset', default='CIFAR10',
                    help='CIFAR10, CIFAR100, FMNIST') 
#Training specifc parameters
parser.add_argument('--batch_size', type=int, default=64,
                    help='minibatch size')
parser.add_argument('--num_epochs', type=int, default=200,
                    help='number of epochs')   
parser.add_argument('--num_clients',  type=int, default=20,
                    help='number of local models')

parser.add_argument('--sampling_rate', type=float,default=0.1,
                    help='frac of local models to update')
parser.add_argument('--local_ep',type=int, default=2,
                    help='iterations of local updating')
parser.add_argument('--alphas', type=str, default="0.1",
                    help='alpha for non-iid distribution')
parser.add_argument('--T', type=float, default= 0.0025,
                    help='scaling parameter T')
parser.add_argument('--seed', type=int,default=0,
                    help='random seed for generating datasets')
parser.add_argument('--alg', default='random',
                    help='random, pow-d, CS, DivFL, HiCS')

#-----------------------------------------------------------------------------------------------------------
# Optimization

parser.add_argument('--lr', type=float, default=0.001,help='learning rate')

parser.add_argument('--momentum', type=float, default=0.9, help='Optimizer momentum value')

parser.add_argument('--weight_decay', type=float, default=1e-4, help='Optimizer weight decay value')

args = parser.parse_args()
print(args)


print(torch.__version__)
torch.cuda.is_available()
np.set_printoptions(threshold=np.inf)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device.type)

if args.dataset == 'FMNIST':
    global_model = FMNIST_CNN()
    num_class = 10
if args.dataset == 'CIFAR10':
    global_model = CNN_CIFAR_dropout()
    num_class = 10
if args.dataset == 'CIFAR100':
    global_model = CNN_CIFAR100_dropout()
    num_class = 100
print('# model parameters:', sum(param.numel() for param in global_model.parameters()))
global_model.to(device)
for m in global_model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))

alphas = args.alphas.split(',')
for i in range(len(alphas)):
    alphas[i] = float(alphas[i])
    
train_dataset,testset, dict_users, _ = partition_data_various_alpha(n_users = args.num_clients, alphas= alphas,rand_seed = args.seed, dataset=args.dataset)
Loaders_train = LocalDataloaders(train_dataset,dict_users,args.batch_size,ShuffleorNot = True)
Counts = []
print('label distribution for each client: ')
for idx in range(args.num_clients):
    print(idx, end=": ")
    counts = [0]*num_class
    for batch_idx,(X,y) in enumerate(Loaders_train[idx]):
        batch = len(y)
        y = np.array(y)
        for i in range(batch):
            counts[int(y[i])] += 1
    Counts.append(counts)
    print(counts)
    
LocalModels = []
for idx in range(args.num_clients):
    LocalModels.append(LocalUpdate(args,global_model,Loaders_train[idx], idx = idx, device=device))
    
# training
args.num_epochs = args.num_epochs
loader_test = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,shuffle=True, num_workers=2)

n_samples = np.array([len(client.trainloader.dataset) for client in LocalModels])
weights = n_samples / np.sum(n_samples)
print("Clients' weights:", weights)
K = args.num_clients

Global_acc = []
Losses = []
Losses_std = []
global_weights = global_model.state_dict()

n_sampled = max(int(args.sampling_rate * args.num_clients), 1)
count_clients = [0]*args.num_clients

if args.alg == 'CS' or args.alg == 'DivFL':
    gradients = get_gradients("clustered_2", global_model, [global_model] * K)

if args.alg == 'HiCS':
    gradients = get_gradients_fc("clustered_2", global_model, [global_model] * K)
    magnitudes = magnitude_gradient(gradients)

for epoch in tqdm(range(args.num_epochs)):
   
    previous_global_model = copy.deepcopy(global_model)
    start_time = time.time()
    
    print(f'\n | Global Training Round : {epoch+1} |\n')
    global_model.train()
    np.random.seed(epoch)
    
    print(args.alg)
    if args.alg == 'random':
        print("random sampling")
        sampled_clients = np.random.choice(K, size=n_sampled, replace=False, p=weights)
    
    if args.alg == 'pow-d':
        print("pow-d sampling")
        if epoch == 0:
            sampled_clients = np.random.choice(K, size=n_sampled, replace=False, p=weights)
        else:
            power_of_choice = np.random.choice(K, size=2*n_sampled, replace=False, p=weights)
            sampled_clients = power_of_choice[np.argsort(np.array(local_loss)[power_of_choice])[-n_sampled:]]
            
    if args.alg == 'DivFL':
        print("DivFL sampling")
        sampled_clients = submod_sampling(epoch, gradients, n_sampled, args.num_clients, stochastic = True)
        
    if args.alg == 'CS':
        print("clustered sampling")
        random_pool = list(range(K))
        if epoch < int(args.num_clients/n_sampled):
            sampled_clients = random_pool[epoch*n_sampled:(epoch + 1)*n_sampled]

        
        else:
            sampled_clients = cluster_sampling(gradients, n_sampled,  weights, "cosine")

    if args.alg == 'HiCS':
        print("HiCS sampling")
        n_samples = np.array([len(client.trainloader.dataset) for client in LocalModels])
        random_pool = list(range(K))
        if epoch < int(args.num_clients/n_sampled):
            sampled_clients = random_pool[epoch*n_sampled:(epoch + 1)*n_sampled]

        else:
            magnitudes = magnitude_gradient(gradients)
            sampled_clients = HiCS_sampling(gradients,magnitudes, args.T, "cosine",  n_samples, n_sampled, args.num_epochs, epoch)
        
        
    print("selection in epoch: ", epoch)
    print(sampled_clients) 
    
    local_weights = []
    local_loss = []
    clients_models = []
    sampled_clients_for_grad = []
    local_loss = []
    for idx in sampled_clients:
        count_clients[idx] += 1
        LocalModels[idx].load_model(global_weights)
        
        w = LocalModels[idx].update_weights_prox(global_round=epoch, mu = 0.1)
        local_weights.append(copy.deepcopy(w))
        clients_models.append(copy.deepcopy(LocalModels[idx].model))
        sampled_clients_for_grad.append(idx)

            
     # update global weights
    global_weights = average_weights(local_weights)
    global_model.load_state_dict(global_weights)
    
    for idx in range(args.num_clients):
        LocalModels[idx].load_model(global_weights)
        l = LocalModels[idx].train_loss()
        local_loss.append(l) 
    
    average_loss = np.sum(np.array(local_loss))/K
    variance = np.sum((np.array(local_loss) - average_loss)**2)/K
    std = np.sqrt(variance)
    Losses_std.append(std)
    Losses.append(average_loss)
    
    if args.alg == 'CS' or args.alg == 'DivFL':
        gradients_i = get_gradients("clustered_2", previous_global_model, clients_models)
        for idx, gradient in zip(sampled_clients_for_grad, gradients_i):
            gradients[idx] = gradient
            
    if args.alg == 'HiCS':
        if epoch < int(args.num_clients/n_sampled):
            gradients_i = get_gradients_fc("clustered_2", previous_global_model, clients_models)
            for idx, gradient in zip(sampled_clients_for_grad, gradients_i):
                gradients[idx] = gradient

            
                
    accuracy = 0
    cnt = 0
    global_model.eval()
    for cnt, (X,y) in enumerate(loader_test):
        X = X.to(device)
        y = y.double().to(device)
        p = global_model(X)
        y_pred = p.argmax(1).double()
        accuracy += Accuracy(y,y_pred)
        cnt += 1
    print("accuracy of global test:",accuracy/cnt)
    print("average local loss:",average_loss)
    Global_acc.append(accuracy/cnt)
    end_time = time.time()
    print('training time:',  end_time - start_time)
    