import fire
import os
import sys
import time
import torch
import torchvision
from torchvision import datasets, transforms
from models import DenseNet
import argparse
import tabulate
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib as mpl

def parameters_to_vector(parameters):
    r"""Convert parameters to one vector
    Arguments:
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    Returns:
        The parameters represented by a single vector
    """
    # Flag for the device where the parameter is located
    param_device = None

    vec = []
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        vec.append(param.view(-1))
    return torch.cat(vec)


def vector_to_parameters(vec, parameters):
    r"""Convert one vector to the parameters
    Arguments:
        vec (Tensor): a single vector represents the parameters of a model.
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    """
    # Ensure vec of type Tensor
    if not isinstance(vec, torch.Tensor):
        raise TypeError('expected torch.Tensor, but got: {}'
                        .format(torch.typename(vec)))
    # Flag for the device where the parameter is located
    param_device = None

    # Pointer for slicing the vector for each parameter
    pointer = 0
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        # The length of the parameter
        num_param = param.numel()
        # Slice the vector, reshape it, and replace the old data of the parameter
        param.data = vec[pointer:pointer + num_param].view_as(param).data

        # Increment the pointer
        pointer += num_param


def _check_param_device(param, old_param_device):
    r"""This helper function is to check if the parameters are located
    in the same device. Currently, the conversion between model parameters
    and single vector form is not supported for multiple allocations,
    e.g. parameters in different GPUs, or mixture of CPU/GPU.
    Arguments:
        param ([Tensor]): a Tensor of a parameter of a model
        old_param_device (int): the device where the first parameter of a
                                model is allocated.
    Returns:
        old_param_device (int): report device for the first time
    """

    # Meet the first parameter
    if old_param_device is None:
        old_param_device = param.get_device() if param.is_cuda else -1
    else:
        warn = False
        if param.is_cuda:  # Check if in same GPU
            warn = (param.get_device() != old_param_device)
        else:  # Check if in CPU
            warn = (old_param_device != -1)
        if warn:
            raise TypeError('Found two parameters on different devices, '
                            'this is currently not supported.')
    return old_param_device

parser = argparse.ArgumentParser(description='dense net training')
parser.add_argument('--cuda_visible_devices', type=str, default='0', help='cuda_visible_devices (default: GPU0)')

parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)')
parser.add_argument('--dataset', type=str, default='CIFAR100', help='dataset name (default: CIFAR100)')
parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)')
parser.add_argument('--distances', type=int, default=20, metavar='N', help='explore radius (default: 20)')
parser.add_argument('--division_part', type=int, default=40, metavar='N', help='division_part(default: 20)')
parser.add_argument('--distances_scale', type=float, default=1.0, metavar='N', help='explore scale (default: 1)')
parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)')

parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL',
                    help='model name (default: None)')
parser.add_argument('--rand', type=str, default='rand', metavar='rand',
                    help='rand or randn selection (default: rand)')
parser.add_argument('--model1_resume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume training from (default: None)')

parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)')
parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)')
parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)')
parser.add_argument('--lr_init', type=float, default=0.1, metavar='LR', help='initial learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)')

parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args()


class AverageMeter(object):
    """
    Computes and stores the average and current value
    Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def moving_average(net1, net2, alpha=1):
    for param1, param2 in zip(net1.parameters(), net2.parameters()):
        param1.data *= (1.0 - alpha)
        param1.data += param2.data * alpha


def _check_bn(module, flag):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        flag[0] = True


def check_bn(model):
    flag = [False]
    model.apply(lambda module: _check_bn(module, flag))
    return flag[0]


def reset_bn(module):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.running_mean = torch.zeros_like(module.running_mean)
        module.running_var = torch.ones_like(module.running_var)


def _get_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        momenta[module] = module.momentum


def _set_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.momentum = momenta[module]


def bn_update(loader, model):
    """
        BatchNorm buffers update (if any).
        Performs 1 epochs to estimate buffers average using train dataset.

        :param loader: train dataset loader for buffers average estimation.
        :param model: model being update
        :return: None
    """
    if not check_bn(model):
        return
    model.train()
    momenta = {}
    model.apply(reset_bn)
    model.apply(lambda module: _get_momenta(module, momenta))
    n = 0
    for input, _ in loader:
        input = input.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        b = input_var.data.size(0)

        momentum = b / (n + b)
        for module in momenta.keys():
            module.momentum = momentum

        model(input_var)
        n += b

    model.apply(lambda module: _set_momenta(module, momenta))

def save_checkpoint(dir, epoch, **kwargs):
    state = {
        'epoch': epoch,
    }
    state.update(kwargs)
    filepath = os.path.join(dir, 'checkpoint-%d.pt' % epoch)
    torch.save(state, filepath)

def train_epoch(loader, model, criterion, optimizer):
    loss_sum = 0.0
    correct = 0.0

    model.train()

    for i, (input, target) in enumerate(loader):
        input = input.cuda(async=True)
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        output = model(input_var)
        loss = criterion(output, target_var)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_sum += loss.data[0] * input.size(0)
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target_var.data.view_as(pred)).sum().item()
    # print(len(loader.dataset))
    return {
        'loss': loss_sum / len(loader.dataset),
        'accuracy': correct / len(loader.dataset) * 100.0,
    }


def eval(loader, model, criterion):
    loss_sum = 0.0
    correct = 0.0

    model.eval()

    for i, (input, target) in enumerate(loader):
        input = input.cuda(async=True)
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        output = model(input_var)
        loss = criterion(output, target_var)

        loss_sum += loss.data[0] * input.size(0)
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target_var.data.view_as(pred)).sum().item()

    # print(len(loader.dataset))
    return {
        'loss': loss_sum / len(loader.dataset),
        # 'loss': loss_sum,
        'accuracy': correct / len(loader.dataset) * 100.0,

    }


def train(model, swa_model, train_set, test_set, save, n_epochs=300, valid_size=None,
          batch_size=64, lr=0.1, wd=0.0001, momentum=0.9, seed=None):
    if seed is not None:
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    # Create train/valid split
    # if valid_size:
    #     indices = torch.randperm(len(train_set))
    #     train_indices = indices[:len(indices) - valid_size]
    #     train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
    #     valid_indices = indices[len(indices) - valid_size:]
    #     valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indices)

    # Data loaders
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False,
                                              pin_memory=(torch.cuda.is_available()), num_workers=0)
    if valid_size:
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=train_sampler,
                                                   pin_memory=(torch.cuda.is_available()), num_workers=0)
        valid_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=valid_sampler,
                                                   pin_memory=(torch.cuda.is_available()), num_workers=0)
    else:
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,
                                                   pin_memory=(torch.cuda.is_available()), num_workers=0)
        valid_loader = None

    # Model on cuda
    if torch.cuda.is_available():
        model = model.cuda()

    # Wrap model for multi-GPUs, if necessary
    model_wrapper = model
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        model_wrapper = torch.nn.DataParallel(model).cuda()

    # Optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=args.momentum,
        weight_decay=args.wd
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs],
                                                     gamma=0.1)

    # Start log
    with open(os.path.join(save, 'results.csv'), 'w') as f:
        f.write('epoch,train_loss,train_error,valid_loss,valid_error,test_error\n')

    # Train model
    best_error = 1

    columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']
    if args.swa:
        columns = columns[:-1] + ['swa_te_loss', 'swa_te_acc'] + ['swa_tr_loss', 'swa_tr_acc'] + columns[-1:]
        swa_test_temp = {'loss': None, 'accuracy': None}
        swa_train_temp = {'loss': None, 'accuracy': None}

    swa_n = 0
    for epoch in range(n_epochs):
        time_ep = time.time()
        scheduler.step()
        train_temp = train_epoch(train_loader, model, criterion, optimizer)
        # print('train',train_temp)

        if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
            test_temp = eval(test_loader, model, criterion)
        else:
            test_temp = {'loss': None, 'accuracy': None}
        # print('test',test_temp)

        if args.swa and (epoch + 1) >= (args.swa_start) and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0:
            moving_average(swa_model, model, 1.0 / (swa_n + 1))
            swa_n += 1
            if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
                bn_update(loaders['train'], swa_model)
                swa_test_temp = eval(loaders['test'], swa_model, criterion)
                swa_train_temp = eval(loaders['train'], swa_model, criterion)
            else:
                swa_test_temp = {'loss': None, 'accuracy': None}
                swa_train_temp = {'loss': None, 'accuracy': None}
            # print('swa_train',swa_train_temp)
            # print('swa_test', swa_test_temp)
        time_ep = time.time() - time_ep
        values = [epoch + 1, lr, train_temp['loss'], train_temp['accuracy'], test_temp['loss'], test_temp['accuracy'],
                  time_ep]
        if args.swa:
            values = values[:-1] + [swa_test_temp['loss'], swa_test_temp['accuracy']] + [swa_train_temp['loss'],
                                                                                         swa_train_temp[
                                                                                             'accuracy']] + values[-1:]
        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
        if epoch % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table, file=f_out)
        print(table)
        f_out.flush()
        # Determine if model is the best
        # if valid_loader and valid_error < best_error:
        #     best_error = valid_error
        #     print('New best error: %.4f' % best_error)
        #     torch.save(model.state_dict(), os.path.join(save, 'model.dat'))
        # else:
        #     torch.save(model.state_dict(), os.path.join(save, 'model.dat'))

        # Log results
        # with open(os.path.join(save, 'results.csv'), 'a') as f:
        #     f.write(
        #         (epoch + 1),
        #         train_temp,
        #         valid_loss,
        #         valid_error,
        #     )

    # Final test of model on test set
    # model.load_state_dict(torch.load(os.path.join(save, 'model.dat')))
    # if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    #     model = torch.nn.DataParallel(model).cuda()


def demo(data, save, depth=100, growth_rate=12, efficient=True, valid_size=5000,
         n_epochs=args.epochs, batch_size=args.batch_size, seed=args.seed):
    """
    A demo to show off training of efficient DenseNets.
    Trains and evaluates a DenseNet-BC on CIFAR-10.

    Args:
        data (str) - path to directory where data should be loaded from/downloaded
            (default $DATA_DIR)
        save (str) - path to save the model to (default /tmp)

        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        efficient (bool) - use the memory efficient implementation? (default True)

        valid_size (int) - size of validation set
        n_epochs (int) - number of epochs for training (default 300)
        batch_size (int) - size of minibatch (default 256)
        seed (int) - manually set the random seed (default None)
    """

    # Get densenet configuration
    if (depth - 4) % 3:
        raise Exception('Invalid depth')
    block_config = [(depth - 4) // 6 for _ in range(3)]

    # Data transforms
    mean = [0.5071, 0.4867, 0.4408]
    stdv = [0.2675, 0.2565, 0.2761]
    train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])
    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])

    # Datasets
    train_set = datasets.CIFAR10(data, train=True, transform=train_transforms, download=True)
    test_set = datasets.CIFAR10(data, train=False, transform=test_transforms, download=False)

    # Models
    model = DenseNet(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=10,
        small_inputs=True,
        efficient=efficient,
    )
    print(model)

    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    # Train the model
    train(model=model, train_set=train_set, test_set=test_set, save=save,
          valid_size=valid_size, n_epochs=n_epochs, batch_size=batch_size, seed=seed)
    print('Done!')




if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
    os.makedirs(args.dir, exist_ok=True)
    with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
        f.write(' '.join(sys.argv))
        f.write('\n')
    os.system('cp -r ./' + sys.argv[0] + ' ./' + args.dir + '/')
    save = args.dir
    data = args.data_path

    f_out = open(os.path.join(save, 'output_record.txt'), 'w')
    depth = 100
    growth_rate = 12
    efficient = True
    valid_size = None
    n_epochs = args.epochs
    batch_size = args.batch_size
    seed = args.seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Get densenet configuration
    if (depth - 4) % 3:
        raise Exception('Invalid depth')
    block_config = [(depth - 4) // 6 for _ in range(3)]

    # Data transforms
    mean = [0.5071, 0.4867, 0.4408]
    stdv = [0.2675, 0.2565, 0.2761]
    train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])
    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])

    # Datasets
    ds = getattr(torchvision.datasets, args.dataset)
    path = os.path.join(args.data_path, args.dataset.lower())
    train_set = ds(path, train=True, download=True, transform=train_transforms)
    test_set = ds(path, train=False, download=True, transform=test_transforms)


    loaders = {
        'train': torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True
        ),
        'test': torch.utils.data.DataLoader(
            test_set,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True
        )
    }
    num_classes = max(train_set.train_labels) + 1

    # Models
    model_1 = DenseNet(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=num_classes,
        small_inputs=True,
        efficient=efficient,
    )
    model_1.cuda()
    model_2 = DenseNet(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=num_classes,
        small_inputs=True,
        efficient=efficient,
    )
    model_2.cuda()
    model_temp = DenseNet(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=num_classes,
        small_inputs=True,
        efficient=efficient,
    )
    model_temp.cuda()
    # print(model)

    norm_sum = 0
    for param1 in model_temp.parameters():
        if (args.rand == 'randn'):
            param1.data = torch.randn(param1.data.shape).cuda()
        if (args.rand == 'rand'):
            param1.data = torch.rand(param1.data.shape).cuda()
        norm_temp = torch.norm(param1.data, 2).cuda()
        norm_temp = norm_temp * norm_temp
        norm_sum += norm_temp
    norm_sum = torch.sqrt(norm_sum)
    # print(model_randvector.state_dict().get('conv1.weight'))
    # print('******3',param1.data)
    for param1 in model_temp.parameters():
        param1.data = param1.data / norm_sum
    print(norm_sum, file=f_out)
    print(norm_sum)

    criterion = torch.nn.functional.cross_entropy
    optimizer = torch.optim.SGD(
        model_1.parameters(),
        lr=0.1,
        momentum=args.momentum,
        weight_decay=args.wd
    )
    start_epoch = 0
    if args.model1_resume is not None:
        print('Resume training from %s' % args.model1_resume)
        checkpoint = torch.load(args.model1_resume)
        start_epoch = checkpoint['epoch']
        model_1.load_state_dict(checkpoint['state_dict'])
        bn_update(loaders['train'], model_1)
        print(eval(loaders['train'], model_1, criterion))
    vec_1 = parameters_to_vector(model_1.parameters())

    # vec_inter_norm = torch.norm(vec_inter)
    f_out.flush()

    dis_counter = 0
    result_shape = args.distances * 2  + 1

    train_loss_results_bnupdate = np.zeros(result_shape)
    test_loss_results_bnupdate = np.zeros(result_shape)
    train_acc_results_bnupdate = np.zeros(result_shape)
    test_acc_results_bnupdate = np.zeros(result_shape)

    for i in range(0, int(result_shape), 1):
        print(i)
        print(i, file=f_out)
        for param1, param2, param3 in zip(model_1.parameters(), model_temp.parameters(), model_2.parameters()):
            param3.data = param1.data + (i-args.distances) * param2.data *args.distances_scale
        bn_update(loaders['train'], model_2)

        train_temp = eval(loaders['train'], model_2, criterion)
        test_temp = eval(loaders['test'], model_2, criterion)
        print(train_temp)
        print(train_temp, file=f_out)
        print(test_temp)
        print(test_temp, file=f_out)

        train_loss_results_bnupdate[dis_counter] = train_temp['loss']
        train_acc_results_bnupdate[dis_counter] = train_temp['accuracy']
        test_loss_results_bnupdate[dis_counter] = test_temp['loss']
        test_acc_results_bnupdate[dis_counter] = test_temp['accuracy']

        np.savetxt(os.path.join(args.dir, "train_loss_results.txt"), train_loss_results_bnupdate)
        np.savetxt(os.path.join(args.dir, "test_loss_results.txt"), test_loss_results_bnupdate)
        np.savetxt(os.path.join(args.dir, "train_acc_results.txt"), train_acc_results_bnupdate)
        np.savetxt(os.path.join(args.dir, "test_acc_results.txt"), test_acc_results_bnupdate)
        dis_counter += 1
        # print("test", test_temp)
        # print("sgd exploring on train and test set %d"%(dis_counter))
        f_out.flush()

    # plt.cla()
    # plt.plot(train_loss_results_bnupdate)
    # plt.savefig(os.path.join(args.dir, 'train_loss_results.png'))
    # plt.cla()
    # plt.plot(test_loss_results_bnupdate)
    # plt.savefig(os.path.join(args.dir, 'test_loss_results.png'))
    # plt.cla()
    # plt.plot(train_acc_results_bnupdate)
    # plt.savefig(os.path.join(args.dir, 'train_acc_results.png'))
    # plt.cla()
    # plt.plot(test_acc_results_bnupdate)
    # plt.savefig(os.path.join(args.dir, 'test_acc_results.png'))

    f_out.close()



