import torch
import numpy as np
import random
from alg.utils import SArunexp

#########################################################
# This code runs the meta-algorithm and SGD+test-set on #
# on *one* dataset and *one* pre-train fraction.        #
#########################################################

# To obtain the desired plots, this needs to be run over all of them, by changing the following params
DATASET_ID = 0  # Dataset number, in {0,...,99} 
FRACTION_PRETRAIN = 0.1  # Data fraction to train initial hypothesis (or SGD), in [0,1)

# number of datapoints in each dataset
N_POINTS_DATASET =   200 # in {200, 500}

# Data and model 
MODEL_TYPE = 'fcn'           
NAME_DATA = 'sinc'

# Hyper params
TRAIN_EPOCHS = 1000
LEARNING_RATE = 0.1
MOMENTUM = 0.95
DROPOUT_PROB = 0
BATCH_SIZE =  2*N_POINTS_DATASET # ensure no minibatches

# Confidence & size of tube
DELTA = 0.035
C = .10 

# Fix device and random seed
DEVICE = torch.device('cpu')
random_seed =    10 
torch.manual_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)
G = torch.Generator(device='cpu').manual_seed(random_seed)

# jointly run meta-algo (called SA) and SGD+Test-set
SGD_ub, sa_ub, SA_p_misclass, SGD_p_misclass = SArunexp(DATASET_ID, G, C, NAME_DATA, DELTA, LEARNING_RATE, MOMENTUM, BATCH_SIZE, TRAIN_EPOCHS, DROPOUT_PROB, N_POINTS_DATASET, FRACTION_PRETRAIN, DEVICE)
    
print(SGD_ub, sa_ub, SA_p_misclass, SGD_p_misclass)
