import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms


################################################################################
# General helpers.
################################################################################

#
# Count the number of learnable parameters in a model.
#
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


#
# Get the desired optimizer.
#
def get_optimizer(optimname, params, learning_rate, momentum, decay):
    if optimname == "sgd":
        optimizer = optim.SGD(params, lr=learning_rate, momentum=momentum, weight_decay=decay)
    elif optimname == "adadelta":
        optimizer = optim.Adadelta(params, lr=learning_rate, weight_decay=decay)
    elif optimname == "adam":
        optimizer = optim.Adam(params, lr=learning_rate, weight_decay=decay)
        # optimizer = optim.Adam(params, lr=learning_rate)
    elif optimname == "adamW":
        optimizer = optim.AdamW(params, lr=learning_rate, weight_decay=decay)
        # optimizer = optim.Adam(params, lr=learning_rate)
    elif optimname == "rmsprop":
        optimizer = optim.RMSprop(params, lr=learning_rate, weight_decay=decay, momentum=momentum)
    elif optimname == "asgd":
        optimizer = optim.ASGD(params, lr=learning_rate, weight_decay=decay)
    elif optimname == "adamax":
        optimizer = optim.Adamax(params, lr=learning_rate, weight_decay=decay)
    else:
        print('Your option for the optimizer is not available, I am loading SGD.')
        optimizer = optim.SGD(params, lr=learning_rate, momentum=momentum, weight_decay=decay)

    return optimizer


################################################################################
# Standard dataset loaders.
################################################################################
def load_cifar100(basedir, batch_size, kwargs):
    # Input channels normalization.
    mrgb = [0.507, 0.487, 0.441]
    srgb = [0.267, 0.256, 0.276]
    normalize = transforms.Normalize(mean=mrgb, std=srgb)

    # Load train data.
    trainloader = torch.utils.data.DataLoader(
        datasets.CIFAR100(root=basedir + 'cifar100/', train=True,
                          transform=transforms.Compose([
                              transforms.RandomCrop(32, 4),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              normalize,
                          ]), download=True),
        batch_size=batch_size, shuffle=True, **kwargs)

    # Labels to torch.
    trainloader.dataset.train_labels = torch.from_numpy(np.array(trainloader.dataset.targets))
    # torch.from_numpy(np.array(trainloader.dataset.train_labels))

    # Load test data.
    testloader = torch.utils.data.DataLoader(
        datasets.CIFAR100(root=basedir + 'cifar100/', train=False,
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              normalize,
                          ])),
        batch_size=batch_size, shuffle=True, **kwargs)
    # Labels to torch.
    testloader.dataset.test_labels = torch.from_numpy(np.array(testloader.dataset.targets))
    # torch.from_numpy(np.array(testloader.dataset.test_labels))

    return trainloader, testloader


# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
def load_cifar10(basedir, batch_size, kwargs):
    # Input channels normalization.
    mrgb = [0.507, 0.487, 0.441]
    srgb = [0.267, 0.256, 0.276]
    normalize = transforms.Normalize(mean=mrgb, std=srgb)

    # Load train data.
    trainloader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root=basedir + 'cifar10/', train=True,
                         transform=transforms.Compose([
                             transforms.RandomCrop(32, 4),
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                             normalize,
                         ]), download=True),
        batch_size=batch_size, shuffle=True, **kwargs)

    # Labels to torch.
    trainloader.dataset.train_labels = torch.from_numpy(np.array(trainloader.dataset.targets))

    # Load test data.
    testloader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root=basedir + 'cifar10/', train=False,
                         transform=transforms.Compose([
                             transforms.ToTensor(),
                             normalize,
                         ])),
        batch_size=batch_size, shuffle=True, **kwargs)
    # Labels to torch.
    testloader.dataset.test_labels = torch.from_numpy(np.array(testloader.dataset.targets))

    return trainloader, testloader


# ------------------------------------------------
class Cifar10Dataset(data.Dataset):
    def __init__(self, basedir, batch_size, train, **kwargs):
        self.basedir = basedir
        self.batch_size = batch_size
        self.total_data = datasets.CIFAR10(root=self.basedir + 'cifar10/', train=train)

        mrgb = [0.507, 0.487, 0.441]
        srgb = [0.267, 0.256, 0.276]
        normalize = transforms.Normalize(mean=mrgb, std=srgb)

        self.data_transforms = transforms.Compose(
            [transforms.ToPILImage(), transforms.RandomCrop(32, 4), transforms.RandomHorizontalFlip(),
             transforms.ToTensor(), normalize])

    def __len__(self):
        return len(self.total_data)

    def __getitem__(self, idx):
        label = self.total_data.targets[idx]
        raw_image = self.total_data.data[idx]
        transformed_image = self.data_transforms(raw_image)
        return raw_image, transformed_image, label
