import argparse
import os
import sys
import time
import torch
import torch.nn.functional as F
import torchvision
import models
import utils
import tabulate
import numpy as np

def parameters_to_vector(parameters):
    r"""Convert parameters to one vector
    Arguments:
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    Returns:
        The parameters represented by a single vector
    """
    # Flag for the device where the parameter is located
    param_device = None

    vec = []
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        vec.append(param.view(-1))
    return torch.cat(vec)


def vector_to_parameters(vec, parameters):
    r"""Convert one vector to the parameters
    Arguments:
        vec (Tensor): a single vector represents the parameters of a model.
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    """
    # Ensure vec of type Tensor
    if not isinstance(vec, torch.Tensor):
        raise TypeError('expected torch.Tensor, but got: {}'
                        .format(torch.typename(vec)))
    # Flag for the device where the parameter is located
    param_device = None

    # Pointer for slicing the vector for each parameter
    pointer = 0
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        # The length of the parameter
        num_param = param.numel()
        # Slice the vector, reshape it, and replace the old data of the parameter
        param.data = vec[pointer:pointer + num_param].view_as(param).data

        # Increment the pointer
        pointer += num_param


def _check_param_device(param, old_param_device):
    r"""This helper function is to check if the parameters are located
    in the same device. Currently, the conversion between model parameters
    and single vector form is not supported for multiple allocations,
    e.g. parameters in different GPUs, or mixture of CPU/GPU.
    Arguments:
        param ([Tensor]): a Tensor of a parameter of a model
        old_param_device (int): the device where the first parameter of a
                                model is allocated.
    Returns:
        old_param_device (int): report device for the first time
    """

    # Meet the first parameter
    if old_param_device is None:
        old_param_device = param.get_device() if param.is_cuda else -1
    else:
        warn = False
        if param.is_cuda:  # Check if in same GPU
            warn = (param.get_device() != old_param_device)
        else:  # Check if in CPU
            warn = (old_param_device != -1)
        if warn:
            raise TypeError('Found two parameters on different devices, '
                            'this is currently not supported.')
    return old_param_device

#os.environ["CUDA_VISIBLE_DEVICES"] = "6"
parser = argparse.ArgumentParser(description='SGD/SWA training')
parser.add_argument('--cuda_visible_devices', type=str, default='0', help='cuda_visible_devices (default: GPU0)')
parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)')
parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)')

parser.add_argument('--distances', type=int, default=20, metavar='N', help='explore radius (default: 20)')
parser.add_argument('--distances_scale', type=float, default=1.0, metavar='N', help='explore scale (default: 1)')
parser.add_argument('--deviate_param', type=float, default=1.0, metavar='N', help='explore scale (default: 1)')
parser.add_argument('--explore_times', type=int, default=50, metavar='N', help='explore times (default: 50)')

parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)')
parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL',
                    help='model name (default: None)')

parser.add_argument('--model1_resume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume training from (default: None)')
parser.add_argument('--model2_resume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume training from (default: None)')
parser.add_argument('--rand', type=str, default='rand', metavar='rand',
                    help='rand or randn selection (default: rand)')

parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)')
parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)')
parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)')
parser.add_argument('--lr_init', type=float, default=0.1, metavar='LR', help='initial learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)')

parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--swa_start', type=float, default=161, metavar='N', help='SWA start epoch number (default: 161)')
parser.add_argument('--swa_lr', type=float, default=0.05, metavar='LR', help='SWA LR (default: 0.05)')
parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N',
                    help='SWA model collection frequency/cycle length in epochs (default: 1)')

parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices

sgd_train_loss_results = np.zeros((args.explore_times, args.distances*2+1))
sgd_train_accuracy_results = np.zeros((args.explore_times, args.distances*2+1))
sgd_test_loss_results = np.zeros((args.explore_times, args.distances*2+1))
sgd_test_accuracy_results = np.zeros((args.explore_times, args.distances*2+1))

swa_train_loss_results = np.zeros((args.explore_times, args.distances*2+1))
swa_train_accuracy_results = np.zeros((args.explore_times, args.distances*2+1))
swa_test_loss_results = np.zeros((args.explore_times, args.distances*2+1))
swa_test_accuracy_results = np.zeros((args.explore_times, args.distances*2+1))

print('Preparing directory %s' % args.dir)
os.makedirs(args.dir, exist_ok=True)
with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
    f.write(' '.join(sys.argv))
    f.write('\n')
os.system('cp -r ./'+sys.argv[0]+' ./'+args.dir+'/')
f_out = open (os.path.join(args.dir, 'output_record.txt'),'w')
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Loading dataset %s from %s' % (args.dataset, args.data_path))
ds = getattr(torchvision.datasets, args.dataset)
path = os.path.join(args.data_path, args.dataset.lower())
train_set = ds(path, train=True, download=True, transform=model_cfg.transform_train)
test_set = ds(path, train=False, download=True, transform=model_cfg.transform_test)
loaders = {
    'train': torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    ),
    'test': torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
}
num_classes = max(train_set.train_labels) + 1

print('Preparing model')
model_randvector = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_temp =  model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_1 = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_2 = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)

model_randvector.cuda()
model_temp.cuda()
model_1.cuda()
model_2.cuda()
#model_temp.cuda()




def schedule(epoch):
    t = (epoch) / (args.swa_start if args.swa else args.epochs)
    lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
        factor = lr_ratio
    return args.lr_init * factor


criterion = F.cross_entropy
optimizer = torch.optim.SGD(
    model_1.parameters(),
    lr=args.lr_init,
    momentum=args.momentum,
    weight_decay=args.wd
)



start_epoch = 0
if args.model1_resume is not None:
    print('Resume training from %s' % args.model1_resume)
    checkpoint = torch.load(args.model1_resume)
    start_epoch = checkpoint['epoch']
    model_1.load_state_dict(checkpoint['state_dict'])
    utils.bn_update(loaders['train'], model_1)
    print(utils.eval(loaders['train'], model_1, criterion))
vec_1 = parameters_to_vector(model_1.parameters())

if args.model2_resume is not None:
    print('Resume training from %s' % args.model2_resume)
    checkpoint = torch.load(args.model2_resume)
    start_epoch = checkpoint['epoch']
    model_2.load_state_dict(checkpoint['swa_state_dict'])
    model_temp.load_state_dict(checkpoint['state_dict'])
    utils.bn_update(loaders['train'], model_2)
    print(utils.eval(loaders['train'], model_2, criterion))
vec_2 = parameters_to_vector(model_2.parameters())

vec_inter = vec_1 - vec_2
vec_inter = vec_inter / torch.norm(vec_inter)
print('vec inter norm is', torch.norm(vec_inter))
print('vec inter norm is', torch.norm(vec_inter), file=f_out)
f_out.flush()


norm_sum = torch.Tensor(1)
for explore_time in range(1, args.explore_times + 1):
    time_ep = time.time()

    #generate a new exploration direction with a samll noise
    vec_inter_explore = torch.randn(vec_inter.shape).cuda()
    vec_inter_explore = vec_inter_explore / torch.norm(vec_inter_explore) * args.deviate_param
    vec_inter_explore = vec_inter + vec_inter_explore
    print('vec inter explore norm ', torch.norm(vec_inter_explore))


    #generate random vector
    print('explore time is' + str(explore_time))
    print('explore time is' + str(explore_time), file=f_out)
    norm_sum = 0

    #save weight vector
    # utils.save_checkpoint(
    #             args.dir,
    #             explore_time,                                            #vector, sgd and swa mark
    #             randvec_state_dict=model_randvector.state_dict(),
    #             #state_dict=model.state_dict(),
    #             #swa_state_dict=swa_model.state_dict(),
    #             optimizer=optimizer.state_dict()
    #         )


    dis_counter = 0

    for distance in range(-args.distances, args.distances + 1):
        print(distance)
        print(distance, file=f_out)
        #utils.adding_weight(model, model_randvector, args.distances, args.distances_scale, dis_counter)
        #utils.bn_update(loaders['test'], model)

        vec_temp = vec_1 + distance * args.distances_scale * vec_inter_explore
        vector_to_parameters(vec_temp, model_temp.parameters())

        utils.bn_update(loaders['train'], model_temp)
        sgd_train_temp = utils.eval(loaders['train'], model_temp, criterion)
        sgd_train_loss_results[explore_time - 1][distance + args.distances] = sgd_train_temp['loss']
        sgd_train_accuracy_results[explore_time - 1][distance + args.distances] = sgd_train_temp['accuracy']
        sgd_test_temp = utils.eval(loaders['test'], model_temp, criterion)
        sgd_test_loss_results[explore_time - 1][distance + args.distances] = sgd_test_temp['loss']
        sgd_test_accuracy_results[explore_time - 1][distance + args.distances] = sgd_test_temp['accuracy']

        vec_temp = vec_2 + distance * args.distances_scale * vec_inter_explore
        vector_to_parameters(vec_temp, model_temp.parameters())

        utils.bn_update(loaders['train'], model_temp)
        swa_train_temp = utils.eval(loaders['train'], model_temp, criterion)
        swa_train_loss_results[explore_time - 1][distance + args.distances] = swa_train_temp['loss']
        swa_train_accuracy_results[explore_time - 1][distance + args.distances] = swa_train_temp['accuracy']
        swa_test_temp = utils.eval(loaders['test'], model_temp, criterion)
        swa_test_loss_results[explore_time - 1][distance + args.distances] = swa_test_temp['loss']
        swa_test_accuracy_results[explore_time - 1][distance + args.distances] = swa_test_temp['accuracy']

        dis_counter += 1
        #print("test", test_temp)
        #print("sgd exploring on train and test set %d"%(dis_counter))

    np.savetxt(os.path.join(args.dir, "sgd_train_loss_results.txt"), sgd_train_loss_results)
    np.savetxt(os.path.join(args.dir, "sgd_train_accuracy_results.txt"), sgd_train_accuracy_results)
    np.savetxt(os.path.join(args.dir, "sgd_test_loss_results.txt"), sgd_test_loss_results)
    np.savetxt(os.path.join(args.dir, "sgd_test_accuracy_results.txt"), sgd_test_accuracy_results)

    np.savetxt(os.path.join(args.dir, "swa_train_loss_results.txt"), swa_train_loss_results)
    np.savetxt(os.path.join(args.dir, "swa_train_accuracy_results.txt"), swa_train_accuracy_results)
    np.savetxt(os.path.join(args.dir, "swa_test_loss_results.txt"), swa_test_loss_results)
    np.savetxt(os.path.join(args.dir, "swa_test_accuracy_results.txt"), swa_test_accuracy_results)

    time_ep = time.time() - time_ep
    print('time is', time_ep)
    print('time is', time_ep, file=f_out)

f_out.close()





