import matlab.engine
import matlab.engine
import numpy as np
import torch
import sys
sys.path.append('..')
from experiment import Experiment, MethodNest, Job
from hyperbox import Hyperbox
from relu_nets import ReLUNet
from neural_nets import data_loaders as dl
from neural_nets import train
from lipMIP import LipProblem
from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip
from other_methods import LOCAL_METHODS, GLOBAL_METHODS, OTHER_METHODS
from utilities import Factory, DoEvery
import utilities as utils
import os
import time
import pickle 


def build_network(layers, train_params, data_seed=None):
    network = ReLUNet(layers)
    train.training_loop(network, train_params)
    return network

def make_eij(i,j, num_classes):
    output = np.zeros(num_classes)
    output[i] = 1.0
    output[j] = -1.0
    return output

def naive_multiclass(network, box):
    timer = utils.Timer()
    lip_vals = {}
    num_classes = network.layer_sizes[-1]
    i = network.classify_np(box.get_center())
    for j in range(num_classes):
        if j == i:
            continue
        cvec = make_eij(i,j, num_classes)
        lip_prob = LipProblem(network, box, c_vector=cvec, primal_norm='linf', verbose=False)
        ij_result = lip_prob.compute_max_lipschitz()
        print((i,j), ij_result.compute_time)
        lip_vals[(i,j)] = ij_result.value
    return (timer.stop(), min(lip_vals.values()))


def do_data_box(network, datum, radius):
    box = Hyperbox.build_linf_ball(datum, radius)
    xlip = LipProblem(network, box, 'crossLipschitz', primal_norm='linf', verbose=True)
    xlip_result = xlip.compute_max_lipschitz()
    xlip_tv = xlip_result.compute_time, xlip_result.value

    naive_tv = naive_multiclass(network, box)
    print('-' * 100)
    print(xlip_tv, naive_tv)

    return (xlip_tv, naive_tv)




if __name__ == '__main__':
    DIMENSION = 4
    NUM_CLASSES = 20
    LAYERS = [DIMENSION, 40, 40, 40, NUM_CLASSES]
    NUM_TRIALS = 5
    NUM_SAMPLES = 100

    data_params = dl.RandomKParameters(1337, 40, radius=0.01, dimension=DIMENSION,
                                       num_classes=NUM_CLASSES)
    dataset = dl.RandomDataset(data_params, random_seed=420)
    trainset, _ = dataset.split_train_val(1.0)
    train_batch = trainset[0][0]
    xentropy = train.XEntropyReg()
    l1_reg = train.LpWeightReg(scalar=5e-4)
    loss = train.LossFunctional(regularizers=[xentropy, l1_reg])
    train_params = train.TrainParameters(trainset, trainset, 1000, loss_functional=loss, 
                                         test_after_epoch=20)


    for rad, radname in [(0.1, 'rad01'), (0.2, 'rad02')]:
        results = {'crossLip': [], 'naive': []}
        for trial in range(NUM_TRIALS):
            network = build_network(LAYERS, train_params)
            for i in range(NUM_SAMPLES):
                data = train_batch[i]
                result = do_data_box(network, data, rad)
                results['crossLip'].append(result[0])
                results['naive'].append(result[1])
        with open('FIXED_experiment_multiclass_%s.pkl' % radname, 'wb') as f:
            pickle.dump(results, f)

