import argparse

import numpy as np
import torch
import yaml

import models
import tools


def get_args():
    parser = argparse.ArgumentParser(
        'Computing Certification Bounds for Globally-Robust Neural Networks')

    parser.add_argument('--config', type=str, help='path to the config yaml file')
    parser.add_argument('--depth', type=int, const=None, nargs='?', help='override for model depth')
    parser.add_argument('--width', type=int, const=None, nargs='?', help='override for model width')
    parser.add_argument('--epochs', type=int, const=None, nargs='?', help='override for training epochs')

    # Checkpoint
    parser.add_argument('--checkpoint', type=str, help='path to checkpoint file')
    parser.add_argument('--logfile', type=str, help='target path to store predictions/labels')

    return parser.parse_args()


def main():
    args = get_args()

    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.Loader)

    model_cfg = cfg['model']
    train_cfg = cfg['training']
    dataset_cfg = cfg['dataset']
    gloro_cfg = cfg['gloro']

    # Process config overrides
    if args.depth:
        model_cfg['depth'] = args.depth
    if args.width:
        model_cfg['width'] = args.width
    if args.epochs:
        train_cfg['epochs'] = args.epochs

    _, _, val_loader, _ = tools.data_loader(
        data_name=dataset_cfg['name'],
        batch_size=train_cfg['batch_size'],
        num_classes=dataset_cfg['num_classes'],
        auxiliary=None,
        fraction=args.fraction,
        seed=dataset_cfg.get('seed', 2023))  # if seed is not given, use 2023

    model = models.GloroNet(**model_cfg, **dataset_cfg)
    model = model.cuda()

    # Load checkpoint
    state = torch.load(args.checkpoint)
    model.load_state_dict(state['backbone'])

    # Collect predictions and labels
    predictions, labels = [], []

    sub_lipschitz = model.sub_lipschitz().item()
    model.eval()
    for inputs, targets in val_loader:
        inputs = inputs.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)
        with torch.no_grad():
            y, _, _ = models.margin_layer(model, x=inputs, label=targets, eps=gloro_cfg['eps'],
                                          use_lln=model_cfg['use_lln'], subL=sub_lipschitz, return_loss=False)
            predictions.append(y.detach().cpu().numpy())
            labels.append(targets.detach().cpu().numpy())

    # Save predictions and labels
    predictions, labels = np.concatenate(predictions), np.concatenate(labels)
    np.savez_compressed(args.logfile, predictions=predictions, labels=labels)


if __name__ == '__main__':
    main()
