import sys
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib

sys.path.append(".")
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test,
    misclassification_detection
)
from image_uncertainty.cifar.cifar_datasets import get_training_dataloader
from image_uncertainty.cifar import settings
from image_uncertainty.uncertainty.metrics import ood_roc_auc
from nuq import NuqClassifier
from experiments.imagenet_discrete import dump_ues


def get_embeddings(model, loader):
    labels = []
    embeddings = []
    for i, (images, batch_labels) in enumerate(tqdm(loader)):
        with torch.no_grad():
            if args.gpu:
                images = images.cuda()
            embeddings.append(model(images).cpu().numpy())
            # model(images)
            # embeddings.append(model.feature.cpu().numpy())
        labels.extend(batch_labels.tolist())
        # if i == 5:
        #     break

    return np.concatenate(embeddings), np.array(labels)



import ray

font = {
    'weight': 'normal',
    'size': 18
}
matplotlib.rc('font', **font)


def main(args):
    train_loader, val_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=4,
        batch_size=args.b,
        shuffle=True,
        ood_name=args.ood_name,
        seed=args.data_seed
    )

    test_loader = cifar_test(args.b, False, args.ood_name)
    ood_loader = cifar_test(args.b, True, args.ood_name)

    model = load_model(args.net, args.weights, args.gpu)
    model.eval()


    calculate = False

    base_dir = './'
    if calculate:
        x_train, y_train = get_embeddings(model, train_loader)
        x_test, y_test = get_embeddings(model, test_loader)
        x_ood, y_ood = get_embeddings(model, ood_loader)

        with open('t_x_train.npy', 'wb') as f:
            np.save(f, x_train)
        with open('t_y_train.npy', 'wb') as f:
            np.save(f, y_train)
        with open('t_x_test.npy', 'wb') as f:
            np.save(f, x_test)
        with open('t_y_test.npy', 'wb') as f:
            np.save(f, y_test)
        with open('t_x_ood.npy', 'wb') as f:
            np.save(f, x_ood)

        ood_val_loader = cifar_test(args.b, True, 'svhn')
        x_ood_val, y_ood_val = get_embeddings(model, ood_val_loader)
        with open('t_x_ood_val.npy', 'wb') as f:
            np.save(f, x_ood_val)
    else:
        with open(f'{base_dir}t_x_train.npy', 'rb') as f:
            x_train = np.load(f)
        with open(f'{base_dir}t_y_train.npy', 'rb') as f:
            y_train = np.load(f)
        with open(f'{base_dir}t_x_test.npy', 'rb') as f:
            x_test = np.load(f)
        with open(f'{base_dir}t_y_test.npy', 'rb') as f:
            y_test = np.load(f)
        with open(f'{base_dir}t_x_ood.npy', 'rb') as f:
            x_ood = np.load(f)
        with open(f'{base_dir}t_x_ood_val.npy', 'rb') as f:
            x_ood_val = np.load(f)

        print(x_ood_val.shape)

    scores = []
    nuq = NuqClassifier(n_neighbors=500, tune_bandwidth=None, verbose=True)
    nuq.fit(x_train, y_train, bandwidth=np.array([1]))
    grid = None
    if grid is None:
        _, dists = ray.get(
            nuq.index_.knn_query.remote(nuq.X_ref_, return_dist=True)
        )
        dists_mean = dists.mean(axis=0)
        left, right = dists_mean[1], dists_mean[-1]
        mults = np.logspace(-2, 4, 14)
        # grid = np.linspace(left, right, n_points)
        grid = [m*left for m in mults]

    print(grid)
    # for i, mult in enumerate(mults):
    for band in grid:
        nuq.bandwidth_ref_ = ray.put(np.array([band]))
        _, ues_val = nuq.predict_proba(x_test, return_uncertainty='epistemic')
        _, ues_ood = nuq.predict_proba(x_ood_val, return_uncertainty='epistemic')
        scores.append(ood_roc_auc(ues_val, ues_ood))

    best_band = grid[np.argmax(scores)]
    print(scores)

    nuq.bandwidth_ref_ = ray.put(best_band)

    _, ues_test = nuq.predict_proba(x_test, return_uncertainty='epistemic')
    _, ues_ood = nuq.predict_proba(x_ood, return_uncertainty='epistemic')
    try:
        described_plot(
            ues_test, ues_ood, args.ood_name, args.net,
            title_extras='NUQ'
        )
    except:
        import ipdb; ipdb.set_trace()


if __name__ == '__main__':
    args = get_eval_args()
    print(args.__dict__)
    print(args.weights)
    main(args)
