import argparse
import os
import sys
import time
import torch
import torch.nn.functional as F
import torchvision
import models
import utils
import tabulate
import numpy as np


def parameters_to_vector(parameters):
    r"""Convert parameters to one vector
    Arguments:
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    Returns:
        The parameters represented by a single vector
    """
    # Flag for the device where the parameter is located
    param_device = None

    vec = []
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        vec.append(param.view(-1))
    return torch.cat(vec)


def vector_to_parameters(vec, parameters):
    r"""Convert one vector to the parameters
    Arguments:
        vec (Tensor): a single vector represents the parameters of a model.
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    """
    # Ensure vec of type Tensor
    if not isinstance(vec, torch.Tensor):
        raise TypeError('expected torch.Tensor, but got: {}'
                        .format(torch.typename(vec)))
    # Flag for the device where the parameter is located
    param_device = None

    # Pointer for slicing the vector for each parameter
    pointer = 0
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        # The length of the parameter
        num_param = param.numel()
        # Slice the vector, reshape it, and replace the old data of the parameter
        param.data = vec[pointer:pointer + num_param].view_as(param).data

        # Increment the pointer
        pointer += num_param


def _check_param_device(param, old_param_device):
    r"""This helper function is to check if the parameters are located
    in the same device. Currently, the conversion between model parameters
    and single vector form is not supported for multiple allocations,
    e.g. parameters in different GPUs, or mixture of CPU/GPU.
    Arguments:
        param ([Tensor]): a Tensor of a parameter of a model
        old_param_device (int): the device where the first parameter of a
                                model is allocated.
    Returns:
        old_param_device (int): report device for the first time
    """

    # Meet the first parameter
    if old_param_device is None:
        old_param_device = param.get_device() if param.is_cuda else -1
    else:
        warn = False
        if param.is_cuda:  # Check if in same GPU
            warn = (param.get_device() != old_param_device)
        else:  # Check if in CPU
            warn = (old_param_device != -1)
        if warn:
            raise TypeError('Found two parameters on different devices, '
                            'this is currently not supported.')
    return old_param_device

def moving_average_bn(net1, net2, alpha=1):
    for key in net2.state_dict():
        if 'bn' in key:
            if 'weight' in key or 'bias' in key:
                a = net1.state_dict()[key].data
                b = net2.state_dict()[key].data
                a *= (1.0 - alpha)
                a += b * alpha
                net1.state_dict()[key].data.copy_(a)
            else:
                b = net2.state_dict()[key].data
                net1.state_dict()[key].data.copy_(b)
        else:
            b = net2.state_dict()[key].data
            net1.state_dict()[key].data.copy_(b)

def moving_average_conv(net1, net2, alpha=1):
    for key in net2.state_dict():
        if 'conv' in key:
            if 'weight' in key or 'bias' in key:
                a = net1.state_dict()[key].data
                b = net2.state_dict()[key].data
                a *= (1.0 - alpha)
                a += b * alpha
                net1.state_dict()[key].data.copy_(a)
                # print(key, file=f_temp)
            else:
                b = net2.state_dict()[key].data
                net1.state_dict()[key].data.copy_(b)
        else:
            b = net2.state_dict()[key].data
            net1.state_dict()[key].data.copy_(b)

def moving_average_conv_l1(net1, net2, alpha=1):
    for key in net2.state_dict():
        if 'conv' in key and 'layer1' in key:
            if 'weight' in key or 'bias' in key:
                a = net1.state_dict()[key].data
                b = net2.state_dict()[key].data
                a *= (1.0 - alpha)
                a += b * alpha
                net1.state_dict()[key].data.copy_(a)
                # print(key, file=f_temp)
            else:
                b = net2.state_dict()[key].data
                net1.state_dict()[key].data.copy_(b)
        else:
            b = net2.state_dict()[key].data
            net1.state_dict()[key].data.copy_(b)

def moving_average_conv_l2(net1, net2, alpha=1):
    for key in net2.state_dict():
        if 'conv' in key and 'layer2' in key:
            if 'weight' in key or 'bias' in key:
                a = net1.state_dict()[key].data
                b = net2.state_dict()[key].data
                a *= (1.0 - alpha)
                a += b * alpha
                net1.state_dict()[key].data.copy_(a)
                # print(key, file=f_temp)
            else:
                b = net2.state_dict()[key].data
                net1.state_dict()[key].data.copy_(b)
        else:
            b = net2.state_dict()[key].data
            net1.state_dict()[key].data.copy_(b)

def conv_layer_marker(net1, alpha=1):
    for key in net1.state_dict():
        if 'conv' in key:
            if 'weight' in key or 'bias' in key:
                a = torch.ones((net1.state_dict()[key].data.shape))
                net1.state_dict()[key].data.copy_(a)
                # print(key, file=f_temp)
            else:
                a = torch.zeros((net1.state_dict()[key].data.shape))
                net1.state_dict()[key].data.copy_(a)
        else:
            a = torch.zeros((net1.state_dict()[key].data.shape))
            net1.state_dict()[key].data.copy_(a)

parser = argparse.ArgumentParser(description='SGD/SWA training')
parser.add_argument('--cuda_visible_devices', type=str, default='0', help='cuda_visible_devices (default: GPU0)')
parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)')
parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)')
parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)')
parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL',
                    help='model name (default: None)')

parser.add_argument('--resume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume training from (default: None)')

parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)')
parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)')
parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)')
parser.add_argument('--lr_init', type=float, default=0.1, metavar='LR', help='initial learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)')

parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--swa_start', type=float, default=161, metavar='N', help='SWA start epoch number (default: 161)')
parser.add_argument('--swa_lr', type=float, default=0.05, metavar='LR', help='SWA LR (default: 0.05)')
parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N',
                    help='SWA model collection frequency/cycle length in epochs (default: 1)')

parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices


print('Preparing directory %s' % args.dir)
os.makedirs(args.dir, exist_ok=True)
with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
    f.write(' '.join(sys.argv))
    f.write('\n')
os.system('cp -r ./' + sys.argv[0] + ' ./' + args.dir + '/')
f_out = open (os.path.join(args.dir, 'output_record.txt'),'w')
f_temp = open (os.path.join(args.dir, 'output_record_temp.txt'),'w')
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

print('Using model %s' % args.model, file = f_out)
model_cfg = getattr(models, args.model)

print('Loading dataset %s from %s' % (args.dataset, args.data_path), file = f_out)
ds = getattr(torchvision.datasets, args.dataset)
path = os.path.join(args.data_path, args.dataset.lower())
train_set = ds(path, train=True, download=True, transform=model_cfg.transform_train)
test_set = ds(path, train=False, download=True, transform=model_cfg.transform_test)
loaders = {
    'train': torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    ),
    'test': torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
}
num_classes = max(train_set.train_labels) + 1

print('Preparing model', file = f_out)
model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_avr_bn = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_avr_conv = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_avr_conv_layer1 = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_avr_conv_layer2 = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
model_avr_conv_rand = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)

model_avr_bn.cuda()
model_avr_conv.cuda()
model_avr_conv_layer1.cuda()
model_avr_conv_layer2.cuda()
model_avr_conv_rand.cuda()
model.cuda()


if args.swa:
    print('SWA training', file = f_out)
    swa_model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
    swa_model.cuda()
    swa_n = 0
else:
    print('SGD training', file = f_out)


def schedule(epoch):
    t = (epoch) / (args.swa_start if args.swa else args.epochs)
    lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
    if t <= 0.5:
        factor = 1.0
    elif t <= 0.9:
        factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
    else:
        factor = lr_ratio
    return args.lr_init * factor


criterion = F.cross_entropy
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=args.lr_init,
    momentum=args.momentum,
    weight_decay=args.wd
)

start_epoch = 0
if args.resume is not None:
    print('Resume training from %s' % args.resume, file = f_out)
    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if args.swa:
        swa_state_dict = checkpoint['swa_state_dict']
        if swa_state_dict is not None:
            swa_model.load_state_dict(swa_state_dict)

        swa_n_ckpt = checkpoint['swa_n']
        if swa_n_ckpt is not None:
            swa_n = swa_n_ckpt


# utils.save_checkpoint(
#     args.dir,
#     start_epoch,
#     state_dict=model.state_dict(),
#     swa_state_dict=swa_model.state_dict() if args.swa else None,
#     swa_n=swa_n if args.swa else None,
#     optimizer=optimizer.state_dict()
# )


sgd_train_loss = np.zeros(args.epochs)
sgd_train_acc = np.zeros(args.epochs)
sgd_test_loss = np.zeros(args.epochs)
sgd_test_acc = np.zeros(args.epochs)

swa_train_loss = np.zeros(args.epochs)
swa_train_acc = np.zeros(args.epochs)
swa_test_loss = np.zeros(args.epochs)
swa_test_acc = np.zeros(args.epochs)

swa_bn_train_loss = np.zeros(args.epochs)
swa_bn_train_acc = np.zeros(args.epochs)
swa_bn_test_loss = np.zeros(args.epochs)
swa_bn_test_acc = np.zeros(args.epochs)

swa_conv_train_loss = np.zeros(args.epochs)
swa_conv_train_acc = np.zeros(args.epochs)
swa_conv_test_loss = np.zeros(args.epochs)
swa_conv_test_acc = np.zeros(args.epochs)

swa_conv_l1_train_loss = np.zeros(args.epochs)
swa_conv_l1_train_acc = np.zeros(args.epochs)
swa_conv_l1_test_loss = np.zeros(args.epochs)
swa_conv_l1_test_acc = np.zeros(args.epochs)

swa_conv_l2_train_loss = np.zeros(args.epochs)
swa_conv_l2_train_acc = np.zeros(args.epochs)
swa_conv_l2_test_loss = np.zeros(args.epochs)
swa_conv_l2_test_acc = np.zeros(args.epochs)

swa_conv_rand_train_loss = np.zeros(args.epochs)
swa_conv_rand_train_acc = np.zeros(args.epochs)
swa_conv_rand_test_loss = np.zeros(args.epochs)
swa_conv_rand_test_acc = np.zeros(args.epochs)

#random select conv 1.6%
threshold = 10000/1146842
vec_temp = parameters_to_vector(model_avr_conv_layer2.parameters())
vec_temp = torch.rand(vec_temp.shape).cuda()
# print(vec_temp[0:100])
vec_temp_threshold = torch.ones((vec_temp.shape)).cuda()*threshold
# print(vec_temp_threshold[0:100])
vec_temp = (vec_temp < vec_temp_threshold).float()
print(torch.sum(vec_temp))
# print(vec_temp[0:100])

vec_conv_marker = torch.zeros((vec_temp.shape)).cuda()
vector_to_parameters(vec_conv_marker, model_avr_conv_layer2.parameters())
conv_layer_marker(model_avr_conv_layer2)
vec_conv_marker = parameters_to_vector(model_avr_conv_layer2.parameters())

while(torch.sum(vec_conv_marker*vec_temp) < 16160):
# print(vec_conv_marker[0:1000])
    rand_location = np.random.randint(0, high=1146842, size=None)
    if vec_temp[rand_location] == 1:
        continue
    else:
        vec_temp[rand_location] = 1
print(torch.sum(vec_conv_marker))
vec_conv_marker = vec_conv_marker * vec_temp
# print(vec_conv_marker[0:1000])
print(torch.sum(vec_conv_marker))

first_conv_marker  = torch.argmax(vec_conv_marker)
print(first_conv_marker)



for epoch in range(start_epoch, args.epochs):
    time_ep = time.time()

    lr = schedule(epoch)
    utils.adjust_learning_rate(optimizer, lr)
    train_result_temp = utils.train_epoch(loaders['train'], model, criterion, optimizer)
    test_result_temp = utils.eval(loaders['test'], model, criterion)
    sgd_train_loss[epoch] = train_result_temp['loss']
    sgd_train_acc[epoch] = train_result_temp['accuracy']
    sgd_test_loss[epoch] = test_result_temp['loss']
    sgd_test_acc[epoch] = test_result_temp['accuracy']

    if args.swa and (epoch + 1) >= args.swa_start and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0:
        moving_average_bn(model_avr_bn, model, 1.0 / (swa_n + 1))
        # moving_average_conv(model_avr_conv, model, 1.0 / (swa_n + 1))
        # moving_average_conv_l1(model_avr_conv_layer1, model, 1.0 / (swa_n + 1))
        # moving_average_conv_l2(model_avr_conv_layer2, model, 1.0 / (swa_n + 1))
        utils.moving_average(swa_model, model, 1.0 / (swa_n + 1))

        vec_model_convrand = parameters_to_vector(model_avr_conv_rand.parameters())
        vec_model = parameters_to_vector(model.parameters())
        if swa_n==0:
            vector_to_parameters(vec_model, model_avr_conv_rand.parameters())
        else:
            vector_to_parameters((vec_model_convrand * swa_n * vec_conv_marker+ vec_model * vec_conv_marker) / (swa_n + vec_conv_marker) + vec_model * (1-vec_conv_marker),
                                 model_avr_conv_rand.parameters())



        swa_n += 1

        utils.bn_update(loaders['train'], model_avr_bn)
        # utils.bn_update(loaders['train'], model_avr_conv)
        # utils.bn_update(loaders['train'], model_avr_conv_layer1)
        # utils.bn_update(loaders['train'], model_avr_conv_layer2)
        utils.bn_update(loaders['train'], model_avr_conv_rand)
        utils.bn_update(loaders['train'], swa_model)

        train_result_temp = utils.eval(loaders['train'], model_avr_bn, criterion)
        test_result_temp = utils.eval(loaders['test'], model_avr_bn, criterion)
        swa_bn_train_loss[epoch] = train_result_temp['loss']
        swa_bn_train_acc[epoch] = train_result_temp['accuracy']
        swa_bn_test_loss[epoch] = test_result_temp['loss']
        swa_bn_test_acc[epoch] = test_result_temp['accuracy']



        train_result_temp = utils.eval(loaders['train'], model_avr_conv_rand, criterion)
        test_result_temp = utils.eval(loaders['test'], model_avr_conv_rand, criterion)
        swa_conv_rand_train_loss[epoch] = train_result_temp['loss']
        swa_conv_rand_train_acc[epoch] = train_result_temp['accuracy']
        swa_conv_rand_test_loss[epoch] = test_result_temp['loss']
        swa_conv_rand_test_acc[epoch] = test_result_temp['accuracy']

        train_result_temp = utils.eval(loaders['train'], swa_model, criterion)
        test_result_temp = utils.eval(loaders['test'], swa_model, criterion)
        swa_train_loss[epoch] = train_result_temp['loss']
        swa_train_acc[epoch] = train_result_temp['accuracy']
        swa_test_loss[epoch] = test_result_temp['loss']
        swa_test_acc[epoch] = test_result_temp['accuracy']

    print(epoch,
          model.state_dict()['conv1.weight'][0,0],
          model_avr_bn.state_dict()['conv1.weight'][0,0],
          model_avr_conv_rand.state_dict()['conv1.weight'][0,0],
          swa_model.state_dict()['conv1.weight'][0,0],

          model.state_dict()['layer1.0.bn1.weight'][0:3],
          model_avr_bn.state_dict()['layer1.0.bn1.weight'][0:3],
          model_avr_conv_rand.state_dict()['layer1.0.bn1.weight'][0:3],
          swa_model.state_dict()['layer1.0.bn1.weight'][0:3],

          model.state_dict()['layer1.0.bn1.running_var'][0:3],
          model_avr_bn.state_dict()['layer1.0.bn1.running_var'][0:3],
          model_avr_conv_rand.state_dict()['layer1.0.bn1.running_var'][0:3],
          swa_model.state_dict()['layer1.0.bn1.running_var'][0:3],

          file=f_temp
          )
    f_temp.flush()
    np.savetxt(os.path.join(args.dir, 'sgd_train_loss.txt'), sgd_train_loss)
    np.savetxt(os.path.join(args.dir, 'sgd_train_acc.txt'), sgd_train_acc)
    np.savetxt(os.path.join(args.dir, 'sgd_test_acc.txt'), sgd_test_acc)
    np.savetxt(os.path.join(args.dir, 'sgd_test_loss.txt'), sgd_test_loss)

    np.savetxt(os.path.join(args.dir, 'swa_bn_test_acc.txt'), swa_bn_test_acc)
    np.savetxt(os.path.join(args.dir, 'swa_bn_test_loss.txt'), swa_bn_test_loss)
    np.savetxt(os.path.join(args.dir, 'swa_bn_train_acc.txt'), swa_bn_train_acc)
    np.savetxt(os.path.join(args.dir, 'swa_bn_train_loss.txt'), swa_bn_train_loss)

    np.savetxt(os.path.join(args.dir, 'swa_conv_test_acc.txt'), swa_conv_test_acc)
    np.savetxt(os.path.join(args.dir, 'swa_conv_test_loss.txt'), swa_conv_test_loss)
    np.savetxt(os.path.join(args.dir, 'swa_conv_train_acc.txt'), swa_conv_train_acc)
    np.savetxt(os.path.join(args.dir, 'swa_conv_train_loss.txt'), swa_conv_train_loss)

    np.savetxt(os.path.join(args.dir, 'swa_conv_l1_test_acc.txt'), swa_conv_l1_test_acc)
    np.savetxt(os.path.join(args.dir, 'swa_conv_l1_test_loss.txt'), swa_conv_l1_test_loss)
    np.savetxt(os.path.join(args.dir, 'swa_conv_l1_train_acc.txt'), swa_conv_l1_train_acc)
    np.savetxt(os.path.join(args.dir, 'swa_conv_l1_train_loss.txt'), swa_conv_l1_train_loss)

    np.savetxt(os.path.join(args.dir, 'swa_conv_l2_test_acc.txt'), swa_conv_l2_test_acc)
    np.savetxt(os.path.join(args.dir, 'swa_conv_l2_test_loss.txt'), swa_conv_l2_test_loss)
    np.savetxt(os.path.join(args.dir, 'swa_conv_l2_train_acc.txt'), swa_conv_l2_train_acc)
    np.savetxt(os.path.join(args.dir, 'swa_conv_l2_train_loss.txt'), swa_conv_l2_train_loss)

    np.savetxt(os.path.join(args.dir, 'swa_conv_rand_test_acc.txt'), swa_conv_rand_test_acc)
    np.savetxt(os.path.join(args.dir, 'swa_conv_rand_test_loss.txt'), swa_conv_rand_test_loss)
    np.savetxt(os.path.join(args.dir, 'swa_conv_rand_train_acc.txt'), swa_conv_rand_train_acc)
    np.savetxt(os.path.join(args.dir, 'swa_conv_rand_train_loss.txt'), swa_conv_rand_train_loss)

    np.savetxt(os.path.join(args.dir, 'swa_train_loss.txt'), swa_train_loss)
    np.savetxt(os.path.join(args.dir, 'swa_train_acc.txt'), swa_train_acc)
    np.savetxt(os.path.join(args.dir, 'swa_test_acc.txt'), swa_test_acc)
    np.savetxt(os.path.join(args.dir, 'swa_test_loss.txt'), swa_test_loss)
    time_ep = time.time() - time_ep
    print(parameters_to_vector(model.parameters())[first_conv_marker])
    print(parameters_to_vector(model_avr_conv_rand.parameters())[first_conv_marker])
    print(parameters_to_vector(model.parameters())[first_conv_marker + 1])
    print(parameters_to_vector(model_avr_conv_rand.parameters())[first_conv_marker + 1])
    print(parameters_to_vector(model.parameters())[first_conv_marker + 2])
    print(parameters_to_vector(model_avr_conv_rand.parameters())[first_conv_marker + 2])
    print(parameters_to_vector(model_avr_conv_rand.parameters())[0])
    print(epoch, lr, time_ep)
    print(epoch, lr, time_ep, file=f_out)

f_out.close()
f_temp.close()