from typing import Tuple
import os
import tensorflow as tf

from .paired_protein_serializer import deserialize_paired_sequence


def get_paired_data(directory: str, batch_size: int, max_sequence_length: int) -> \
        Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:

    train_file = os.path.join(directory, 'supervised', 'paired_scope', 'paired_scope_train.tfrecords')
    valid_file = os.path.join(directory, 'supervised', 'paired_scope', 'paired_scope_valid.tfrecords')
    test_file = os.path.join(directory, 'supervised', 'paired_scope', 'paired_scope_test.tfrecords')

    if not os.path.exists(train_file):
        raise FileNotFoundError(train_file)
    if not os.path.exists(valid_file):
        raise FileNotFoundError(valid_file)
    if not os.path.exists(test_file):
        raise FileNotFoundError(test_file)

    train_data = tf.data.TFRecordDataset(train_file)
    valid_data = tf.data.TFRecordDataset(valid_file)
    test_data = tf.data.TFRecordDataset(test_file)

    def prepare_paired_dataset(dataset: tf.data.Dataset, shuffle: bool) -> tf.data.Dataset:
        dataset = dataset.map(deserialize_paired_sequence, batch_size)

        def filter_seqlen(example):
            return (example['first']['protein_length'] < max_sequence_length) \
                & (example['second']['protein_length'] < max_sequence_length)

        dataset = dataset.filter(filter_seqlen)
        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)

        # Length bucketing screws up the pairwise matching
        dataset = dataset.padded_batch(batch_size, dataset.output_shapes)
        return dataset

    train_data = prepare_paired_dataset(train_data, shuffle=True)
    valid_data = prepare_paired_dataset(valid_data, shuffle=False)
    test_data = prepare_paired_dataset(test_data, shuffle=False)

    return train_data, valid_data, test_data
