from tape.data_utils import deserialize_secondary_structure_sequence_eight_classes
from tape.data_utils import deserialize_secondary_structure_sequence_three_classes
from .Task import SequenceToSequenceClassificationTask


class SecondaryStructureTask(SequenceToSequenceClassificationTask):

    def __init__(self, n_classes: int = 8):
        assert n_classes in [3, 8]
        deserialization_functions = {
            3: deserialize_secondary_structure_sequence_three_classes,
            8: deserialize_secondary_structure_sequence_eight_classes}
        super().__init__(
            key_metric='ACC',
            supervised=True,
            deserialization_func=deserialization_functions[n_classes],
            n_classes=n_classes,
            label_name='output_sequence',
            input_name='encoder_output',
            output_name='sequence_logits')
