import argparse
import logging

import torch

from lip_convnets import LipConvNet, LipResNet
from utils import *

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(x):
        return x

logger = logging.getLogger(__name__)


def get_args():
    parser = argparse.ArgumentParser()

    # Training specifications
    parser.add_argument('--batch-size', default=256, type=int)
    parser.add_argument('--epochs', default=200, type=int)
    parser.add_argument('--lr-min', default=0., type=float)
    parser.add_argument('--lr-max', default=0.1, type=float)
    parser.add_argument('--scheduler', default='multi-step', type=str, choices=['multi-step', 'one-cycle'])
    parser.add_argument('--weight-decay', default=5e-4, type=float)
    parser.add_argument('--momentum', default=0.9, type=float)
    parser.add_argument('--gamma', default=0., type=float, help='gamma for certificate regularization')
    parser.add_argument('--opt-level', default='O2', type=str, choices=['O0', 'O2'],
                        help='O0 is FP32 training and O2 is "Almost FP16" Mixed Precision')
    parser.add_argument('--loss-scale', default='1.0', type=str, choices=['1.0', 'dynamic'],
                        help='If loss_scale is "dynamic", adaptively adjust the loss scale over time')

    # Auxiliary specifications
    parser.add_argument('--auxiliary-dir', default=None, type=str)
    parser.add_argument('--auxiliary', default=None, type=str)
    parser.add_argument('--fraction', default=0.7, type=float)

    # Checkpoint
    parser.add_argument('--checkpoint', type=str, help='path to checkpoint file')
    parser.add_argument('--logfile', type=str, help='target path to store predictions/labels')

    # Model architecture specifications
    parser.add_argument('--conv-layer', default='soc', type=str, choices=['bcop', 'cayley', 'soc', 'lot'],
                        help='BCOP, Cayley, SOC, LOT convolution')
    parser.add_argument('--init-channels', default=32, type=int)
    parser.add_argument('--activation', default='maxmin', choices=['maxmin', 'hh1', 'hh2'],
                        help='Activation function')
    parser.add_argument('--block-size', default=1, type=int, choices=[1, 2, 3, 4, 5, 6, 7, 8],
                        help='size of each block')
    parser.add_argument('--lln', action='store_true', help='set last linear to be linear and normalized')
    parser.add_argument('--residual', action='store_true', help='residual')

    # Dataset specifications
    parser.add_argument('--data-dir', default='./data', type=str)
    parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'cifar100'],
                        help='dataset to use for training')

    # Other specifications
    parser.add_argument('--epsilon', default=36, type=int)
    parser.add_argument('--out-dir', default='LipConvnet', type=str, help='Output directory')
    parser.add_argument('--seed', default=0, type=int, help='Random seed')
    return parser.parse_args()


def init_model(args):
    if args.residual:
        model = LipResNet(args.conv_layer, args.activation, init_channels=args.init_channels,
                          block_size=args.block_size, num_classes=args.num_classes,
                          lln=args.lln)
    else:
        model = LipConvNet(args.conv_layer, args.activation, init_channels=args.init_channels,
                           block_size=args.block_size, num_classes=args.num_classes,
                           lln=args.lln)
    return model


def main():
    args = get_args()

    if args.conv_layer in ['cayley', 'ot'] and args.opt_level == 'O2':
        raise ValueError('O2 optimization level is incompatible with Cayley/LOT Convolution')

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_loader, test_loader = get_loaders(args.data_dir, args.batch_size, args.dataset, auxiliary=None,
                                            fraction=args.fraction)
    std = cifar10_std
    if args.dataset == 'cifar10':
        args.num_classes = 10
    elif args.dataset == 'cifar100':
        args.num_classes = 100
    else:
        raise Exception('Unknown dataset')

    # Training
    std = torch.tensor(std).cuda()
    L = 1 / torch.max(std)

    # Evaluation at best model (early stopping)
    model_test = init_model(args).cuda()
    model_test.load_state_dict(torch.load(args.checkpoint))
    model_test.float()
    model_test.eval()

    _, _, certificates, predictions, labels = evaluate_certificates(test_loader, model_test, L, return_labels=True)
    np.savez_compressed(args.logfile, certificates=certificates, predictions=predictions, labels=labels)


if __name__ == "__main__":
    main()
