import sys

import torch

sys.path.append('.')
from image_uncertainty.cifar.cifar_evaluate import  default_weights
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test, misclassification_detection
)
from image_uncertainty.uncertainty.methods import mcd_ue, mcd_runs


"""
Ensemble uncertainty estimation
It assumes you already trained the models and put them in checkpoint folder
"""

MODEL_SIZE = 5


class EnsembleWrapper:
    def __init__(self, models):
        self.models = models

    def __call__(self, x):
        preds = torch.stack([model(x) for model in self.models])
        averaged = torch.mean(torch.softmax(preds, dim=-1), dim=0)
        unsoftmaxed = torch.log(averaged)
        return unsoftmaxed


def main():
    args = get_eval_args()
    print(args.__dict__)

    test_loader = cifar_test(args.b, False, args.ood_name)

    MODEL_SIZE=10
    models = []
    for i in range(MODEL_SIZE):
        weights = default_weights(args.net, args.ood_name, args.data_seed, i)
        model = load_model(args.net, weights, args.gpu)
        models.append(model)


    id_runs, correct = mcd_runs(test_loader, models, args.gpu, ensemble=True)
    accuracy = sum(correct) / len(correct)

    for ood_name in ['svhn', 'lsun', 'smooth']:
        print(); print()
        print(ood_name)
        args.ood_name = ood_name
        ood_loader = cifar_test(args.b, True, args.ood_name)
        ood_runs, _ = mcd_runs(ood_loader, models, args.gpu, ensemble=True)

        for acquisition in ['max_prob', 'entropy', 'std', 'bald']:
            print()
            print(acquisition)
            for model_size in [2, 3, 5, 7, 10]:
                print(model_size)
                ues = mcd_ue(id_runs[:model_size], acquisition)
                # misclassification_detection(correct, ues)
                ood_ues = mcd_ue(ood_runs[:model_size], acquisition)

                described_plot(
                    ues, ood_ues, args.ood_name, args.net, accuracy,
                    f'ensemble ({acquisition}), T={len(models)}'
                )

    # # Explore the single model performance variation on smooth
    # # As it seems to vary a lot
    # for ood_name in ['smooth']:
    #     print(); print()
    #     print(ood_name)
    #     args.ood_name = ood_name
    #     ood_loader = cifar_test(args.b, True, args.ood_name)
    #
    #     for acquisition in ['entropy']:
    #         print()
    #         print(acquisition)
    #         for i in range(10):
    #             print(i)
    #             id_runs, correct = mcd_runs(test_loader, models[i:i+1], args.gpu, ensemble=True)
    #             accuracy = sum(correct) / len(correct)
    #             ood_runs, _ = mcd_runs(ood_loader, models[i:i+1], args.gpu, ensemble=True)
    #             ues = mcd_ue(id_runs, acquisition)
    #             # misclassification_detection(correct, ues)
    #             ood_ues = mcd_ue(ood_runs, acquisition)
    #
    #             described_plot(
    #                 ues, ood_ues, args.ood_name, args.net, accuracy,
    #                 title_extras=f'ensemble ({acquisition}), T={len(models)}'
    #             )

if __name__ == '__main__':
    main()
