import matlab.engine
import sys
sys.path.append('..')
import numpy as np 
import torch
from hyperbox import Hyperbox
from interval_analysis import HBoxIA
from relu_nets import ReLUNet
from lipMIP import LipProblem
from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip
from neural_nets import train
from neural_nets import data_loaders as dl
from experiment import Experiment, InstanceGroup, Result, ResultList, MethodNest
from utilities import Factory
import math
import neural_nets.data_loaders as dl 
import neural_nets.train as train
import pickle
box_methods = [CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, LipProblem, SeqLip]



def dim_scale(k, val, dim):
    if k not in ['SeqLip', 'LipSDP']:
        return val
    else:
        return math.sqrt(dim) *val


def format_resultList(results, dim):
    times = results.average_stdevs('time')
    values = results.average_stdevs('value')
    rel_errs = results.get_rel_err(dim)
    keys = sorted(list(times.keys()))
    max_len_k = max(len(_) for _ in keys + ['method'])

    pad = lambda s: s + ' ' * (max_len_k - len(s))
    def dim_scale(k, val):
        if k not in ['SeqLip', 'LipSDP']:
            return val
        else:
            return math.sqrt(dim) *val
    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s
    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])
    print(header + '\n' + '-' * len(header))
    key_order = [_[0] for _ in sorted([(k, dim_scale(k, values[k][0])) for k in keys], key=lambda p: p[1])]
    for k in key_order:
        elements = [pad(k),  
                    '|{:10.4f}'.format(times[k][0]),
                    '|{:10.4f}'.format(times[k][1]),
                    '|{:10.4f}'.format(dim_scale(k, values[k][0])),
                    '|{:10.4f}'.format(dim_scale(k, values[k][1])),
                    '|{:10.4f}'.format(rel_errs[k][0] * 100),
                    '|{:10.4f}'.format(rel_errs[k][1] * 100)]
        print(' '.join(elements))


def test_random(layer_sizes, k):
    # Good parameters are [16, 16, 16, 2]
    assert layer_sizes[-1] == 2
    c_vector = np.array([1.0, -1.0])
    results = []
    for _ in range(k):
        random_net = ReLUNet(layer_sizes=layer_sizes)
        exp = Experiment(box_methods, network=random_net, c_vector=c_vector, primal_norm='l1', 
                         verbose=True, num_threads=2) 
        methods = MethodNest(Experiment.do_unit_hypercube_eval)
        results.append(methods(exp))
    return ResultList(results)


def test_synthetic(layer_sizes, k):
    # Good parameters are [10, 20, 30, 20, 2] 
    data_params = dl.RandomKParameters(num_points=2000, k=20, radius=0.02, 
                         dimension=layer_sizes[0])
    dataset = dl.RandomDataset(data_params, batch_size=128, 
                               random_seed=1234)
    train_set, _ = dataset.split_train_val(1.0)

    train_params = train.TrainParameters(train_set, train_set, 500,
                                         test_after_epoch=20)
    results = []
    c_vector = np.array([1.0, -1.0])
    for _ in range(k):
        net = ReLUNet(layer_sizes=layer_sizes)
        train.training_loop(net, train_params)
        exp = Experiment(box_methods, network=net, c_vector=c_vector, primal_norm='l1')
        results.append(exp.do_unit_hypercube_eval())
    return ResultList(results)

def test_mnist_bin(layer_sizes, k):
    DIGITS = [1,7]
    # Good params here are [784, 20, 20, 2]
    results = []
    train_data = dl.load_mnist_data('train', digits=DIGITS, use_cuda=True)
    val_data = dl.load_mnist_data('val', digits=DIGITS, use_cuda=True)
    xentropyReg = train.XEntropyReg()
    l1_reg = train.LpWeightReg(scalar=1e-4, lp='l1')
    loss = train.LossFunctional(regularizers=[xentropyReg, l1_reg])
    train_params = train.TrainParameters(train_data, val_data, 10, loss_functional=loss)
    cvec = np.array([1.0,-1.0])
    for i in range(k):
        net = ReLUNet(layer_sizes)
        train.training_loop(net, train_params, use_cuda=True)
        exp = Experiment(box_methods, network=net, c_vector=cvec, primal_norm='l1', verbose=True, num_threads=4) 
        results.append(MethodNest(Experiment.do_unit_hypercube_eval)(exp))
    return ResultList(results)





if __name__ == '__main__':
    RANDOM_LAYERS = [16, 16, 16, 2]
    SYNTHETIC_LAYERS = [10, 20, 30, 20, 2]
    MNIST_BIN_LAYERS = [784, 40, 40, 2]

    K = 20
    for fxn, layers, filename in [(test_random, RANDOM_LAYERS, 'l1_exp_1_random.pkl'), 
                                  (test_synthetic, SYNTHETIC_LAYERS, 'l1_exp_1_synthetic.pkl'),
                                  #(test_mnist_bin, MNIST_BIN_LAYERS, 'exp_1_MNIST_bin.pkl')
                                  ]:
        try:
            results = fxn(layers, K)
            with open(filename, 'wb') as f:
                pickle.dump(results, f)
        except:
            pass

