from argparse import ArgumentParser
import sys
from pathlib import Path

import numpy as np
from nuq import NuqClassifier
import matplotlib
import matplotlib.pyplot as plt
import torch
from torch import nn
from tqdm import tqdm

sys.path.append('.')
from image_uncertainty.cifar.cifar_evaluate import described_plot
from image_uncertainty.uncertainty.metrics import ood_roc_auc
from experiments.imagenet_discrete import dump_ues
from scipy.special import softmax

parser = ArgumentParser()
parser.add_argument(
    '--subsample', action='store_true', default=False,
    help='set to subsample train embeddings 1 to 20'
)
parser.add_argument(
    '--base-dir', type=str,
    default='/home/mephody_bro/imagenet_embeddings_full'
)

parser.add_argument(
    '--ood-name', type=str,
    default='imagenet_o'
)
parser.add_argument('--gamma', default=0.9999, type=float)
parser.add_argument('--length-scale', default=1.0, type=float)
parser.add_argument(
    '--architecture', default='linear', choices=["linear", "multilinear"]
)

font = {
    'weight': 'normal',
    'size': 18
}
matplotlib.rc('font', **font)

def calc_gradient_penalty(x, y_pred):
    gradients = torch.autograd.grad(
        outputs=y_pred,
        inputs=x,
        grad_outputs=torch.ones_like(y_pred),
        create_graph=True,
    )[0]

    gradients = gradients.flatten(start_dim=1)

    # L2 norm
    grad_norm = gradients.norm(2, dim=1)

    # Two sided penalty
    gradient_penalty = ((grad_norm - 1) ** 2).mean()

    # One sided penalty - down
    #     gradient_penalty = F.relu(grad_norm - 1).mean()

    return gradient_penalty


def main():
    args = parser.parse_args()
    base_dir = Path(args.base_dir)

    x_train = np.load(str(base_dir / f'train_embeddings.npy'))
    y_train = np.load(str(base_dir / f'train_targets.npy'))
    x_val = np.load(str(base_dir / f'val_embeddings.npy'))
    y_val = np.load(str(base_dir / f'val_targets.npy'))
    x_ood_o = np.load(str(base_dir / f'ood_embeddings_imagenet_o.npy'))
    x_ood_r = np.load(str(base_dir / f'ood_embeddings_imagenet_r.npy'))

    if args.subsample:
        idx = range(0, len(x_train), 10)
        x_train = x_train[idx]
        y_train = y_train[idx]

    print(x_train.shape, len(x_val), len(x_ood_o))

    batch_size = 500
    batch_size_ = 500
    l_gradient_penalty = 0.00
    sigma = 15
    emb_size = 128


    import torch.nn.functional as F
    def benchmark(dl_test, model2, epoch=0, loss=0):
        x, y = next(iter(dl_test))
        x = x.cuda()
        y = y.cuda()
        if l_gradient_penalty != 0:
            x.requires_grad_(True)

        y_pred = model2(x)
        accuracy = (torch.sum(torch.argmax(y, dim=-1) == torch.argmax(y_pred, dim=-1)) / len(y)).item()
        bce = F.binary_cross_entropy(y_pred, y).item()

        if l_gradient_penalty != 0:
            gp = l_gradient_penalty*calc_gradient_penalty(x, y_pred)
        else:
            gp = 0

        print(f"{epoch}: {accuracy:.3f}, {bce:.3f}, {gp:.3f}, {loss:.3f}")

    ds_train = torch.utils.data.TensorDataset(torch.from_numpy(x_train).float(),
                                              F.one_hot(torch.from_numpy(y_train)).float())
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

    ds_test = torch.utils.data.TensorDataset(torch.from_numpy(x_val).float(),
                                             F.one_hot(torch.from_numpy(y_val)).float())
    dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size_, shuffle=False)

    ds_ood_o = torch.utils.data.TensorDataset(torch.from_numpy(x_ood_o).float(),
                                             F.one_hot(torch.zeros(len(x_ood_o)).long()).float())
    dl_ood_o = torch.utils.data.DataLoader(ds_ood_o, batch_size=batch_size_, shuffle=False)

    ds_ood_r = torch.utils.data.TensorDataset(torch.from_numpy(x_ood_r).float(),
                                            F.one_hot(torch.zeros(len(x_ood_r)).long()).float())
    dl_ood_r = torch.utils.data.DataLoader(ds_ood_r, batch_size=batch_size_, shuffle=False)


    gamma = args.gamma
    model_output_size = 2048
    centroid_size = 64
    length_scale = args.length_scale
    epochs = 1
    milestones = [1, 2, 3]
    learning_rate = 1e-2
    weight_decay = 1e-4
    num_classes = 1000

    from image_uncertainty.models.duq import LinearCentroids, MultiLinearCentroids
    if args.architecture == 'linear':
        klass = LinearCentroids
    else:
        klass = MultiLinearCentroids
    feature_extractor = nn.Identity()
    head = klass(
        num_classes=num_classes,
        gamma=gamma,
        embedding_size=model_output_size,
        features=centroid_size,
        feature_extractor=feature_extractor,
        batch_size=batch_size,
        sigma=length_scale
    ).cuda()

    # head = Head(2048, 1000).cuda()
    head_optimizer = torch.optim.SGD(head.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        head_optimizer, milestones=milestones, gamma=0.2
    )

    def get_ues(dl, model):
        ues = []
        for i, (x, _) in enumerate(tqdm(dl)):
            x = x.cuda()
            output = model(x)
            ues.extend((-1*output.max(1)[0].cpu().detach()).tolist())
        return np.array(ues)

    for e in range(5):
        head.train()
        for i, (x, y) in enumerate(tqdm(dl_train)):
            x = x.cuda()
            y = y.cuda()
            head_optimizer.zero_grad()
            x.requires_grad_(True)

            y_pred = head(x)

            loss = F.binary_cross_entropy(y_pred, y)
            if l_gradient_penalty != 0:
                loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred)

            loss.backward()
            head_optimizer.step()

            with torch.no_grad():
                head.update_embeddings(x, y)

            if (i+1) % 100 == 0:
                benchmark(dl_test, head, e, loss.item())
                head.train()
        scheduler.step()


        ues_test = get_ues(dl_test, head)
        ues_ood = get_ues(dl_ood_o, head)
        score = ood_roc_auc(ues_test, ues_ood)
        print('OOD ROC AUC o', score)
        ues_ood = get_ues(dl_ood_r, head)
        score = ood_roc_auc(ues_test, ues_ood)
        print('OOD ROC AUC r', score)
    # described_plot(
    #     ues_test, ues_ood, args.ood_name, 'spectral', title_extras=f'DUQ SN'
    # )
    torch.save(head.state_dict(), f'experiments/checkpoint/duq/imagenet_{args.architecture}_head.pth')


if __name__ == '__main__':
    main()