import numpy as np
import json
import argparse
from dataset_loader import Mydatasets
from datasets_directory.dataset_loader import MyLogisticRegression
from opt_algs  import private_newton, CompareAlgs
from myutils import eps_to_zcdp


def helper_fun(datasetname,pb,num_rep):
    """ This function is a helper function for running different algorithms

    datasetname = name of the dataset
    pb = a dictionary with the parameters
    num_rep = number of times we repeat the optimization algorithm to report the average
    """
    datasets = Mydatasets()
    X,y,w_opt = getattr(datasets,datasetname)()
    dataset = X,y
    priv_param = pb["total"]
    num_samples = len(y)
    delta = (1.0/num_samples)**2
    total = pb["total"]
    rho_eq = eps_to_zcdp(total,delta)
    pb["total"] = rho_eq
    print("privacy constraint is DP!"+' equaivalent rho: '+str(pb["total"]))
    lr = MyLogisticRegression(X,y,reg=1e-8)
    c = CompareAlgs(lr,dataset,w_opt,iters=pb["num_iteration"],pb=pb)
    for rep in range(num_rep):
        print(str(rep+1)+" expriment out of "+ str(num_rep))
        c.add_algo(private_newton,"private-newton")

        losses_dict = c.loss_vals()
        gradnorm_dict = c.gradnorm_vals()
        accuracy_dict = c.accuracy_vals()
        wall_clock_dict = c.wall_clock_alg()
        if rep == 0:
            losses_total = losses_dict
            gradnorm_total = gradnorm_dict
            accuracy_total = accuracy_dict
            wall_clock_total = wall_clock_dict
        else:
            for names in losses_total.keys():
                losses_total[names].extend(losses_dict[names])
                gradnorm_total[names].extend(gradnorm_dict[names])
                accuracy_total[names].extend(accuracy_dict[names])
                wall_clock_total[names].extend(wall_clock_dict[names])

    result = {}
    accuracy_wopt = c.accuracy_np()
    result['num-samples'] = num_samples
    result['acc-best'] = accuracy_wopt.tolist()
    for alg in losses_total.keys():
        losses = np.array(losses_total[alg])
        gradnorm = np.array(gradnorm_total[alg])
        acc = np.array(accuracy_total[alg])
        wall_clock = np.array(wall_clock_total[alg])
        result[alg] = {}
        result[alg]["loss_avg"] = (np.mean(losses, axis=0)).tolist()
        result[alg]["loss_std"] = (np.std(losses, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["gradnorm_avg"] = np.mean(gradnorm, axis=0).tolist()
        result[alg]["gradnorm_std"] = (np.std(gradnorm, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["acc_avg"] = (np.mean(acc, axis=0)).tolist()
        result[alg]["acc_std"] = (np.std(acc, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["clock_time_avg"] = np.mean(wall_clock, axis=0).tolist()
        result[alg]["clock_time_std"] =  (np.std(wall_clock, axis=0) / np.sqrt(num_rep)).tolist()

    json.dump(result, open("results/full-batch/newton_"+datasetname+"_"+str(priv_param)+"_"+privacy_type+"_"+str(pb["num_iteration"])+".txt", 'w'))



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--datasetname")
    parser.add_argument("--total")
    parser.add_argument("--numiter")
    args = parser.parse_args()
    datasetname = args.datasetname
    total = float(args.total) # total privacy budget 
    num_iter = int(args.numiter)  # number of iterations
    pb = {
      "total": total,  # Total privacy budget
      "num_iteration": num_iter,
      "grad_frac": 0.7
    }
    num_rep = 8 # the number of repetitions for averaging over the randomness 
    helper_fun(datasetname,pb,num_rep=num_rep)


if __name__ == '__main__':
    main()
