"""Code for creating (Fashion) MNIST derived binary classification datasets."""
from typing import Tuple

import numpy as np
from sklearn.decomposition import PCA
import tensorflow as tf


def _create_binary_mnist(x, y, label0: int, label1: int):
    keep_mask = (y == label0) | (y == label1)
    x = x[keep_mask].reshape([-1, 28 * 28]).astype(np.float32) / 255.0
    y = (y[keep_mask] == label1).astype(np.int32)
    return x, y


def _load_raw_dataset(name: str):
    if name.startswith('mnist') or name.startswith('fashion_mnist'):
        if name == 'mnist49':
            label0, label1 = 4, 9
        elif name == 'fashion_mnist_pullover_coat':
            label0, label1 = 2, 4
        else:
            raise ValueError(name)

        if name.startswith('mnist'):
            (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
        else:
            (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

        x_train, y_train = _create_binary_mnist(x_train, y_train, label0, label1)
        x_test, y_test = _create_binary_mnist(x_test, y_test, label0, label1)

    else:
        raise ValueError(name)

    return (x_train, y_train), (x_test, y_test)


def make_dataset(
    name: str,
    n_components: int
) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray], PCA]:
    """Create a binary classification dataset.

    Available datasets:
        mnist49: Binary classification between the 4 and 9 digits in MNIST. The
            9 has the positive label and the 4 has the zero label in the returned
            dataset.
        fashion_mnist_pullover_coat: Binary classification between the pullover and coat
            classes in Fashion MNIST. The coat has the positive label and the
            pullover has the zero label in the returned dataset.

    Args:
        name: {"mnist49", "fashion_mnist_pullover_coat"}, name of the dataset
        n_components: The number of principle components to use.

    Returns:
        (x_train, y_train): A float32 array with shape [N_train, n_components] and
            an int32 array with shape [N_train] representing the train split.
        (x_test, y_test): A float32 array with shape [N_train, n_components] and
            an int32 array with shape [N_train] representing the train split.
        pca: The sklearn PCA object used to transform between the original flattened
            images and the whitened principle components in the returned datasets.
    """
    (x_train, y_train), (x_test, y_test) = _load_raw_dataset(name)

    pca = PCA(n_components=n_components)
    x_train = pca.fit_transform(x_train)
    x_test = pca.transform(x_test)

    return (x_train, y_train), (x_test, y_test), pca
