from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import random

import torchtext


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


def _subsample_by_classes(all_examples, labels, proportions=None, num_per_class=None):
    examples = {label: [] for label in labels}
    for example in all_examples:
        examples[example.label].append(example)

    selected_examples = []
    for label in labels:
        random.shuffle(examples[label])

        if proportions is not None:
            num_in_class = int(proportions[label] * len(examples[label]))
        else:
            num_in_class = num_per_class[label]
        selected_examples = selected_examples + examples[label][:num_in_class]
        print('number of examples with label \'{}\': {}'.format(label, num_in_class))

    return selected_examples


def _split_by_classes(all_examples, labels, num_select_per_class):
    examples = {label: [] for label in labels}
    for example in all_examples:
        examples[example.label].append(example)

    selected_examples = []
    remaining_examples = []
    for label in labels:
        assert num_select_per_class <= len(examples[label])

        random.shuffle(examples[label])
        selected_examples = selected_examples + examples[label][:num_select_per_class]
        remaining_examples = remaining_examples + examples[label][num_select_per_class:]

    return selected_examples, remaining_examples


class IMDBProcessor:
    """Processor for the IMDB data set."""

    def __init__(self):
        TEXT = torchtext.data.Field()
        LABEL = torchtext.data.Field(sequential=False)

        self._train_set, self._test_set = \
            torchtext.datasets.IMDB.splits(TEXT, LABEL)

        # train set: pos 12500 ; neg 12500

        all_examples = self._create_examples(self._train_set, "train")
        self._full_dev_examples, self._full_train_examples = \
                _split_by_classes(all_examples, self.get_labels(), num_select_per_class=10)

    def get_train_examples(self, proportions=None, num_per_class=None, noise_rate=0.):
        """See base class."""
        print('get train examples')
        all_examples = self._full_train_examples

        #if proportions is None:
        #    proportions = {label: 1. for label in self.get_labels()}

        # Add noise
        for i, _ in enumerate(all_examples):
            if random.random() < noise_rate:
                all_examples[i].label = random.choice(self.get_labels())

        # Subsample
        selected_examples = _subsample_by_classes(
            all_examples, self.get_labels(), proportions, num_per_class)

        return selected_examples

    def get_dev_examples(self, proportions=None, num_per_class=None):
        """See base class."""
        print('get dev examples')
        all_examples = self._full_dev_examples

        if proportions is None and num_per_class is None:
            return all_examples
        else:
            selected_examples = _subsample_by_classes(
                all_examples, self.get_labels(), proportions, num_per_class)
            return selected_examples

    def get_test_examples(self):
        """See base class."""
        print('get test examples')
        all_examples = self._create_examples(self._test_set, "test")
        return all_examples

    def get_labels(self):
        """See base class."""
        return ['pos', 'neg']

    def _create_examples(self, dataset, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, data) in enumerate(dataset):
            guid = "%s-%s" % (set_type, i)
            examples.append(InputExample(
                guid=guid,
                text_a=' '.join(data.text),
                text_b=None,
                label=data.label))
        return examples


class TRECProcessor:
    """Processor for the TREC data set."""

    def __init__(self):
        TEXT = torchtext.data.Field()
        LABEL = torchtext.data.Field(sequential=False)

        self._train_set, self._test_set = \
            torchtext.datasets.TREC.splits(
                TEXT, LABEL, fine_grained=False)

        all_examples = self._create_examples(self._train_set, "train")
        self._full_dev_examples, self._full_train_examples = \
                _split_by_classes(all_examples, self.get_labels(), num_select_per_class=10)

    def get_train_examples(self, proportions=None, num_per_class=None, noise_rate=0.):
        """See base class."""
        print('get train examples')
        all_examples = self._full_train_examples

        #if proportions is None:
        #    proportions = {label: 1. for label in self.get_labels()}

        # Add noise
        for i, _ in enumerate(all_examples):
            if random.random() < noise_rate:
                all_examples[i].label = random.choice(self.get_labels())

        # Subsample
        selected_examples = _subsample_by_classes(
            all_examples, self.get_labels(), proportions, num_per_class)

        return selected_examples

    def get_dev_examples(self, proportions=None, num_per_class=None):
        """See base class."""
        print('get dev examples')
        all_examples = self._full_dev_examples

        if proportions is None and num_per_class is None:
            return all_examples
        else:
            selected_examples = _subsample_by_classes(
                all_examples, self.get_labels(), proportions, num_per_class)
            return selected_examples

    def get_test_examples(self):
        """See base class."""
        print('get test examples')
        all_examples = self._create_examples(self._test_set, "test")
        return all_examples

    def get_labels(self):
        """See base class."""
        return ['NUM', 'LOC', 'HUM', 'DESC', 'ENTY', 'ABBR']

    def _create_examples(self, dataset, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, data) in enumerate(dataset):
            guid = "%s-%s" % (set_type, i)
            examples.append(InputExample(
                guid=guid,
                text_a=' '.join(data.text),
                text_b=None,
                label=data.label))
        return examples


class SST5Processor:
    """Processor for the SST-5 data set."""

    def __init__(self):
        TEXT = torchtext.data.Field()
        LABEL = torchtext.data.Field(sequential=False)

        self._train_set, self._dev_set, self._test_set = \
            torchtext.datasets.SST.splits(
                TEXT, LABEL, fine_grained=True)

    def get_train_examples(self, proportions=None, num_per_class=None, noise_rate=0.):
        """See base class."""
        print('get train examples')
        all_examples = self._create_examples(self._train_set, "train")

        #if proportions is None:
        #    proportions = {label: 1. for label in self.get_labels()}

        # Add noise
        for i, _ in enumerate(all_examples):
            if random.random() < noise_rate:
                all_examples[i].label = random.choice(self.get_labels())

        # Subsample
        if proportions is None and num_per_class is None:
            return all_examples
        else:
            selected_examples = _subsample_by_classes(
                all_examples, self.get_labels(), proportions, num_per_class)
            return selected_examples


    def get_dev_examples(self, proportions=None, num_per_class=None):
        """See base class."""
        print('get dev examples')
        all_examples = self._create_examples(self._dev_set, "dev")

        if proportions is None and num_per_class is None:
            return all_examples
        else:
            selected_examples = _subsample_by_classes(
                all_examples, self.get_labels(), proportions, num_per_class)
            return selected_examples

    def get_test_examples(self):
        """See base class."""
        print('get test examples')
        return self._create_examples(self._test_set, "test")

    def get_labels(self):
        """See base class."""
        return ['negative', 'very positive', 'neutral',
                'positive', 'very negative']

    def _create_examples(self, dataset, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, data) in enumerate(dataset):
            guid = "%s-%s" % (set_type, i)
            examples.append(InputExample(
                guid=guid,
                text_a=' '.join(data.text),
                text_b=None,
                label=data.label))
        return examples



class SST2Processor:
    """Processor for the SST-2 data set (GLUE version)."""

    # def get_train_examples(self):
    #     """See base class."""
    #     return self._create_examples("train")

    def get_train_examples(self, proportions=None, num_per_class=None, noise_rate=0.):
        """See base class."""
        print('get train examples')
        all_examples = self._create_examples("train")

        #if proportions is None:
        #    proportions = {label: 1. for label in self.get_labels()}

        # Add noise
        for i, _ in enumerate(all_examples):
            if random.random() < noise_rate:
                all_examples[i].label = random.choice(self.get_labels())

        # Subsample
        selected_examples = _subsample_by_classes(
            all_examples, self.get_labels(), proportions, num_per_class)

        return selected_examples

    def get_dev_examples(self, proportions=None, num_per_class=None):
        """See base class."""
        print('get dev examples')
        all_examples = self._create_examples("dev")

        if proportions is None and num_per_class is None:
            return all_examples
        else:
            selected_examples = _subsample_by_classes(
                all_examples, self.get_labels(), proportions, num_per_class)
            return selected_examples

    def get_test_examples(self):
        """See base class."""
        return self._create_examples("test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, set_type):
        """Creates examples for the training and dev sets."""
        sentence_file = open('data/sst2.{}.sentences.txt'.format(set_type))
        labels_file = open('data/sst2.{}.labels.txt'.format(set_type))

        examples = []
        for sentence, label in zip(
                sentence_file.readlines(), labels_file.readlines()):
            label = label.strip('\n')
            sentence = sentence.strip('\n')

            if label == '':
                break
            examples.append(InputExample(
                guid=set_type, text_a=sentence, text_b=None, label=label))
        return examples


# processor = SST2Processor()
# train_examples = processor.get_train_examples(
#     proportions={'0': 0.001, '1': 0.01})
#
# # for example in train_examples:
# #     print(example.text_a, example.label)
