from image_uncertainty.cifar import settings
from image_uncertainty.cifar.cifar_datasets import get_training_dataloader
from image_uncertainty.uncertainty.calibration import (
    draw_reliability_graph, scale_temperature, WrappedModel
)

from image_uncertainty.cifar.cifar_evaluate import (
    evaluate, load_model, inference, get_eval_args
)


def main():
    args = get_eval_args()
    print(args.__dict__)

    model = load_model(args.net, args.weights, args.gpu)

    training_loader, val_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=4,
        batch_size=args.b,
        shuffle=True,
        ood_name=args.ood_name
    )

    _, probs, labels = inference(model, val_loader, args.gpu)
    draw_reliability_graph(probs, labels, args.ood_name, 'uncalibrated')
    evaluate(model, args)

    temperature = scale_temperature(model, val_loader)

    model_t = WrappedModel(model, temperature)

    _, probs, labels = inference(model_t, val_loader, args.gpu)
    draw_reliability_graph(probs, labels, args.ood_name, 'calibrated')
    evaluate(model_t, args, title_extras=' calibrated')


if __name__ == '__main__':
    main()
