from typing import Tuple
import os
import tensorflow as tf
import numpy as np
from .binary_scop_protein_serializer import deserialize_binary_scop_sequence


def get_binary_scop_task_data(directory: str,
                              batch_size: int,
                              max_sequence_length: int,
                              task_type: str,
                              task_num: int) -> \
        Tuple[tf.data.Dataset, tf.data.Dataset]:
    """
    Returns data for an individual binary SCOP 1.67 classification task.
    """

    train_file = os.path.join(directory, 'supervised', 'binary_{}_task',
                              'binary_{}_task_{}_train.tfrecords'.format(task_type, task_type, task_num))
    valid_file = os.path.join(directory, 'supervised', 'binary_{}_task'
                              'binary_{}_task_{}_valid.tfrecords'.format(task_type, task_type, task_num))

    if not os.path.exists(train_file):
        raise FileNotFoundError(train_file)
    if not os.path.exists(valid_file):
        raise FileNotFoundError(valid_file)

    train_data = tf.data.TFRecordDataset(train_file)
    valid_data = tf.data.TFRecordDataset(valid_file)

    def prepare_dataset(dataset: tf.data.Dataset, shuffle: bool) -> tf.data.Dataset:
        dataset = dataset.map(deserialize_binary_scop_sequence, batch_size)
        dataset = dataset.filter(lambda example: example['protein_length'] < max_sequence_length)
        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        bucket_boundaries = np.arange(100, max_sequence_length + 100, 100)
        centers = np.arange(50, max_sequence_length + 100, 100)
        ratio = (centers[-1]) / (centers)
        ratio = ratio * batch_size
        ratio = np.asarray(ratio, np.int32)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            lambda example: example['protein_length'],
            bucket_boundaries,
            ratio)
        dataset = dataset.apply(batch_fun)
        return dataset

    train_data = prepare_dataset(train_data, shuffle=True)
    valid_data = prepare_dataset(valid_data, shuffle=False)

    return train_data, valid_data
