import argparse

import torch
from autoattack import AutoAttack

import models
from tools import data_loader


def get_args():
    parser = argparse.ArgumentParser(
        'Auto Attack for Globally-Robust Neural Networks')

    # Model
    parser.add_argument('--model', default='ReScale', type=str)
    parser.add_argument('--dataset', default='cifar10', type=str)
    parser.add_argument('--num_classes', default=10, type=int)
    parser.add_argument('--use_lln',
                        action='store_true',
                        default=False,
                        help='last layer normalization')
    parser.add_argument('--depth',
                        default=12,
                        type=int,
                        help='number of residual blocks or conv layers')
    parser.add_argument('--width',
                        default=128,
                        type=int,
                        help='number of channels in the backbone')
    # For Linear Residual Block
    parser.add_argument('--use_affine', action='store_true', default=False)
    parser.add_argument('--use_bias', action='store_true', default=False)
    parser.add_argument('--use_rescale', action='store_true', default=False)
    parser.add_argument('--use_wn', action='store_true', default=False)
    parser.add_argument('--use_centering', action='store_true', default=False)

    parser.add_argument('--input_size', default=32, type=int)
    parser.add_argument('--epsilon', default=0.14117647, type=float)  # 36/255
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--checkpoint',
                        type=str,
                        help='the checkpoint to verify')

    args = parser.parse_args()

    if args.dataset == 'cifar10':
        args.num_classes, args.input_size = 10, 32
    elif args.dataset == 'cifar100':
        args.num_classes, args.input_size = 100, 32
    elif args.dataset == 'tiny_imagenet':
        args.num_classes, args.input_size = 200, 64
    else:
        raise ValueError('dataset %s not supported.' % args.dataset)
    return args


@torch.no_grad()
def main():
    args = get_args()
    torch.backends.cudnn.benchmark = True

    _, _, val_loader, _ = data_loader(batch_size=args.batch_size,
                                      num_classes=args.num_classes,
                                      distributed=False)

    model = models.__dict__[args.model](**vars(args))
    params = torch.load(args.checkpoint, 'cpu')['backbone']
    model.load_state_dict(params, strict=False)
    model.eval()
    print(model)

    model = model.cuda()
    adversary = AutoAttack(model,
                           norm='L2',
                           eps=args.epsilon,
                           version='standard')

    all_inputs, all_targets = [], []
    for inputs, targets in val_loader:
        all_inputs.append(inputs)
        all_targets.append(targets)

    all_inputs = torch.cat(all_inputs, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    _ = adversary.run_standard_evaluation(all_inputs,
                                          all_targets,
                                          bs=args.batch_size)


if __name__ == '__main__':
    main()
