# -*- coding: utf-8 -*-
# Description: Demo_MNIST

import argparse
import torch
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

from src import datasets
from src import train_eval
from src import CIFAR10_model as my_model
import pickle

#### args
parser = argparse.ArgumentParser()

parser.add_argument("-HT", "--high_level", help="High Level Tasks: standard standard_dp_bn l2 spectral Bayes Bayes_dp", type = str, default='Bayes')

parser.add_argument("-TT", "--train_type", help="Training type: Standard Bayes Spectral .", type=str, default='B')

parser.add_argument("-re", "--resnet", help="ResNet (1) or not (0).", type=int, default=0)

##### dataset
parser.add_argument("-d", "--dataset", help="Dataset for experiment.", type=str, default='CIFAR10')
parser.add_argument("-dp", "--dataset_path", help="Path for dataset.", type=str, default='./data')
parser.add_argument("-dir", "--output_dir", help="Path for output directory.", type=str, default='./results/')
parser.add_argument("-b", "--batch_size", help="Batch size.", type=int, default=256)
parser.add_argument("-tA", "--trainAll", help="Train on entire (1) or partial (0) dataset.", type=int, default=0)
parser.add_argument("-bn", "--BatchNorm", help="Use batch normalization (1) or not (0).", type=int, default=0)
## only works if args.trainAll == 0
parser.add_argument("-s", "--trainSize_seed", help="Random seed for num of training data.", type=int, default=2)
parser.add_argument("-tS", "--trainSize", help="Num of training data.", type=int, default=50000)
##

parser.add_argument("-t", "--tau", help="Coefficient for the L2-regularization", type=float, default=0.0)
parser.add_argument("-e", "--epoch", help="Number of training epochs.", type=int, default=300)
parser.add_argument("-r", "--lrate", help="Learning rate.", type=float, default=5e-4)
parser.add_argument("-dpR", "--dp_rate", help="Dropout Rate. 0.0 if no dropout rate.", type=float, default=0.0) 

##### Bayes Regularization
parser.add_argument("-rT", "--regu_type", help="Regularization type: ALL, CONV, FC, or LAST", type=str, default='LAST')
parser.add_argument("-p", "--rho", help="Coefficient for the trace-regularization", type=float, default=1e-4)
parser.add_argument("-ol", "--outer_loop", help="Number of outer loops for coordinate descent.", type=int, default=30)
parser.add_argument("-lT", "--lower_threshold", help="Lower bound of the smallest singular value.", type=float, default=1e-3)
parser.add_argument("-uT", "--upper_threshold", help="Upper bound of the largest singular value.", type=float, default=1e3)

# Compile and configure all the model parameters.
args = parser.parse_args()

train_loader, test_loader = datasets.get_Dataset(args)
    
print('Start Bayes Trianing')
myNet = my_model.BayesNet(args)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, myNet.parameters()), lr=args.lrate, weight_decay=args.tau)

acc = train_eval.bayes_training_evaluation(myNet, optimizer, train_loader, test_loader, args)
print(acc)