#%%
import gc
from pathlib import Path
from argparse import ArgumentParser

import numpy as np
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed

from tqdm import tqdm

import sys
sys.path.append('.')
from experiments.imagenet_discrete import test_loaders, load_model, parse_args, show
from experiments.imagenet_ddu import get_train_loader


if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


def embeddings(loader, model, x_type):
    # switch to evaluate mode
    model.eval()
    # logits = []
    embeddings = []
    targets = []


    with torch.no_grad():
        for i, (images, target) in tqdm(enumerate(loader)):
            images = images.to(device)

            # compute output
            output = model(images)
            # logits.append(output.cpu().detach().numpy())
            targets.extend(target.cpu().detach().numpy())
            embeddings.append(model.feature.cpu().numpy())

            # if i == 1:
            #     break
    return np.concatenate(embeddings), np.array(targets)


def main():
    args = parse_args()
    val_loader, ood_loader = test_loaders(args.data_folder, args.ood_folder, args.b)
    train_loader = get_train_loader(batch_size=args.b)

    args.dir_to_save = Path('checkpoint')

    model = load_model(args.net)

    def save(arr, name):
        with open(args.dir_to_save / f"{name}.npy", 'wb') as f:
            np.save(f, arr)

    train_embeddings, train_targets = embeddings(train_loader, model, args.x_type)
    save(train_embeddings, f'train_{args.x_type}')
    save(train_targets, 'train_targets')
    del train_embeddings
    gc.collect()

    val_embeddings, val_targets = embeddings(val_loader, model, args.x_type)
    save(val_embeddings, f'val_{args.x_type}')
    save(val_targets, 'val_targets')

    ood_embeddings, _ = embeddings(ood_loader, model, args.x_type)
    save(ood_embeddings, f'ood_{args.x_type}_{args.ood_name}')
    del ood_embeddings
    gc.collect()



if __name__ == '__main__':
    main()
