import matlab.engine
import sys
sys.path.append('..')
import numpy as np 
import torch
from hyperbox import Hyperbox
from interval_analysis import HBoxIA
from relu_nets import ReLUNet
from lipMIP import LipProblem
from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip
from neural_nets import train
from neural_nets import data_loaders as dl
from experiment import Experiment, InstanceGroup, Result, ResultList, MethodNest
from utilities import Factory
import math
import neural_nets.data_loaders as dl 
import neural_nets.train as train

box_methods = [CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, LipProblem, SeqLip]



def dim_scale(k, val, dim):
    if k not in ['SeqLip', 'LipSDP']:
        return val
    else:
        return math.sqrt(dim) *val


def format_resultList(results, dim):
    times = results.average_stdevs('time')
    values = results.average_stdevs('value')
    keys = sorted(list(times.keys()))
    max_len_k = max(len(_) for _ in keys + ['method'])

    pad = lambda s: s + ' ' * (max_len_k - len(s))
    def dim_scale(k, val):
        if k not in ['SeqLip', 'LipSDP']:
            return val
        else:
            return math.sqrt(dim) *val
    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s
    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD'])
    print(header + '\n' + '-' * len(header))
    key_order = [_[0] for _ in sorted([(k, dim_scale(k, values[k][0])) for k in keys], key=lambda p: p[1])]
    for k in key_order:
        elements = [pad(k),  
                    '|{:10.4f}'.format(times[k][0]),
                    '|{:10.4f}'.format(times[k][1]),
                    '|{:10.4f}'.format(dim_scale(k, values[k][0])),
                    '|{:10.4f}'.format(dim_scale(k, values[k][1]))]
        print(' '.join(elements))


def test_random(layer_sizes, k):
    assert layer_sizes[-1] == 2
    c_vector = np.array([1.0, -1.0])
    results = []
    for _ in range(k):
        random_net = ReLUNet(layer_sizes=layer_sizes)
        exp = Experiment(box_methods, network=random_net, c_vector=c_vector, primal_norm='linf', 
                         verbose=True, num_threads=2) 
        methods = MethodNest(Experiment.do_unit_hypercube_eval)
        results.append(methods(exp))
    return ResultList(results)


def test_synthetic(layer_sizes, k):
    data_params = dl.RandomKParameters(num_points=2000, k=20, radius=0.02, 
                         dimension=layer_sizes[0])
    dataset = dl.RandomDataset(data_params, batch_size=128, 
                               random_seed=1234)
    train_set, _ = dataset.split_train_val(1.0)

    train_params = train.TrainParameters(train_set, train_set, 500,
                                         test_after_epoch=20)
    results = []
    c_vector = np.array([1.0, -1.0])
    for _ in range(k):
        net = ReLUNet(layer_sizes=layer_sizes)
        train.training_loop(net, train_params)
        exp = Experiment(box_methods, network=net, c_vector=c_vector, primal_norm='linf')
        try:
            results.append(exp.do_unit_hypercube_eval())
        except: 
            pass
    return results

def test_mnist_bin(layer_sizes, digits, k):
    results = []
    train_data = dl.load_mnist_data('train', digits=digits, use_cuda=True)
    val_data = dl.load_mnist_data('val', digits=digits, use_cuda=True)
    xentropyReg = train.XEntropyReg()
    l1_reg = train.LpWeightReg(scalar=3e-3, lp='l1')
    loss = train.LossFunctional(regularizers=[xentropyReg])
    train_params = train.TrainParameters(train_data, val_data, 10, loss_functional=loss)
    for i in range(k):
        net = ReLUNet(layer_sizes)
        train.training_loop(net, train_params, use_cuda=True)
        exp = Experiment(box_methods, network=net, c_vector=cvec, primal_norm='linf', verbose=True) 
        try:
            results.append(MethodNest(Experiment.do_unit_hypercube_eval)(exp))
        except:
            pass
    return results




outputs = test_random([2, 4, 8, 2], 2)
format_resultList(outputs, 2)
