""" Script to build a bunch of networks, train each of them, and then
	build jobs to evaluate each of these

	#### PERFORMANCE SERIES FOR 2D SYNTHETIC DATA
"""
import matlab.engine
import numpy as np
import torch 
import sys 
sys.path.append('..')
from experiment import Experiment, MethodNest, Job
from hyperbox import Hyperbox 
from relu_nets import ReLUNet
from neural_nets import data_loaders as dl
from neural_nets import train
from lipMIP import LipProblem
from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip
from other_methods import LOCAL_METHODS, GLOBAL_METHODS
from utilities import Factory, DoEvery
import os 


SCHEDULE_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
					   'jobs', 'scheduled')

def main():
	NAME = 'perfSynth'
	exp_kwargs = {'c_vector': np.array([1.0, -1.0]), 
				  'primal_norm': 'linf'}
	DIMENSION = 2
	RADIUS = 0.2
	GLOBAL_LO = np.zeros(DIMENSION)
	GLOBAL_HI = np.ones(DIMENSION)
	RANDOM_SEED = 420
	DOMAIN = Hyperbox.build_unit_hypercube(DIMENSION)		
	BALL_FACTORY = Factory(Hyperbox.build_linf_ball, radius=RADIUS)
	LAYER_SEQ = [[2, 20, 20, 2],
				 [2, 20, 40, 60, 80, 100, 2]]


	def NAME_FXN(network):
		""" Returns a string based on the network """
		if len(network.layer_sizes) == 4:
			key = 'SMALL'
		else:
			key = 'BIG'
		return '%s_%s' % (NAME, key)
	# ================================================================
	# =           Data Parameters Setup                              =
	# ================================================================
	# Make both the training/validation sets 
	data_params = dl.RandomKParameters(num_points=400, k=10, radius=0.01, 
									   dimension=DIMENSION)
	dataset = dl.RandomDataset(data_params, batch_size=128, 
							   random_seed=RANDOM_SEED)
	train_set, _ = dataset.split_train_val(1.0)

	# Make the data arg_bundle object
	loader_kwargs = {'batch_size': 100, 'random_seed': RANDOM_SEED}
	data_arg_bundle = {'data_type': 'synthetic', 
					   'params': data_params,
					   'loader_kwargs': loader_kwargs,
					   'ball_factory': BALL_FACTORY}

	# ================================================================
	# =           Training Parameter Setup                           =
	# ================================================================

	# Build the loss functional and set the optimizer 
	xentropy = train.XEntropyReg()
	l2_penalty = train.LpWeightReg(scalar=1e-4, lp='l2')
	loss_functional = train.LossFunctional(regularizers=[xentropy, l2_penalty])
	train_params = train.TrainParameters(train_set, train_set, 500, 
										 loss_functional=loss_functional,
										 test_after_epoch=50)


	# ================================================================
	# =           Build the methodNests                              =
	# ================================================================

	# --- randomly evaluated method nest
	random_nest = MethodNest(Experiment.do_random_evals, 
	   						  {'sample_domain': DOMAIN, 
							   'ball_factory': BALL_FACTORY,
							   'num_random_points': 100})

	# --- data-based method nest 
	data_nest = MethodNest(Experiment.do_data_evals, data_arg_bundle)


	# --- hypercube stuff 
	cube_nest = MethodNest(Experiment.do_unit_hypercube_eval)

	local_nests = [random_nest, data_nest, cube_nest]
	global_nests = [cube_nest]


	def build_jobs(network, **exp_kwargs):
		local_exp = Experiment([FastLip, LipLP, LipProblem], network=network,
							   **exp_kwargs)
		global_exp = Experiment([LipProblem, FastLip, LipLP, SeqLip, LipSDP, NaiveUB], 
								network=network, **exp_kwargs)
		prefix = NAME_FXN(network)
		#prefix = '%s_RELUS%02d' % (NAME, network.num_relus)
		local_job_name = prefix + "_LOCAL"
		local_job = Job(local_job_name, local_exp, local_nests,
						save_loc=SCHEDULE_DIR)
		global_job_name = prefix + "_GLOBAL"
		global_job = Job(global_job_name, global_exp, global_nests,
						 save_loc=SCHEDULE_DIR)
		local_job.write()
		global_job.write()

	# ==============================================================
	# =           Train the networks                                =
	# ==============================================================

	for layer_size in LAYER_SEQ:
		print('-' * 30, "Starting training:", layer_size)
		network = ReLUNet(layer_sizes=layer_size)
		train.training_loop(network, train_params)
		build_jobs(network, **exp_kwargs)


if __name__ == '__main__':
	main()