import torch
import random
import numpy as np
from pnn.utils import runexp


#########################################################
# This code runs the PAC_BAYES on *one* dataset 
# and *one* pre-train fraction.        
#########################################################
DATASET_ID = 0      # Dataset number, in {0,...,59} 
PERC_PRIOR = 0.1    # Data fraction to train prior, in [0,1)]
MC_SAMPLES = 1000   # We use 10000, but this requires more execution time 
       
# Options on what objective to optimize and how
PRIOR_DIST = 'gaussian'      # distribution used in prior/posterior
OBJECTIVE = 'invkl'          # training bound used :'invkl', 'quad', ...
TRAIN_METHOD = 'original'    # conditional 0-1 or with cross entropy: 'conditional' or 'original'
PRIOR_TRAIN_OPT = 'det'      # 'det' (train NN with SGD) or 'pb' (train PNN w PAC-Bayes objective)

# NN model and Input data
MODEL_TYPE = 'fcn'           
NAME_DATA =  'binarymnist'

# Full dataset size and nuner of point for each subdataset used
DATASET_SIZE = 60000         # full dataset size
USED_DATASET_SIZE = 1000     # number of points used for training              

# Hyper params prior
PRIOR_EPOCHS = 100
LEARNING_RATE_PRIOR = .01
MOMENTUM_PRIOR = .95
SIGMAPRIOR = .01

# Hyper params
BATCH_SIZE = 250 
TRAIN_EPOCHS = 100
LEARNING_RATE = .001
MOMENTUM = .9
DROPOUT_PROB = .2

# Device and Random initialization
DEVICE = torch.device("cpu")
RANDOM_SEED = 10
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
g = torch.Generator(device='cpu').manual_seed(RANDOM_SEED)

# Confidence parameter (confidence on final bound is DELTA+DELTA_TEST)
DELTA = 0.025     
DELTA_TEST = 0.01 


# run PAC-Bayes prior and posterior training
risk_certificate, test_loss01 = runexp(DATASET_ID, SIGMAPRIOR, LEARNING_RATE, MOMENTUM, prior_train_opt=PRIOR_TRAIN_OPT, train_method=TRAIN_METHOD, model_type=MODEL_TYPE, objective=OBJECTIVE, prior_dist=PRIOR_DIST, learning_rate_prior=LEARNING_RATE_PRIOR, momentum_prior=MOMENTUM_PRIOR, delta=DELTA, delta_test=DELTA_TEST, mc_samples=MC_SAMPLES, train_epochs=TRAIN_EPOCHS, prior_epochs=PRIOR_EPOCHS, batch_size=BATCH_SIZE, device=DEVICE, used_dataset_size = USED_DATASET_SIZE, perc_prior=PERC_PRIOR, verbose=False, verbose_test=False, dropout_prob=DROPOUT_PROB, name_data=NAME_DATA)

# print risk certificcate and misclass on test set
print(float(risk_certificate), float(test_loss01))

