import matlab.engine
import sys
sys.path.append('..')
import numpy as np 
import torch
from hyperbox import Hyperbox
from interval_analysis import HBoxIA
from relu_nets import ReLUNet
from lipMIP import LipProblem
from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip
from neural_nets import train
from neural_nets import data_loaders as dl
from experiment import Experiment, InstanceGroup, Result, ResultList, MethodNest
from utilities import Factory
import math
import neural_nets.data_loaders as dl 
import neural_nets.train as train
import pickle

box_methods = [CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, LipProblem, SeqLip]
local_methods = [LipProblem, LipLP, FastLip, RandomLB, CLEVER]
global_methods = [SeqLip, LipSDP, NaiveUB]
NUM_THREADS = 2

def dim_scale(k, val, dim):
    if k not in ['SeqLip', 'LipSDP']:
        return val
    else:
        return math.sqrt(dim) *val



def test_random(layer_sizes, k_trials, num_random, radius):
    # Good parameters are [16, 16, 16, 2]
    assert layer_sizes[-1] == 2
    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])
    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)
    c_vector = np.array([1.0, -1.0])
    local_results = []
    global_results = []
    for _ in range(k_trials):
        random_net = ReLUNet(layer_sizes=layer_sizes)
        local_exp = Experiment(local_methods, network=random_net, c_vector=c_vector, primal_norm='linf', 
                               verbose=True, num_threads=NUM_THREADS)
        local_results.append(local_exp.do_random_evals(num_random_points=num_random, 
                                sample_domain=sample_domain, ball_factory=ball_factory))
        global_exp = Experiment(global_methods, network=random_net, c_vector=c_vector, primal_norm='linf')
        global_results.append(global_exp.do_unit_hypercube_eval())
    return local_results, global_results


def test_synthetic(layer_sizes, k_trials, num_random, radius):
    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])
    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)
    c_vector = np.array([1.0, -1.0])
    local_results = []
    global_results = []
    for _ in range(k_trials):
        net = ReLUNet(layer_sizes=layer_sizes)
        train.training_loop(net, train_params)
        
        local_exp = Experiment(local_methods, network=net, c_vector=c_vector, primal_norm='linf', 
                               verbose=True, num_threads=NUM_THREADS)
        local_results.append(local_exp.do_random_evals(num_random_points=num_random, 
                                sample_domain=sample_domain, ball_factory=ball_factory))
        global_exp = Experiment(global_methods, network=net, c_vector=c_vector, primal_norm='linf')
        global_results.append(global_exp.do_unit_hypercube_eval())
    return local_results, global_results


def test_mnist_bin(layer_sizes, k_trials, num_random, radius):
    DIGITS = [1,7]
    # Good params here are [784, 20, 20, 2]
    data_results = []
    random_results = []
    global_results = []
    train_data = dl.load_mnist_data('train', digits=DIGITS, use_cuda=True)
    val_data = dl.load_mnist_data('val', digits=DIGITS, use_cuda=True)
    
    xentropyReg = train.XEntropyReg()
    l1_reg = train.LpWeightReg(scalar=1e-5, lp='l1')
    loss = train.LossFunctional(regularizers=[xentropyReg, l1_reg])
    train_params = train.TrainParameters(train_data, val_data, 10, loss_functional=loss)
    cvec = np.array([1.0,-1.0])
    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])
    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)
    for i in range(k_trials):
        # Train the net 
        net = ReLUNet(layer_sizes)
        train.training_loop(net, train_params, use_cuda=True)
        
        # Do the data experiments 
        local_exp = Experiment(local_methods, network=net, c_vector=cvec, primal_norm='l1', verbose=True, num_threads=NUM_THREADS)
        data_to_check = dl.load_mnist_data('val', digits=DIGITS, use_cuda=False, shuffle=True, batch_size=num_random)
        data_to_check = next(iter(data_to_check))[0].cpu().numpy()
        data_results.append(local_exp.do_data_evals(data_to_check, ball_factory))
        # Do the random point experiments
        #random_results.append(local_exp.do_random_evals(num_random, sample_domain, ball_factory ))
        
        # Do the global experiments
        global_exp = Experiment(global_methods, network=net, c_vector=cvec, primal_norm='l1', verbose=True, num_threads=NUM_THREADS) 
        global_results.append(global_exp.do_unit_hypercube_eval())
    return data_results, random_results, global_results


def format_local_globals(result_list_list, global_result, dim):
    num_per = [len(_.results) for _ in result_list_list]

    all_results = [_.results for _ in result_list_list]
    local_rl = ResultList([_ for el in all_results for _ in el])
    global_rl = ResultList(global_result)
    local_times = local_rl.average_stdevs('time')
    local_vals = local_rl.average_stdevs('value')
    global_times = global_rl.average_stdevs('time')
    global_vals = global_rl.average_stdevs('value')
    local_errs = local_rl.get_rel_err(dim)
    all_keys = list(local_times.keys()) + list(global_times.keys())
    
    # Get the values total dict:
    all_values = {}
    for val_dict in [local_vals, global_vals]:
        for k,v in val_dict.items():
            all_values[k] = (dim_scale(k,v[0], dim), dim_scale(k, v[1],dim))
    # Get times total dict:
    all_times = {}
    for time_dict in [local_times, global_times]:
        for k,v in time_dict.items():
            all_times[k] = v
   
    # Get right answers in right order:
    right_answers = [_.values('LipProblem') for _ in local_rl.results]
    rel_err_dict = {}
    
    for k in global_vals.keys():
        answer_idx = 0

        if k not in rel_err_dict:
            rel_err_dict[k] = []
        for num, result in zip(num_per, global_rl.results):
            for i in range(num):
                right_answer = right_answers[answer_idx]
                answer_idx +=1 
                rel_err_dict[k].append(dim_scale(k, result.values(k), dim) / right_answer)
    global_errs = {k: (np.array(v).mean(), np.array(v).std(), len(v)) for k,v in rel_err_dict.items()}
    all_errs = {}

    for err_dict in global_errs, local_errs:
        for k,v in err_dict.items():
            all_errs[k] = v


    # Key-order 
    key_order = [_[0] for _ in sorted(all_values.items(), key=lambda p:p[1])]
    max_len_k = max(len(_) for _ in key_order + ['method'])
    pad = lambda s: s + ' '*(max_len_k - len(s))

    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s
    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])
    print(header + '\n' + '-'*len(header))
    dformat = lambda s: '|{:10.4f}'.format(s)
    for k in key_order:
        elements = [pad(k), 
                    dformat(all_times[k][0]), dformat(all_times[k][1]),
                    dformat(all_values[k][0]), dformat(all_values[k][1]),
                    dformat(all_errs[k][0]), dformat(all_errs[k][1])]
        print(' '.join(elements))
        



if __name__ == '__main__':
    RANDOM_LAYERS = [16, 16, 16, 2]
    SYNTHETIC_LAYERS = [10, 20, 30, 20, 2]
    MNIST_BIN_LAYERS = [784, 20, 20, 20, 2]
    NUM_RANDOM = 20
    K = 20
    for rad, radstr in [(0.1, 'radius01'), (0.2, 'radius02')]:
        for fxn, layers, filename, in\
            [#(test_random, RANDOM_LAYERS, 'exp_2_random_%s.pkl' % radstr), 
             #(test_synthetic, SYNTHETIC_LAYERS, 'exp_2_synthetic_%s.pkl' % radstr),
             (test_mnist_bin, MNIST_BIN_LAYERS, 'l1_exp_2_MNIST_bin_%s.pkl' % radstr)
            ]:

            try:
                results = fxn(layers, K, NUM_RANDOM, rad)
            except:
                continue
            with open(filename, 'wb') as f:
                pickle.dump(results, f)

