#python3 thedirection.py --dir=wide_thedirection_withbn4 --dataset=CIFAR10 --data_path=data --model=WideResNet28x10  --swa --distances=30 --distances_scale=1 --cuda_visible_devices=4 --resume=./wide_solid_withbn_4/checkpoint-50.pt
import argparse
import os
import sys
import time
import torch
import torch.nn.functional as F
import torchvision
import models
import utils
import tabulate
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib as mpl

def parameters_to_vector(parameters):
    r"""Convert parameters to one vector
    Arguments:
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    Returns:
        The parameters represented by a single vector
    """
    # Flag for the device where the parameter is located
    param_device = None

    vec = []
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        vec.append(param.view(-1))
    return torch.cat(vec)


def vector_to_parameters(vec, parameters):
    r"""Convert one vector to the parameters
    Arguments:
        vec (Tensor): a single vector represents the parameters of a model.
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    """
    # Ensure vec of type Tensor
    if not isinstance(vec, torch.Tensor):
        raise TypeError('expected torch.Tensor, but got: {}'
                        .format(torch.typename(vec)))
    # Flag for the device where the parameter is located
    param_device = None

    # Pointer for slicing the vector for each parameter
    pointer = 0
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        # The length of the parameter
        num_param = param.numel()
        # Slice the vector, reshape it, and replace the old data of the parameter
        param.data = vec[pointer:pointer + num_param].view_as(param).data

        # Increment the pointer
        pointer += num_param


def _check_param_device(param, old_param_device):
    r"""This helper function is to check if the parameters are located
    in the same device. Currently, the conversion between model parameters
    and single vector form is not supported for multiple allocations,
    e.g. parameters in different GPUs, or mixture of CPU/GPU.
    Arguments:
        param ([Tensor]): a Tensor of a parameter of a model
        old_param_device (int): the device where the first parameter of a
                                model is allocated.
    Returns:
        old_param_device (int): report device for the first time
    """

    # Meet the first parameter
    if old_param_device is None:
        old_param_device = param.get_device() if param.is_cuda else -1
    else:
        warn = False
        if param.is_cuda:  # Check if in same GPU
            warn = (param.get_device() != old_param_device)
        else:  # Check if in CPU
            warn = (old_param_device != -1)
        if warn:
            raise TypeError('Found two parameters on different devices, '
                            'this is currently not supported.')
    return old_param_device

#os.environ["CUDA_VISIBLE_DEVICES"] = "6"
parser = argparse.ArgumentParser(description='SGD/SWA training')
parser.add_argument('--cuda_visible_devices', type=str, default='0', help='cuda_visible_devices (default: GPU0)')
parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)')
parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)')

parser.add_argument('--ray_time', type=int, default=5, metavar='N', help='ray_time (default: 10)')
parser.add_argument('--distances', type=int, default=15, metavar='N', help='explore radius (default: 20)')
parser.add_argument('--distances_scale', type=float, default=1.0, metavar='N', help='explore scale (default: 1)')
# parser.add_argument('--explore_times', type=int, default=50, metavar='N', help='explore times (default: 50)')

parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)')
parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL',
                    help='model name (default: None)')

parser.add_argument('--sgdresume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume sgd training from (default: None)')
parser.add_argument('--swaresume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume swa training from (default: None)')


# parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)')
# parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)')
# parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)')
# parser.add_argument('--lr_init', type=float, default=0.1, metavar='LR', help='initial learning rate (default: 0.01)')
# parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)')

# parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
# parser.add_argument('--swa_start', type=float, default=161, metavar='N', help='SWA start epoch number (default: 161)')
# parser.add_argument('--swa_lr', type=float, default=0.05, metavar='LR', help='SWA LR (default: 0.05)')
# parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N',
#                     help='SWA model collection frequency/cycle length in epochs (default: 1)')

parser.add_argument('--seed', type=int, default=0, metavar='S', help='random seed (default: 0)')

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices


print('Preparing directory %s' % args.dir)
os.makedirs(args.dir, exist_ok=True)
with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
    f.write(' '.join(sys.argv))
    f.write('\n')
f_out = open (os.path.join(args.dir, 'output_record.txt'),'w')
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Loading dataset %s from %s' % (args.dataset, args.data_path))
ds = getattr(torchvision.datasets, args.dataset)
path = os.path.join(args.data_path, args.dataset.lower())
train_set = ds(path, train=True, download=True, transform=model_cfg.transform_train)
test_set = ds(path, train=False, download=True, transform=model_cfg.transform_test)
loaders = {
    'train': torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    ),
    'test': torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
}
num_classes = max(train_set.train_labels) + 1

print('Preparing model')
model_sgd = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_swa = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_randomray = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_temp = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_sgd.cuda()
model_swa.cuda()
model_randomray.cuda()
model_temp.cuda()


def schedule(epoch):
    t = (epoch) / (args.swa_start if args.swa else args.epochs)
    lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
        factor = lr_ratio
    return args.lr_init * factor

criterion = F.cross_entropy
# optimizer = torch.optim.SGD(
#     model_sgd.parameters(),
#     lr=args.lr_init,
#     momentum=args.momentum,
#     weight_decay=args.wd
# )



start_epoch = 0
if args.sgdresume is not None:
    print('Resume training from %s' % args.sgdresume)
    checkpoint = torch.load(args.sgdresume)
    start_epoch = checkpoint['epoch']
    model_sgd.load_state_dict(checkpoint['state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer'])

if args.swaresume is not None:
    print('Resume training from %s' % args.swaresume)
    checkpoint = torch.load(args.swaresume)
    start_epoch = checkpoint['epoch']
    model_swa.load_state_dict(checkpoint['swa_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer'])

# utils.save_checkpoint(
#     args.dir,
#     start_epoch,
#     state_dict=model.state_dict(),
#     randvector_dict=model_randvector.state_dict,
#     swa_state_dict=swa_model.state_dict() if args.swa else None,
#     swa_n=swa_n if args.swa else None,
#     optimizer=optimizer.state_dict()
# )
vec_sgd = parameters_to_vector(model_sgd.parameters()).cuda()
vec_swa = parameters_to_vector(model_swa.parameters()).cuda()
vec_temp = parameters_to_vector(model_swa.parameters()).cuda()
sgd_train_loss_randomray_result = np.zeros((args.ray_time, args.distances))
sgd_train_acc_randomray_result = np.zeros((args.ray_time, args.distances))
sgd_test_loss_randomray_result = np.zeros((args.ray_time, args.distances))
sgd_test_acc_randomray_result = np.zeros((args.ray_time, args.distances))
swa_train_loss_randomray_result = np.zeros((args.ray_time, args.distances))
swa_train_acc_randomray_result = np.zeros((args.ray_time, args.distances))
swa_test_loss_randomray_result = np.zeros((args.ray_time, args.distances))
swa_test_acc_randomray_result = np.zeros((args.ray_time, args.distances))

for i in range(0, args.ray_time):
    #generate random weight vector
    if i == 0:
        print('i=',i)
        vec_random = parameters_to_vector(model_randomray.parameters()).cuda()
        vec_random = torch.randn(vec_random.shape).cuda()
        vec_random = vec_random / torch.norm(vec_random)
        vector_to_parameters(vec_random, model_randomray.parameters())
    else:
        print('i=', i)
        vec_random = torch.randn(vec_random.shape).cuda()
        vec_random = vec_random / torch.norm(vec_random)
        vector_to_parameters(vec_random, model_randomray.parameters())
    for d in range(args.distances):
        time_start = time.time()
        #sgd
        vec_temp = vec_sgd + d * args.distances_scale * vec_random
        vector_to_parameters(vec_temp, model_temp.parameters())
        utils.bn_update(loaders['train'], model_temp)
        sgd_train_temp = utils.eval(loaders['train'], model_temp, criterion)
        sgd_test_temp = utils.eval(loaders['test'], model_temp, criterion)
        #swa
        vec_temp = vec_swa + d * args.distances_scale * vec_random
        vector_to_parameters(vec_temp, model_temp.parameters())
        utils.bn_update(loaders['train'], model_temp)
        swa_train_temp = utils.eval(loaders['train'], model_temp, criterion)
        swa_test_temp = utils.eval(loaders['test'], model_temp, criterion)

        sgd_train_loss_randomray_result[i, d] = sgd_train_temp['loss']
        sgd_train_acc_randomray_result[i, d] = sgd_train_temp['accuracy']
        sgd_test_loss_randomray_result[i, d] = sgd_test_temp['loss']
        sgd_test_acc_randomray_result[i, d] = sgd_test_temp['accuracy']
        swa_train_loss_randomray_result[i, d] = swa_train_temp['loss']
        swa_train_acc_randomray_result[i, d] = swa_train_temp['accuracy']
        swa_test_loss_randomray_result[i, d] = swa_test_temp['loss']
        swa_test_acc_randomray_result[i, d] = swa_test_temp['accuracy']

        np.savetxt(os.path.join(args.dir, 'sgd_train_loss_randomray_result.txt'), sgd_train_loss_randomray_result)
        np.savetxt(os.path.join(args.dir, 'sgd_train_acc_randomray_result.txt'), sgd_train_acc_randomray_result)
        np.savetxt(os.path.join(args.dir, 'sgd_test_loss_randomray_result.txt'), sgd_test_loss_randomray_result)
        np.savetxt(os.path.join(args.dir, 'sgd_test_acc_randomray_result.txt'), sgd_test_acc_randomray_result)
        np.savetxt(os.path.join(args.dir, 'swa_train_loss_randomray_result.txt'), swa_train_loss_randomray_result)
        np.savetxt(os.path.join(args.dir, 'swa_train_acc_randomray_result.txt'), swa_train_acc_randomray_result)
        np.savetxt(os.path.join(args.dir, 'swa_test_loss_randomray_result.txt'), swa_test_loss_randomray_result)
        np.savetxt(os.path.join(args.dir, 'swa_test_acc_randomray_result.txt'), swa_test_acc_randomray_result)
        print(i,d, time.time()-time_start)
    # plt.cla()
    # plt.plot(sgd_train_loss_randomray_result.T, c='b')
    # plt.plot(swa_train_loss_randomray_result.T, c='r')
    # plt.savefig(os.path.join(args.dir, 'sgdandswa_train_loss_randomray_result.png'))
    # plt.cla()
    # plt.plot(sgd_test_loss_randomray_result.T, c='b')
    # plt.plot(swa_test_loss_randomray_result.T, c='r')
    # plt.savefig(os.path.join(args.dir, 'sgdandswa_test_loss_randomray_result.png'))


#np.savetxt(os.path.join(args.dir, "norm_sum.txt"), norm_sum)

f_out.close()
