#%%
import sys
import random

import numpy as np
from matplotlib import pyplot as plt
from torchvision.datasets import MNIST, SVHN
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
import torch
from torch import nn
from torch.nn.utils import spectral_norm

sys.path.append('..')
# from nw_uncertainty.method.nw_method import  NewNW
from nuq.nuq_classifier import NuqClassifier
#%%
SEED = 1

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

train_transforms = transforms.Compose([
    transforms.ToTensor()
])
test_transforms = transforms.Compose([
    transforms.ToTensor()
])

mnist_train = MNIST('../checkpoint/data', download=True, train=True, transform=train_transforms)
mnist_test = MNIST('../checkpoint/data', download=True, train=False, transform=test_transforms)


train_loader = DataLoader(mnist_train, batch_size=512, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=512)

sn = True

if sn:
    class SimpleConv(nn.Module):
        def __init__(self):
            super().__init__()
            width = 32
            self.layers = nn.Sequential(
                spectral_norm(nn.Conv2d(1, 16, 3, padding=1, bias=False)),
                nn.BatchNorm2d(16),
                nn.LeakyReLU(),
                nn.MaxPool2d(2),  # 14x14

                spectral_norm(nn.Conv2d(16, 32, 3, padding=1, bias=False)),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),
                nn.MaxPool2d(2), # 7x7

                spectral_norm(nn.Conv2d(32, 32, 3, padding=1, bias=False)),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),
                nn.AvgPool2d(2, padding=1), # 4x4

                nn.Flatten(),
                spectral_norm(nn.Linear(512, width, bias=False)),
                nn.BatchNorm1d(width),
                nn.LeakyReLU(),
            )

            self.feature = None
            self.linear = nn.Linear(width, 10)

        def forward(self, x):
            out = self.layers(x)
            self.feature = out.clone().detach()
            return self.linear(out)
else:
    class SimpleConv(nn.Module):
        def __init__(self):
            super().__init__()
            width = 32
            self.layers = nn.Sequential(
                nn.Conv2d(1, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.LeakyReLU(),
                nn.MaxPool2d(2),  # 14x14

                nn.Conv2d(16, 32, 3, padding=1, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),
                nn.MaxPool2d(2), # 7x7

                nn.Conv2d(32, 32, 3, padding=1, bias=False),
                nn.BatchNorm2d(32),
                nn.LeakyReLU(),
                nn.AvgPool2d(2, padding=1), # 4x4

                nn.Flatten(),
                nn.Linear(512, width, bias=False),
                nn.BatchNorm1d(width),
                nn.LeakyReLU(),
            )

            self.feature = None
            self.linear = nn.Linear(width, 10)

        def forward(self, x):
            out = self.layers(x)
            self.feature = out.clone().detach()
            return self.linear(out)



def get_model():
    epochs = 10
    model = SimpleConv().cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    model.train()

    for e in range(epochs):
        epoch_losses = []
        for x_batch, y_batch in train_loader:
            x_batch = x_batch.cuda()
            y_batch = y_batch.cuda()
            optimizer.zero_grad()
            preds = model(x_batch)


            loss = criterion(preds, y_batch)
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())
        print(np.mean(epoch_losses))

    model.eval()
    correct = []
    for x_batch, y_batch in test_loader:
        with torch.no_grad():
            x_batch = x_batch.cuda()
            preds = torch.argmax(model(x_batch).cpu(), dim=-1)
            correct.extend((preds == y_batch).tolist())
    print('Accuracy', np.mean(correct))
    return model


model = get_model()


class Ensemble:
    def __init__(self, models):
        self.models = models

    def __call__(self, x):
        with torch.no_grad():
            x_ = torch.mean(
                torch.stack([m(x.cuda()).cpu() for m in self.models]),
                dim=0
            )
        return x_

    def ue(self, x):
        with torch.no_grad():
            x_ = torch.stack([m(x.cuda()).cpu() for m in self.models])
            x_ = torch.mean(torch.softmax(x_, dim=-1), dim=0)
        return torch.sum(-x_ * torch.log(x_), dim=-1)


ensemble = Ensemble([get_model() for _ in range(10)])
print(ensemble(next(iter(train_loader))[0]))

#%%
rotation = (30, 45)
corrupted_transforms = transforms.Compose([
    transforms.RandomRotation(rotation),
    transforms.ToTensor()
])

mnist_corrupted = MNIST('../checkpoint/data', download=True, train=False, transform=corrupted_transforms)
corrupted_loader = DataLoader(mnist_corrupted, batch_size=10_000)
images, labels = next(iter(corrupted_loader))


with torch.no_grad():
    preds = torch.argmax(model(images.cuda()), dim=-1).cpu()
np.mean((preds == labels).numpy())
ood_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

# cifar = CIFAR10(
#     '../checkpoint/data', download=True, train=False, transform=cifar_transforms
# )
ood_dataset = SVHN('./data', split='test', download=True, transform=ood_transforms)
ood_loader = DataLoader(ood_dataset, batch_size=10_000)
images_ood, labels_ood = next(iter(ood_loader))

def entropy(x):
    x_ = torch.softmax(x, dim=-1)
    return torch.sum(-x_*torch.log(x_), dim=-1)

with torch.no_grad():
    preds = model(images.cuda()).cpu()
    preds_ood = model(images_ood.cuda()).cpu()

preds_train = None
for x_batch, y_batch in train_loader:
    with torch.no_grad():
        preds_batch = model(x_batch.cuda()).cpu()
    if preds_train is None:
        preds_train = preds_batch
        y_train = y_batch
    else:
        preds_train = torch.cat((preds_train, preds_batch), dim=0)
        y_train = torch.cat((y_train, y_batch))

preds_train = preds_train.numpy()
y_train = y_train.numpy()


ue_mnist = 1 - torch.max(torch.softmax(preds, dim=-1), dim=-1).values
ue_ood = 1 - torch.max(torch.softmax(preds_ood, dim=-1), dim=-1).values
ue = torch.cat((ue_mnist, ue_ood)).numpy()
ue_entropy = torch.cat((entropy(preds), entropy(preds_ood)))
ue_ensemble = torch.cat((ensemble.ue(images), ensemble.ue(images_ood)))

#%%

train_loader = DataLoader(mnist_train, batch_size=60_000, shuffle=True)
print(train_loader.batch_size)
with torch.no_grad():
    train_images, y_train = next(iter(train_loader))
    y_train = y_train.numpy()
    model(train_images.cuda())
    train_embeddings = model.feature.cpu().numpy()

    model(images.cuda()).cpu()
    embeddings = model.feature.cpu().numpy()
    print(embeddings[:5, :3])

    model(images_ood.cuda())
    embeddings_ood = model.feature.cpu().numpy()
    print(embeddings_ood[:5, :3])

# nuq = NuqClassifier(strategy="isj", tune_bandwidth=True, n_neighbors=20)
# nuq.fit(X=train_embeddings, y=y_train)
# print('Fitted bandwidth', nuq.bandwidth)
#
# ue_nuq = np.concatenate((
#     nuq.predict_uncertainty(embeddings)['total'],
#     nuq.predict_uncertainty(embeddings_ood)['total'],
# ))


EMBEDDINGS = True
n_neighbors = 500


if EMBEDDINGS:
    # bandwidth_rand = np.random.exponential(1, preds_train.shape[0])[:, None]
    nuq = NuqClassifier(tune_bandwidth='classification', n_neighbors=n_neighbors)
    nuq.fit(X=train_embeddings, y=y_train)

    ue_nuq = np.concatenate((
        nuq.predict_proba(embeddings, return_uncertainty='epistemic')[1],
        nuq.predict_proba(embeddings_ood, return_uncertainty='epistemic')[1],
    ))
else:
    nuq = NuqClassifier(tune_bandwidth='classification', n_neighbors=n_neighbors)
    nuq.fit(X=preds_train, y=y_train)

    ue_nuq = np.concatenate((
        nuq.predict_proba(preds.numpy(), return_uncertainty='epistemic')[1],
        nuq.predict_proba(preds_ood.numpy(), return_uncertainty='epistemic')[1],
    ))
print(ue_nuq)


#%%
from image_uncertainty.spectral_normalized_models import (
    gmm_fit, logsumexp
)
gaussians_model, jitter_eps = gmm_fit(
    embeddings=torch.tensor(train_embeddings), labels=torch.tensor(y_train), num_classes=10
)

ues_test_ddu = gaussians_model.log_prob(torch.tensor(embeddings)[:, None, :].float())
ues_test_ddu = -logsumexp(ues_test_ddu).numpy().flatten()
ues_ood_ddu = gaussians_model.log_prob(torch.tensor(embeddings_ood)[:, None, :].float())
ues_ood_ddu = -logsumexp(ues_ood_ddu).numpy().flatten()
ue_ddu = np.concatenate((ues_test_ddu, ues_ood_ddu))

ue_labels = np.concatenate((np.zeros(10000), np.ones(10000)))
xs = np.arange(0, 20001, 200)

def fractions(uncertainty):
    idxs = np.argsort(uncertainty)
    sorted_labels = ue_labels[idxs]
    return [np.sum(sorted_labels[:max_id]) for max_id in xs]

#%%
randomed_sums = fractions(np.random.random((20_000,)))
maxprobed_sums = fractions(ue)
entropy_sums = fractions(ue_entropy)
nwed_sums = fractions(ue_nuq)
ddu_sums = fractions(ue_ddu)
ensemble_sums = fractions(ue_ensemble)
optimal_sums = fractions(ue_labels)



font = {
    'weight': 'normal',
    'size': 18
}

import matplotlib
matplotlib.rc('font', **font)

linewidth = 4
plt.figure(figsize=(9, 8), dpi=150)
plt.subplots_adjust(left=0.21, bottom=0.13, right=0.93)
plt.title('Rotated MNIST vs grayscale SVHN')
plt.ylabel('SVHN objects included')
plt.xlabel('Total objects included')
plt.plot(xs, randomed_sums, label='Random', alpha=0.3, linewidth=linewidth)
plt.plot(xs, optimal_sums, label='Optimal', alpha=0.3, linewidth=linewidth)
plt.plot(xs, maxprobed_sums, label='MaxProb', linestyle='--', color='tab:green', linewidth=linewidth)
plt.plot(xs, entropy_sums, label='Entropy', linestyle=':', color='tab:red', linewidth=linewidth)
plt.plot(xs, ensemble_sums, label='Ensemble', linestyle='-.', color='tab:orange', linewidth=linewidth)
plt.plot(xs, ddu_sums, label='DDU', linestyle='-.', color='tab:cyan', linewidth=linewidth)
plt.plot(xs, nwed_sums, label='NUQ', linestyle='-.', color='tab:purple', linewidth=linewidth)
plt.legend(loc='upper left')
plt.show()
print('done')


#%%
# from sklearn.neighbors import KNeighborsClassifier
# from sklearn.metrics import accuracy_score
# kn_model = KNeighborsClassifier()
# kn_model.fit(train_embeddings, y_train)

# accuracy_score(kn_m)
#%%
# accuracy_score(labels, kn_model.predict(embeddings))

#%%
# predictions = torch.argmax(preds, dim=-1).numpy()
# imgs = images.numpy().reshape(-1, 28, 28)
# import matplotlib.pyplot as plt
# from mpl_toolkits.axes_grid1 import ImageGrid
# import numpy as np
#
# fig = plt.figure(figsize=(12., 3.5))
# fig.set_tight_layout({"pad": 2})
#
# grid = ImageGrid(fig, 111,  # similar to subplot(111)
#                  nrows_ncols=(1, 4),  # creates 2x2 grid of axes
#                  axes_pad=0.5,  # pad between axes in inch.
#                  )
#
# for i, ax in enumerate(grid):
#     # Iterating over the grid returns the Axes.
#     ax.axis('off')
#     ax.imshow(imgs[np.random.randint(len(imgs))], cmap='gray')
#
# plt.show()

#%%

#%%
# np.mean()
# corrects = (torch.argmax(preds, dim=-1) == labels).numpy()
#
# xs = np.arange(0, 10001, 100)
#
# def splits(ues):
#     idxs = np.argsort(ues)
#     sorted_corrects = corrects[idxs]
#     ys = [1] + [np.mean(sorted_corrects[:num]) for num in xs[1:]]
#     return ys
#
# plt.figure(figsize=(6, 5), dpi=150)
# plt.subplots_adjust(left=0.15, bottom=0.13, right=0.95)
# plt.title('Accuracy, MNIST rotated')
# plt.ylabel("Accuracy")
# plt.xlabel("Samples selected")
# plt.plot(xs, splits(ue_mnist), label='MaxProb', linestyle='--', color='tab:green', linewidth=linewidth)
# plt.plot(xs, splits(ue_entropy[:10000].numpy()), label='Entropy', linestyle=':', color='tab:red', linewidth=linewidth)
# plt.plot(xs, splits(ues_test_ddu), label='DDU', linestyle='-', color='tab:cyan', linewidth=linewidth)
# plt.plot(xs, splits(ue_nuq[:10000]), label='NUQ', linestyle='-.', color='tab:purple', linewidth=linewidth)
# plt.legend()
# plt.show()
#
# #%%
# np.mean((torch.argmax(preds_ood, dim=-1) == labels_ood).numpy())
#
# def panel(imgs, num=4):
#     fig = plt.figure(figsize=(num*3.0, 3.5))
#     fig.set_tight_layout({"pad": 2})
#
#     grid = ImageGrid(fig, 111,  # similar to subplot(111)
#                      nrows_ncols=(1, num),  # creates 2x2 grid of axes
#                      axes_pad=0.5,  # pad between axes in inch.
#                      )
#
#
#     for i, ax in enumerate(grid):
#         # Iterating over the grid returns the Axes.
#         ax.axis('off')
#         ax.imshow(imgs[np.random.randint(len(imgs))], cmap='gray')
#
#     plt.show()
#
# #%%
# from numpy.random import random
# from skimage.filters import gaussian
# from image_uncertainty.datasets.smooth_random import SmoothRandom
#
# num = 10
# image_size = (32, 64, 3)
# noise_images = random((num, *image_size))
# radiuses = 1.5 * random(num) + 1
# smoothed = [gaussian(img, r, multichannel=3) for img, r in zip(noise_images, radiuses)]
# panel(smoothed)
