import argparse
import os
import sys
import time
import torch
import torch.nn.functional as F
import torchvision
import models
import utils
import tabulate
import numpy as np
from copy import deepcopy

# os.environ["CUDA_VISIBLE_DEVICES"] = "6"
parser = argparse.ArgumentParser(description='SGD/SWA training')
parser.add_argument('--cuda_visible_devices', type=str, default='0', help='cuda_visible_devices (default: GPU0)')
parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)')
parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)')

parser.add_argument('--distances', type=int, default=20, metavar='N', help='explore radius (default: 20)')
parser.add_argument('--distances_scale', type=float, default=1.0, metavar='N', help='explore scale (default: 1)')
parser.add_argument('--explore_times', type=int, default=50, metavar='N', help='explore times (default: 50)')

parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)')
parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL',
                    help='model name (default: None)')

parser.add_argument('--resume', type=str, default='./train_swa_solid2/checkpoint-340.pt', metavar='CKPT',
                    help='checkpoint to resume training from (default: None)')
parser.add_argument('--rand', type=str, default='rand', metavar='rand',
                    help='rand or randn selection (default: rand)')

parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)')
parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)')
parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)')
parser.add_argument('--lr_init', type=float, default=0.1, metavar='LR', help='initial learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)')

parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--swa_start', type=float, default=161, metavar='N', help='SWA start epoch number (default: 161)')
parser.add_argument('--swa_lr', type=float, default=0.05, metavar='LR', help='SWA LR (default: 0.05)')
parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N',
                    help='SWA model collection frequency/cycle length in epochs (default: 1)')

parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices

sgd_train_loss_results = np.zeros((args.explore_times, args.distances * 2 + 1))
sgd_train_accuracy_results = np.zeros((args.explore_times, args.distances * 2 + 1))
sgd_test_loss_results = np.zeros((args.explore_times, args.distances * 2 + 1))
sgd_test_accuracy_results = np.zeros((args.explore_times, args.distances * 2 + 1))

swa_train_loss_results = np.zeros((args.explore_times, args.distances * 2 + 1))
swa_train_accuracy_results = np.zeros((args.explore_times, args.distances * 2 + 1))
swa_test_loss_results = np.zeros((args.explore_times, args.distances * 2 + 1))
swa_test_accuracy_results = np.zeros((args.explore_times, args.distances * 2 + 1))

print('Preparing directory %s' % args.dir)
os.makedirs(args.dir, exist_ok=True)
with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
    f.write(' '.join(sys.argv))
    f.write('\n')
os.system('cp -r ./' + sys.argv[0] + ' ./' + args.dir + '/')
f_out = open(os.path.join(args.dir, 'output_record.txt'), 'w')
f_out0 = open(os.path.join(args.dir, 'output_record0.txt'), 'w')
f_out1 = open(os.path.join(args.dir, 'output_record1.txt'), 'w')
f_out2 = open(os.path.join(args.dir, 'output_record2.txt'), 'w')
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

print('Using model %s' % args.model)
model_cfg = getattr(models, args.model)

print('Loading dataset %s from %s' % (args.dataset, args.data_path))
ds = getattr(torchvision.datasets, args.dataset)
path = os.path.join(args.data_path, args.dataset.lower())
train_set = ds(path, train=True, download=True, transform=model_cfg.transform_train)
test_set = ds(path, train=False, download=True, transform=model_cfg.transform_test)
loaders = {
    'train': torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    ),
    'test': torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
}
num_classes = max(train_set.train_labels) + 1

print('Preparing model')
model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_randvector = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_temp = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)

model.cuda()
model_randvector.cuda()
model_temp.cuda()
# model_temp.cuda()


if args.swa:
    print('SWA training')
    swa_model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
    swa_model.cuda()
    swa_n = 0
else:
    print('SGD training')


def schedule(epoch):
    t = (epoch) / (args.swa_start if args.swa else args.epochs)
    lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
        factor = lr_ratio
    return args.lr_init * factor


criterion = F.cross_entropy
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=args.lr_init,
    momentum=args.momentum,
    weight_decay=args.wd
)

start_epoch = 0
if args.resume is not None:
    print('Resume training from %s' % args.resume)
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    model_temp.load_state_dict(checkpoint['state_dict'])
    model_randvector.load_state_dict(checkpoint['state_dict'])
    # model_temp.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if args.swa:
        swa_state_dict = checkpoint['swa_state_dict']
        if swa_state_dict is not None:
            swa_model.load_state_dict(swa_state_dict)
        swa_n_ckpt = checkpoint['swa_n']
        if swa_n_ckpt is not None:
            swa_n = swa_n_ckpt

columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']
if args.swa:
    columns = columns[:-1] + ['swa_te_loss', 'swa_te_acc'] + columns[-1:]
    swa_res = {'loss': None, 'accuracy': None}

# utils.save_checkpoint(
#     args.dir,
#     start_epoch,
#     state_dict=model.state_dict(),
#     randvector_dict=model_randvector.state_dict,
#     swa_state_dict=swa_model.state_dict() if args.swa else None,
#     swa_n=swa_n if args.swa else None,
#     optimizer=optimizer.state_dict()
# )

norm_sum = torch.Tensor(1)
for explore_time in range(1, args.explore_times + 1):
    time_ep = time.time()
    # generate random vector
    print('explore time is' + str(explore_time))
    print('explore time is' + str(explore_time), file=f_out)
    norm_sum = 0

    print(model_randvector.state_dict(), file=f_out0)
    # dict_temp = deepcopy(model_randvector.state_dict())
    for key in model_randvector.state_dict():
        # print(key, file=f_out)
        if 'running_mean' not in key and 'running_var' not in key and 'num_batches_tracked' not in key:
            if 'bn' not in key:
                a = torch.zeros(model_randvector.state_dict()[key].shape).cuda()
            else:
                if (args.rand == 'rand'):
                    a = torch.rand(model_randvector.state_dict()[key].shape).cuda()
                if (args.rand == 'randn'):
                    a = torch.randn(model_randvector.state_dict()[key].shape).cuda()
                    # print(a, file=f_out)
            model_randvector.state_dict()[key].data.copy_(a)
    print(model_randvector.state_dict(), file=f_out2)

    for param1 in model_randvector.parameters():
        norm_temp = torch.norm(param1.data, 2).cuda()
        norm_temp = norm_temp * norm_temp
        norm_sum += norm_temp
    norm_sum = torch.sqrt(norm_sum)

    for param1 in model_randvector.parameters():
        param1.data = param1.data / norm_sum
    print(norm_sum, file=f_out)
    print(model_randvector.state_dict(), file=f_out1)

    # print(model_randvector.state_dict(), file=f_out2)

    # save weight vector
    utils.save_checkpoint(
        args.dir,
        explore_time,  # vector, sgd and swa mark
        randvec_state_dict=model_randvector.state_dict(),
        # state_dict=model.state_dict(),
        # swa_state_dict=swa_model.state_dict(),
        optimizer=optimizer.state_dict()
    )


    for distance in range(-args.distances, args.distances + 1):
        print(distance)
        print(distance, file=f_out)
        # utils.adding_weight(model, model_randvector, args.distances, args.distances_scale, dis_counter)
        # utils.bn_update(loaders['test'], model)

        for param1, param2, param3 in zip(model.parameters(), model_randvector.parameters(), model_temp.parameters()):
            param3.data = param1.data + distance * param2.data
        utils.bn_update(loaders['train'], model_temp)
        sgd_train_temp = utils.eval(loaders['train'], model_temp, criterion)
        sgd_train_loss_results[explore_time - 1][distance + args.distances] = sgd_train_temp['loss']
        sgd_train_accuracy_results[explore_time - 1][distance + args.distances] = sgd_train_temp['accuracy']
        sgd_test_temp = utils.eval(loaders['test'], model_temp, criterion)
        sgd_test_loss_results[explore_time - 1][distance + args.distances] = sgd_test_temp['loss']
        sgd_test_accuracy_results[explore_time - 1][distance + args.distances] = sgd_test_temp['accuracy']

        for param1, param2, param3 in zip(swa_model.parameters(), model_randvector.parameters(),
                                          model_temp.parameters()):
            param3.data = param1.data + distance * param2.data
        utils.bn_update(loaders['train'], model_temp)
        swa_train_temp = utils.eval(loaders['train'], model_temp, criterion)
        swa_train_loss_results[explore_time - 1][distance + args.distances] = swa_train_temp['loss']
        swa_train_accuracy_results[explore_time - 1][distance + args.distances] = swa_train_temp['accuracy']
        swa_test_temp = utils.eval(loaders['test'], model_temp, criterion)
        swa_test_loss_results[explore_time - 1][distance + args.distances] = swa_test_temp['loss']
        swa_test_accuracy_results[explore_time - 1][distance + args.distances] = swa_test_temp['accuracy']

        # print("test", test_temp)
        # print("sgd exploring on train and test set %d"%(dis_counter))

    np.savetxt(os.path.join(args.dir, "sgd_train_loss_results.txt"), sgd_train_loss_results)
    np.savetxt(os.path.join(args.dir, "sgd_train_accuracy_results.txt"), sgd_train_accuracy_results)
    np.savetxt(os.path.join(args.dir, "sgd_test_loss_results.txt"), sgd_test_loss_results)
    np.savetxt(os.path.join(args.dir, "sgd_test_accuracy_results.txt"), sgd_test_accuracy_results)

    np.savetxt(os.path.join(args.dir, "swa_train_loss_results.txt"), swa_train_loss_results)
    np.savetxt(os.path.join(args.dir, "swa_train_accuracy_results.txt"), swa_train_accuracy_results)
    np.savetxt(os.path.join(args.dir, "swa_test_loss_results.txt"), swa_test_loss_results)
    np.savetxt(os.path.join(args.dir, "swa_test_accuracy_results.txt"), swa_test_accuracy_results)

    time_ep = time.time() - time_ep
    print('time is', time_ep)
    print('time is', time_ep, file=f_out)

os.system('echo "' + ' '.join(sys.argv) + '" | mail -s "a program finished just now" 962086838@qq.com')

f_out.close()
f_out0.close()
f_out1.close()
f_out2.close()



