import matlab.engine
import sys
sys.path.append('..')
import numpy as np 
import torch
from hyperbox import Hyperbox
from interval_analysis import HBoxIA
from relu_nets import ReLUNet
from lipMIP import LipProblem
from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip
from neural_nets import train
from neural_nets import data_loaders as dl
from experiment import Experiment, InstanceGroup, Result, ResultList, MethodNest
from utilities import Factory
import math
import neural_nets.data_loaders as dl 
import neural_nets.train as train
import pickle
box_methods = [CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, LipProblem, SeqLip]



def dim_scale(k, val, dim):
    if k not in ['SeqLip', 'LipSDP']:
        return val
    else:
        return math.sqrt(dim) *val


def format_resultList(results, dim):
    times = results.average_stdevs('time')
    values = results.average_stdevs('value')
    rel_errs = results.get_rel_err(dim)
    keys = sorted(list(times.keys()))
    max_len_k = max(len(_) for _ in keys + ['method'])

    pad = lambda s: s + ' ' * (max_len_k - len(s))
    def dim_scale(k, val):
        if k not in ['SeqLip', 'LipSDP']:
            return val
        else:
            return math.sqrt(dim) *val
    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s
    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])
    print(header + '\n' + '-' * len(header))
    key_order = [_[0] for _ in sorted([(k, dim_scale(k, values[k][0])) for k in keys], key=lambda p: p[1])]
    for k in key_order:
        elements = [pad(k),  
                    '|{:10.4f}'.format(times[k][0]),
                    '|{:10.4f}'.format(times[k][1]),
                    '|{:10.4f}'.format(dim_scale(k, values[k][0])),
                    '|{:10.4f}'.format(dim_scale(k, values[k][1])),
                    '|{:10.4f}'.format(rel_errs[k][0] * 100),
                    '|{:10.4f}'.format(rel_errs[k][1] * 100)]
        print(' '.join(elements))



def test_synthetic_relaxations(layer_sizes, k, mip_gaps):
    # Good parameters are [10, 20, 30, 20, 2] 
    data_params = dl.RandomKParameters(num_points=2000, k=20, radius=0.02, 
                         dimension=layer_sizes[0])
    dataset = dl.RandomDataset(data_params, batch_size=128, 
                               random_seed=1234)
    train_set, _ = dataset.split_train_val(1.0)

    train_params = train.TrainParameters(train_set, train_set, 500,
                                         test_after_epoch=20)
    results = []
    c_vector = np.array([1.0, -1.0])
    for _ in range(k):
        net = ReLUNet(layer_sizes=layer_sizes)
        train.training_loop(net, train_params)        
        trial_dict = {}
        for mip_gap in mip_gaps:
            exp = Experiment([LipProblem], network=net, c_vector=c_vector,
                             primal_norm='linf', verbose=True, num_threads=2,
                             mip_gap=mip_gap)
            trial_dict[mip_gap] = exp.do_unit_hypercube_eval()
        lp_exp = Experiment([LipLP], network=net, c_vector=c_vector, 
                            primal_norm='linf', verbose=True, num_threads=2)
        trial_dict['LP'] = lp_exp.do_unit_hypercube_eval()

        results.append(trial_dict)
    return results


def test_random_relaxations(layer_sizes, k, mip_gaps):
    # Good parameters are [16, 16, 16, 2]
    assert layer_sizes[-1] == 2
    c_vector = np.array([1.0, -1.0])
    results = []
    for _ in range(k):
        random_net = ReLUNet(layer_sizes=layer_sizes)
        trial_dict = {}
        for mip_gap in mip_gaps:
            exp = Experiment([LipProblem], network=random_net, 
                             c_vector=c_vector, primal_norm='linf', 
                             verbose=True, num_threads=2, mip_gap=mip_gap)
            trial_dict[mip_gap] = exp.do_unit_hypercube_eval()


        lp_exp = Experiment([LipLP], network=random_net, c_vector=c_vector,
                            primal_norm='linf', num_threads=2)
        trial_dict['LP'] = lp_exp.do_unit_hypercube_eval()
        results.append(trial_dict)
    return results



if __name__ == '__main__':
    RANDOM_LAYERS = [16, 16, 16, 2]
    SYNTHETIC_LAYERS = [10, 20, 30, 20, 2]
    MIP_GAPS = [1.00, 0.10, 0.01, 0.00]
    K = 100
    for fxn, layers, filename in [(test_random_relaxations, RANDOM_LAYERS, 'exp_3_random.pkl'),
                                  (test_synthetic_relaxations, SYNTHETIC_LAYERS, 'exp_3_synthetic.pkl'),
                                  #(test_mnist_bin, MNIST_BIN_LAYERS, 'exp_1_MNIST_bin.pkl')
                                  ]:
        try:
            results = fxn(layers, K, MIP_GAPS)
            with open(filename, 'wb') as f:
                pickle.dump(results, f)
        except:
            pass

