import fire
import os
import sys
import time
import torch
import torchvision
from torchvision import datasets, transforms
from models import DenseNet
import argparse
import tabulate
import numpy as np

parser = argparse.ArgumentParser(description='dense net training')
parser.add_argument('--cuda_visible_devices', type=str, default='0', help='cuda_visible_devices (default: GPU0)')

parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)')
parser.add_argument('--dataset', type=str, default='CIFAR100', help='dataset name (default: CIFAR100)')
parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH',
                    help='path to datasets location (default: None)')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)')
parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)')

parser.add_argument('--resume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume training from (default: None)')
parser.add_argument('--swa_resume', type=str, default=None, metavar='CKPT',
                    help='checkpoint to resume training from (default: None)')

parser.add_argument('--lr_set', type=float, default=0.01, metavar='LR', help='set learning rate (default: 0.01)')

parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)')
parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)')
parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)')

parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--swa_start', type=float, default=161, metavar='N', help='SWA start epoch number (default: 161)')
parser.add_argument('--swa_lr', type=float, default=0.05, metavar='LR', help='SWA LR (default: 0.05)')
parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N',
                    help='SWA model collection frequency/cycle length in epochs (default: 1)')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args()


class AverageMeter(object):
    """
    Computes and stores the average and current value
    Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def moving_average(net1, net2, alpha=1):
    for param1, param2 in zip(net1.parameters(), net2.parameters()):
        param1.data *= (1.0 - alpha)
        param1.data += param2.data * alpha


def _check_bn(module, flag):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        flag[0] = True


def check_bn(model):
    flag = [False]
    model.apply(lambda module: _check_bn(module, flag))
    return flag[0]


def reset_bn(module):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.running_mean = torch.zeros_like(module.running_mean)
        module.running_var = torch.ones_like(module.running_var)


def _get_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        momenta[module] = module.momentum


def _set_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.momentum = momenta[module]


def bn_update(loader, model):
    """
        BatchNorm buffers update (if any).
        Performs 1 epochs to estimate buffers average using train dataset.

        :param loader: train dataset loader for buffers average estimation.
        :param model: model being update
        :return: None
    """
    if not check_bn(model):
        return
    model.train()
    momenta = {}
    model.apply(reset_bn)
    model.apply(lambda module: _get_momenta(module, momenta))
    n = 0
    for input, _ in loader:
        input = input.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        b = input_var.data.size(0)

        momentum = b / (n + b)
        for module in momenta.keys():
            module.momentum = momentum

        model(input_var)
        n += b

    model.apply(lambda module: _set_momenta(module, momenta))

def save_checkpoint(dir, epoch, **kwargs):
    state = {
        'epoch': epoch,
    }
    state.update(kwargs)
    filepath = os.path.join(dir, 'checkpoint-%d.pt' % epoch)
    torch.save(state, filepath)

def train_epoch(loader, model, criterion, optimizer):
    loss_sum = 0.0
    correct = 0.0

    model.train()

    for i, (input, target) in enumerate(loader):
        input = input.cuda(async=True)
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        output = model(input_var)
        loss = criterion(output, target_var)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_sum += loss.data[0] * input.size(0)
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target_var.data.view_as(pred)).sum().item()
    # print(len(loader.dataset))
    return {
        'loss': loss_sum / len(loader.dataset),
        'accuracy': correct / len(loader.dataset) * 100.0,
    }


def eval(loader, model, criterion):
    loss_sum = 0.0
    correct = 0.0

    model.eval()

    for i, (input, target) in enumerate(loader):
        input = input.cuda(async=True)
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        output = model(input_var)
        loss = criterion(output, target_var)

        loss_sum += loss.data[0] * input.size(0)
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target_var.data.view_as(pred)).sum().item()

    # print(len(loader.dataset))
    return {
        'loss': loss_sum / len(loader.dataset),
        # 'loss': loss_sum,
        'accuracy': correct / len(loader.dataset) * 100.0,

    }


def train(model, swa_model, train_set, test_set, save, n_epochs=300, valid_size=None,
          batch_size=64, lr=0.1, wd=0.0001, momentum=0.9, seed=None):
    if seed is not None:
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    # Create train/valid split
    # if valid_size:
    #     indices = torch.randperm(len(train_set))
    #     train_indices = indices[:len(indices) - valid_size]
    #     train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
    #     valid_indices = indices[len(indices) - valid_size:]
    #     valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indices)

    # Data loaders
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False,
                                              pin_memory=(torch.cuda.is_available()), num_workers=0)
    if valid_size:
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=train_sampler,
                                                   pin_memory=(torch.cuda.is_available()), num_workers=0)
        valid_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=valid_sampler,
                                                   pin_memory=(torch.cuda.is_available()), num_workers=0)
    else:
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,
                                                   pin_memory=(torch.cuda.is_available()), num_workers=0)
        valid_loader = None

    # Model on cuda
    if torch.cuda.is_available():
        model = model.cuda()

    # Wrap model for multi-GPUs, if necessary
    model_wrapper = model
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        model_wrapper = torch.nn.DataParallel(model).cuda()

    # Optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=args.momentum,
        weight_decay=args.wd
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs],
                                                     gamma=0.1)

    # Start log
    with open(os.path.join(save, 'results.csv'), 'w') as f:
        f.write('epoch,train_loss,train_error,valid_loss,valid_error,test_error\n')

    # Train model
    best_error = 1

    columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']
    if args.swa:
        columns = columns[:-1] + ['swa_te_loss', 'swa_te_acc'] + ['swa_tr_loss', 'swa_tr_acc'] + columns[-1:]
        swa_test_temp = {'loss': None, 'accuracy': None}
        swa_train_temp = {'loss': None, 'accuracy': None}

    swa_n = 0
    for epoch in range(n_epochs):
        time_ep = time.time()
        scheduler.step()
        train_temp = train_epoch(train_loader, model, criterion, optimizer)
        # print('train',train_temp)

        if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
            test_temp = eval(test_loader, model, criterion)
        else:
            test_temp = {'loss': None, 'accuracy': None}
        # print('test',test_temp)

        if args.swa and (epoch + 1) >= (args.swa_start) and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0:
            moving_average(swa_model, model, 1.0 / (swa_n + 1))
            swa_n += 1
            if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
                bn_update(loaders['train'], swa_model)
                swa_test_temp = eval(loaders['test'], swa_model, criterion)
                swa_train_temp = eval(loaders['train'], swa_model, criterion)
            else:
                swa_test_temp = {'loss': None, 'accuracy': None}
                swa_train_temp = {'loss': None, 'accuracy': None}
            # print('swa_train',swa_train_temp)
            # print('swa_test', swa_test_temp)
        time_ep = time.time() - time_ep
        values = [epoch + 1, lr, train_temp['loss'], train_temp['accuracy'], test_temp['loss'], test_temp['accuracy'],
                  time_ep]
        if args.swa:
            values = values[:-1] + [swa_test_temp['loss'], swa_test_temp['accuracy']] + [swa_train_temp['loss'],
                                                                                         swa_train_temp[
                                                                                             'accuracy']] + values[-1:]
        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
        if epoch % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table, file=f_out)
        print(table)
        f_out.flush()
        # Determine if model is the best
        # if valid_loader and valid_error < best_error:
        #     best_error = valid_error
        #     print('New best error: %.4f' % best_error)
        #     torch.save(model.state_dict(), os.path.join(save, 'model.dat'))
        # else:
        #     torch.save(model.state_dict(), os.path.join(save, 'model.dat'))

        # Log results
        # with open(os.path.join(save, 'results.csv'), 'a') as f:
        #     f.write(
        #         (epoch + 1),
        #         train_temp,
        #         valid_loss,
        #         valid_error,
        #     )

    # Final test of model on test set
    # model.load_state_dict(torch.load(os.path.join(save, 'model.dat')))
    # if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    #     model = torch.nn.DataParallel(model).cuda()


def demo(data, save, depth=100, growth_rate=12, efficient=True, valid_size=5000,
         n_epochs=args.epochs, batch_size=args.batch_size, seed=args.seed):
    """
    A demo to show off training of efficient DenseNets.
    Trains and evaluates a DenseNet-BC on CIFAR-10.

    Args:
        data (str) - path to directory where data should be loaded from/downloaded
            (default $DATA_DIR)
        save (str) - path to save the model to (default /tmp)

        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        efficient (bool) - use the memory efficient implementation? (default True)

        valid_size (int) - size of validation set
        n_epochs (int) - number of epochs for training (default 300)
        batch_size (int) - size of minibatch (default 256)
        seed (int) - manually set the random seed (default None)
    """

    # Get densenet configuration
    if (depth - 4) % 3:
        raise Exception('Invalid depth')
    block_config = [(depth - 4) // 6 for _ in range(3)]

    # Data transforms
    mean = [0.5071, 0.4867, 0.4408]
    stdv = [0.2675, 0.2565, 0.2761]
    train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])
    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])

    # Datasets
    train_set = datasets.CIFAR10(data, train=True, transform=train_transforms, download=True)
    test_set = datasets.CIFAR10(data, train=False, transform=test_transforms, download=False)

    # Models
    model = DenseNet(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=10,
        small_inputs=True,
        efficient=efficient,
    )
    print(model)

    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    # Train the model
    train(model=model, train_set=train_set, test_set=test_set, save=save,
          valid_size=valid_size, n_epochs=n_epochs, batch_size=batch_size, seed=seed)
    print('Done!')


"""
A demo to show off training of efficient DenseNets.
Trains and evaluates a DenseNet-BC on CIFAR-10.

Try out the efficient DenseNet implementation:
python demo.py --efficient True --data <path_to_data_dir> --save <path_to_save_dir>

Try out the naive DenseNet implementation:
python demo.py --efficient False --data <path_to_data_dir> --save <path_to_save_dir>

Other args:
    --depth (int) - depth of the network (number of convolution layers) (default 40)
    --growth_rate (int) - number of features added per DenseNet layer (default 12)
    --n_epochs (int) - number of epochs for training (default 300)
    --batch_size (int) - size of minibatch (default 256)
    --seed (int) - manually set the random seed (default None)
"""
if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
    os.makedirs(args.dir, exist_ok=True)
    with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
        f.write(' '.join(sys.argv))
        f.write('\n')
    os.system('cp -r ./' + sys.argv[0] + ' ./' + args.dir + '/')
    save = args.dir
    data = args.data_path

    f_out = open(os.path.join(save, 'output_record.txt'), 'w')
    depth = 100
    growth_rate = 12
    efficient = True
    valid_size = None
    n_epochs = args.epochs
    batch_size = args.batch_size
    seed = args.seed

    # Get densenet configuration
    if (depth - 4) % 3:
        raise Exception('Invalid depth')
    block_config = [(depth - 4) // 6 for _ in range(3)]

    # Data transforms
    mean = [0.5071, 0.4867, 0.4408]
    stdv = [0.2675, 0.2565, 0.2761]
    train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])
    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv),
    ])

    # Datasets
    ds = getattr(torchvision.datasets, args.dataset)
    path = os.path.join(args.data_path, args.dataset.lower())
    train_set = ds(path, train=True, download=True, transform=train_transforms)
    test_set = ds(path, train=False, download=True, transform=test_transforms)


    loaders = {
        'train': torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True
        ),
        'test': torch.utils.data.DataLoader(
            test_set,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True
        )
    }
    num_classes = max(train_set.train_labels) + 1

    # Models
    model = DenseNet(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=num_classes,
        small_inputs=True,
        efficient=efficient,
    )
    model.cuda()
    # print(model)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.1,
        momentum=args.momentum,
        weight_decay=args.wd
    )
    if args.swa_resume is not None:
        print('Resume training from %s' % args.swa_resume, file=f_out)
        checkpoint = torch.load(args.swa_resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['swa_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    criterion = torch.nn.functional.cross_entropy


    # Make save directory
    # if not os.path.exists(save):
    #     os.makedirs(save)
    # if not os.path.isdir(save):
    #     raise Exception('%s is not a dir' % save)
    columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']

    train_res_swa = eval(loaders['train'], model, criterion)
    test_res_swa = eval(loaders['test'], model, criterion)
    print(train_res_swa)
    print(test_res_swa)
    args.epochs = start_epoch + 50

    for epoch in range(start_epoch, args.epochs):
        time_ep = time.time()

        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs],
                                                         gamma=0.1)
        scheduler.step(epoch)

        # if epoch < 0.5 * n_epochs:
        #     lr = 0.5
        # elif epoch < 0.75 * n_epochs:
        #     lr = 0.05
        # else:
        #     lr = 0.005
        lr = args.lr_set

        adjust_learning_rate(optimizer, lr)
        train_res = train_epoch(loaders['train'], model, criterion, optimizer)
        if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
            test_res = eval(loaders['test'], model, criterion)
        else:
            test_res = {'loss': None, 'accuracy': None}

        if train_res['loss'] < train_res_swa['loss'] and test_res['loss'] > test_res_swa['loss']:
            print('find', file=f_out)
            print('find')
            save_checkpoint(
                args.dir,
                epoch + 1,
                state_dict=model.state_dict(),
                optimizer=optimizer.state_dict()
            )



        time_ep = time.time() - time_ep
        values = [epoch + 1, lr, train_res['loss'], train_res['accuracy'], test_res['loss'], test_res['accuracy'],
                  time_ep]
        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
        if epoch % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table, file=f_out)
        print(table)
        f_out.flush()
    # Train the model
    # train(model=model, swa_model=swa_model, train_set=train_set, test_set=test_set, save=save,
    #       valid_size=valid_size, n_epochs=n_epochs, batch_size=batch_size, seed=seed)
    print('Done!')
    f_out.close()
    os.system('echo "' + ' '.join(sys.argv) + '" | mail -s "a program finished just now" 962086838@qq.com')