from typing import Tuple
import os
import tensorflow as tf
import numpy as np
from .transmembrane_protein_serializer import deserialize_transmembrane_sequence


def get_transmembrane_data(directory: str, batch_size: int, max_sequence_length: int) -> \
        Tuple[tf.data.Dataset, tf.data.Dataset]:

    train_file = os.path.join(directory, 'supervised', 'transmembrane', 'transmembrane_train.tfrecords')
    valid_file = os.path.join(directory, 'supervised', 'transmembrane', 'transmembrane_valid.tfrecords')

    if not os.path.exists(train_file):
        raise FileNotFoundError(train_file)
    if not os.path.exists(valid_file):
        raise FileNotFoundError(valid_file)

    train_data = tf.data.TFRecordDataset(train_file)
    valid_data = tf.data.TFRecordDataset(valid_file)

    def prepare_dataset(dataset: tf.data.Dataset, shuffle: bool) -> tf.data.Dataset:
        dataset = dataset.map(deserialize_transmembrane_sequence, batch_size)
        # dataset = dataset.filter(lambda example: example['protein_length'] < max_sequence_length)
        # dataset = dataset.filter(lambda example: example['protein_type'] != b'Globular')

        dataset = dataset.shuffle(5000) if shuffle else dataset.prefetch(1024)
        bucket_boundaries = np.arange(100, max_sequence_length + 100, 100)
        centers = np.arange(50, max_sequence_length + 100, 100)
        ratio = (centers[-1]) / (centers)
        ratio = ratio * batch_size
        ratio = np.asarray(ratio, np.int32)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            lambda example: example['protein_length'],
            bucket_boundaries,
            ratio)
        dataset = dataset.apply(batch_fun)
        return dataset

    train_data = prepare_dataset(train_data, shuffle=True)
    valid_data = prepare_dataset(valid_data, shuffle=False)

    return train_data, valid_data
