#
# Helper functions that can be used by all scripts.
#
# Most notable helpers:
# - Load datasets.
# - Subsample a dataset.
# - Counter the number of parameters in a model.
#
# Hyperspherical Prototype Networks
#

import os
import numpy as np
from   six.moves import cPickle as pickle
import torch
import torch.optim as optim
import torch.utils.data as data
from   torchvision import datasets, transforms

################################################################################
# General helpers.
################################################################################

#
# Count the number of learnable parameters in a model.
#
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

#
# Get the desired optimizer.
#
def get_optimizer(optimname, params, learning_rate, momentum, decay):
    if optimname == "sgd":
        optimizer = optim.SGD(params, lr=learning_rate, momentum=momentum, \
                weight_decay=decay)
    elif optimname == "adadelta":
        optimizer = optim.Adadelta(params, lr=learning_rate, \
                weight_decay=decay)
    elif optimname == "adam":
        optimizer = optim.Adam(params, lr=learning_rate, weight_decay=decay)
    elif optimname == "rmsprop":
        optimizer = optim.RMSprop(params, lr=learning_rate, \
                weight_decay=decay, momentum=momentum)
    return optimizer

#
#
#
def per_class_accuracy(ty, tp, nr_classes):
    correct = ty == tp
    correct = np.bincount(ty[ty==tp], minlength=nr_classes).astype(float)
    total   = np.bincount(ty, minlength=nr_classes).astype(float)
    return correct / total

#
# Select a specific number of examples per class.
#
# x (torch Tensor) - Features.
# y (torch Tensor) - Labels.
# nr_ex (int)      - Number of examples per class.
#
def subselect(x, y, nr_ex):
    # Different classes.
    ny      = y.numpy()
    classes = np.unique(ny)
    # Remember indices to keep.
    tokeep  = []
    # Go over all classes.
    for i in range(len(classes)):
        cidxs = np.where(ny == classes[i])[0]
        sidxs = np.random.choice(cidxs, nr_ex, replace=False)
        tokeep.append(sidxs)
    # Subselect and return.
    tokeep = torch.from_numpy(np.concatenate(tokeep))
    return x[tokeep], y[tokeep]

################################################################################
# Standard dataset loaders.
################################################################################

#
# Load the MNISt dataset.
#
def load_mnist(batch_size, kwargs):
    # Train data.
    trainloader = torch.utils.data.DataLoader(
            datasets.MNIST('../data/mnist/', train=True,
                    transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                    ])), 
            batch_size=batch_size, shuffle=True, **kwargs)
    # Test data.
    testloader = torch.utils.data.DataLoader(
            datasets.MNIST('../data/mnist/', train=False,
                    transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                    ])),
            batch_size=batch_size, shuffle=True, **kwargs)
    return trainloader, testloader

#
# Load the CIFAR 10 dataset.
#
def load_cifar10(basedir, batch_size, kwargs):
    # Input channels normalization.
    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                std=[0.2023, 0.1994, 0.2010])

    # Load train data.
    trainloader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=basedir+'cifar10/', train=True,
                    transform=transforms.Compose([
                            transforms.RandomCrop(32, 4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normalize,
                    ])),
             batch_size=batch_size, shuffle=True, **kwargs)
    # Labels to torch.
    trainloader.dataset.train_labels = \
            torch.from_numpy(np.array(trainloader.dataset.train_labels))
    
    # Load test data.
    testloader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root=basedir+'cifar10/', train=False,
                    transform=transforms.Compose([
                            transforms.ToTensor(),
                            normalize,
                    ])),
            batch_size=batch_size, shuffle=True, **kwargs)
    # Labels to torch.
    testloader.dataset.test_labels = \
            torch.from_numpy(np.array(testloader.dataset.test_labels))
    
    return trainloader, testloader

#
# Load the CIFAR 100 dataset.
#
def load_cifar100(basedir, batch_size, kwargs):
    # Input channels normalization.
    normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441],
                std=[0.267, 0.256, 0.276])

    # Load train data.
    trainloader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=basedir+'cifar100/', train=True,
                    transform=transforms.Compose([
                            transforms.RandomCrop(32, 4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normalize,
                    ]), download=False),
             batch_size=batch_size, shuffle=True, **kwargs)
    # Labels to torch.
    trainloader.dataset.train_labels = \
            torch.from_numpy(np.array(trainloader.dataset.train_labels))
    
    # Load test data.
    testloader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root=basedir+'cifar100/', train=False,
                    transform=transforms.Compose([
                            transforms.ToTensor(),
                            normalize,
                    ])),
            batch_size=batch_size, shuffle=True, **kwargs)
    # Labels to torch.
    testloader.dataset.test_labels = \
            torch.from_numpy(np.array(testloader.dataset.test_labels))
    
    return trainloader, testloader


#
# Load the ImageNet-200 dataset.
#
def load_imagenet200(basedir, batch_size, kwargs):
    # Correct basedir.
    basedir += "imagenet200/"
    
    # Normalization.
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], \
            std=[0.229, 0.224, 0.225])
    
    # Train loader.
    train_data = datasets.ImageFolder(basedir + "train/", \
        transform=transforms.Compose([transforms.RandomCrop(64, 4), \
        transforms.RandomHorizontalFlip(), \
        transforms.ToTensor(), \
        normalize]))
    trainloader = torch.utils.data.DataLoader(train_data, \
            batch_size=batch_size, shuffle=True, **kwargs)
    
    # Test loader.
    test_data = datasets.ImageFolder(basedir + "test/", \
        transform=transforms.Compose([transforms.ToTensor(), \
        normalize]))
    testloader = torch.utils.data.DataLoader(test_data, \
            batch_size=batch_size, shuffle=False, **kwargs)
    
    return trainloader, testloader

#
# Load the CUB-200-2011 dataset.
#
def load_cub(basedir, batch_size, kwargs):
    # Correct basedir.
    basedir += "cub/"

    # Normalization.
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], \
            std=[0.229, 0.224, 0.225])
    
    # Train loader.
    train_data = datasets.ImageFolder(basedir + "train/", \
        transform=transforms.Compose([transforms.RandomResizedCrop(224), \
        transforms.RandomHorizontalFlip(), \
        transforms.ToTensor(), \
        normalize]))
    trainloader = torch.utils.data.DataLoader(train_data, \
            batch_size=batch_size, shuffle=True, **kwargs)
    
    # Test loader.
    test_data = datasets.ImageFolder(basedir + "test/", \
        transform=transforms.Compose([transforms.RandomResizedCrop(224), \
        transforms.ToTensor(), \
        normalize]))
    testloader = torch.utils.data.DataLoader(test_data, \
            batch_size=batch_size, shuffle=False, **kwargs)
    
    return trainloader, testloader


################################################################################
# Storing helpers.
################################################################################

#
# Store cross-entropy baseline results.
#
def store_results_ce(args, testscores, ext, model=None):
    resdir = "%s/%s/softmax-ce/n-%s_s-%d/e-%d-%d-%d/r-%s_l-%.4f_m-%.4f_c-%.4f/" \
            %(args.resdir, args.dataset, args.network, args.seed, args.drop1, \
            args.drop2, args.epochs, args.optimizer, args.learning_rate, \
            args.momentum, args.decay)
    if not os.path.exists(resdir):
        os.makedirs(resdir)
    if testscores.shape[1] == 2:
        np.savetxt(resdir + ext, testscores, fmt='%d | %.8f')
    elif testscores.shape[1] == 3:
        np.savetxt(resdir + ext, testscores, fmt='%d | %.8f %.8f')
    if model is not None:
        model.save_state_dict(resdir + "model.pt")

#
#
#
def store_results_polar( args, testscores, ext, model=None):
    polardir = "/".join(args.polarfile.split("/")[-2:])[:-4]
    resdir = "%s/%s/hypershperical/n-%s_s-%d/%s/e-%d-%d-%d/r-%s_l-%.4f_m-%.4f_c-%.4f/" \
            %(args.resdir, args.dataset, args.network, args.seed, polardir, \
            args.drop1, args.drop2, args.epochs, args.optimizer, \
            args.learning_rate, args.momentum, args.decay)
    if not os.path.exists(resdir):
        os.makedirs(resdir)
    if testscores.shape[1] == 2:
        np.savetxt(resdir + ext, testscores, fmt='%d | %.8f')
    elif testscores.shape[1] == 3:
        np.savetxt(resdir + ext, testscores, fmt='%d | %.8f %.8f')
    if model is not None:
        model.save_state_dict(resdir + "model.pt")

#
#
#
def store_results_reg(resdir, args, testscores, ext, model=None):
    resdir += "a-%s_n-%s/seed-%d/e-%d-%d-%d_o-%d/r-%s_l-%.4f_m-%.4f_c-%.4f/task-%d/" \
            %(args.approach, args.network, args.seed, args.drop1, args.drop2, args.epochs, \
            args.output_dims, args.optimizer, \
            args.learning_rate, args.momentum, args.decay, args.task)
    if not os.path.exists(resdir):
        os.makedirs(resdir)
    #np.save(resdir + ext, testscores)
    np.savetxt(resdir + ext, testscores, fmt='%d | %.8f %.8f')
    if model is not None:
        model.save_state_dict(resdir + "model.pt")
