from argparse import ArgumentParser
import sys
from pathlib import Path

import numpy as np
from nuq import NuqClassifier
import matplotlib
import matplotlib.pyplot as plt
from ncvis import NCVis
import ray
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture

sys.path.append('.')
from image_uncertainty.uncertainty.metrics import accuracy_and_ood

from experiments.imagenet_discrete import dump_ues
from scipy.special import softmax

parser = ArgumentParser()
parser.add_argument(
    '--subsample', action='store_true', default=False,
    help='set to subsample train embeddings 1 to 20'
)
parser.add_argument(
    '--base-dir', type=str,
    default='./checkpoint'
)

parser.add_argument(
    '--strategy', default='classification'
)
parser.add_argument(
    '--log-pn', type=int, default=0
)

parser.add_argument(
    '--n-neighbors', type=int, default=50
)

parser.add_argument(
    '--normalize', action='store_true', default=False,
)
parser.add_argument(
    '--ood-name', type=str,
    default='imagenet_o'
)

font = {
    'weight': 'normal',
    'size': 18
}
matplotlib.rc('font', **font)


def visualize(x, y, x_test, y_test, x_ood, ood_name):
    alpha = 0.1

    plt.figure(figsize=(9, 8), dpi=150)
    plt.subplots_adjust(left=0.15, bottom=0.13, right=0.95)

    vis = NCVis()
    dims = vis.fit_transform(x[y < 15])
    plt.scatter(dims[:, 0], dims[:, 1], c=y[y < 15], alpha=alpha, cmap='tab20')
    plt.title('First 15 classes on ImageNet train logits')

    x_ = np.concatenate((x, x_test))
    y_ = np.concatenate((y, y_test))
    num_classes = 15
    x_ = np.concatenate((x_[y_ < num_classes], x_ood))
    y_ood = np.ones(len(x_ood), dtype=int) * -1
    y_ = np.concatenate((y_[y_ < num_classes], y_ood))

    plt.figure(figsize=(9, 8), dpi=150)
    plt.subplots_adjust(left=0.15, bottom=0.13, right=0.95)

    vis = NCVis()
    dims = vis.fit_transform(x_)
    plt.scatter(dims[:, 0], dims[:, 1], c=y_, alpha=0.1, cmap='tab20', label=f'OOD ({ood_name})')
    plt.title('First 15 classes on Imagenet full embeddings')
    plt.legend()
    plt.show()


def k_means(x_train, y_train, x_test, y_test, gmm=False):
    if gmm:
        klass = lambda n: GaussianMixture(n, n_init=2)
    else:
        klass = lambda n: KMeans(n, n_init=20)

    x_train = x_train[y_train == 0]
    x_test = x_test[y_test == 0]
    print(x_train.shape, x_test.shape)
    ys = []
    ns = range(1, 40)
    for n in ns:
        model = klass(n)
        model.fit(x_train)
        ys.append(model.score(x_test))
        print(ys[-1])

    print(ys)
    plt.figure(figsize=(8, 6), dpi=80)
    plt.plot(ns, ys, linewidth=3)
    plt.xlabel('N gaussians')
    plt.ylabel('Score')
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
    model_name = 'GMM' if gmm else 'KMeans'
    plt.title(f"{model_name} score for 0 class of ImageNet")
    plt.show()


def main():
    args = parser.parse_args()
    base_dir = Path(args.base_dir)

    x_train = np.load(str(base_dir / f'train_embeddings.npy'))
    y_train = np.load(str(base_dir / f'train_targets.npy'))
    x_val = np.load(str(base_dir / f'val_embeddings.npy'))
    y_val = np.load(str(base_dir / f'val_targets.npy'))
    x_ood = np.load(str(base_dir / f'ood_embeddings_{args.ood_name}.npy'))

    if args.subsample:
        idx = range(0, len(x_train), 50)
        x_train = x_train[idx]
        y_train = y_train[idx]

    if args.normalize:
        from sklearn.preprocessing import StandardScaler
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_val = scaler.transform(x_val)
        x_ood = scaler.transform(x_ood)

    print(x_train.shape, len(x_val), len(x_ood))

    nuq = NuqClassifier(
        tune_bandwidth=args.strategy, n_neighbors=args.n_neighbors, verbose=True,
        log_pN=args.log_pn
    )

    nuq.fit(X=x_train, y=y_train)
    # nuq.fit(X=x_train, y=y_train, bandwidth=np.array([40]))
    # nuq.bandwidth_ref_ = ray.put(np.array(0.5))
    print(ray.get(nuq.bandwidth_ref_))
    ue_type = 'epistemic'
    preds, ues_test = nuq.predict_proba(x_val, return_uncertainty=ue_type)
    _, ues_ood = nuq.predict_proba(x_ood, return_uncertainty=ue_type)

    accuracy_and_ood(ues_test, ues_ood, preds.toarray(), y_val)
    dump_ues(ues_test, ues_ood, f'nuq_{ue_type}', 'imagenet', args.ood_name)

    # visualize(x_train, y_train, x_val, y_val, x_ood, args.ood_name)


if __name__ == '__main__':
    main()
