from typing import List
import os

from tape.data_utils import deserialize_binary_scop_sequence
from .Task import SequenceBinaryClassificationTask


class BinaryScopFoldTask(SequenceBinaryClassificationTask):

    def __init__(self, task_num: int):
        super().__init__(
            key_metric='Acc',
            supervised=True,
            deserialization_func=deserialize_binary_scop_sequence,
            label='label',
            input_name='encoder_output',
            output_name='logit')
        self._task_num = task_num

    def get_train_files(self, data_folder: str) -> List[str]:
        if self.supervised:
            data_folder = os.path.join(data_folder, 'supervised')
        train_file = os.path.join(data_folder, str(self), '{}_task_{}_train.tfrecords'.format(self, self._task_num))
        if not os.path.exists(train_file):
            raise FileNotFoundError(train_file)

        return [train_file]

    def get_valid_files(self, data_folder: str) -> List[str]:
        if self.supervised:
            data_folder = os.path.join(data_folder, 'supervised')
        valid_file = os.path.join(data_folder, str(self), '{}_task_{}_valid.tfrecords'.format(self, self._task_num))
        if not os.path.exists(valid_file):
            raise FileNotFoundError(valid_file)

        return [valid_file]
