# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from functools import lru_cache
import json


def convert_sentence_to_json(sentence):
    if '_' in sentence:
        prefix, rest = sentence.split('_', 1)
        query, rest = rest.split('_', 1)
        query_index = len(prefix.rstrip().split(' '))
    else:
        query, query_index = None, None

    prefix, rest = sentence.split('[', 1)
    pronoun, rest = rest.split(']', 1)
    pronoun_index = len(prefix.rstrip().split(' '))

    sentence = sentence.replace('_', '').replace('[', '').replace(']', '')

    return {
        'idx': 0,
        'text': sentence,
        'target': {
            'span1_index': query_index,
            'span1_text': query,
            'span2_index': pronoun_index,
            'span2_text': pronoun,
        },
    }


def extended_noun_chunks(sentence):
    noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
    np_start, cur_np = 0, 'NONE'
    for i, token in enumerate(sentence):
        np_type = token.pos_ if token.pos_ in {'NOUN', 'PROPN'} else 'NONE'
        if np_type != cur_np:
            if cur_np != 'NONE':
                noun_chunks.add((np_start, i))
            if np_type != 'NONE':
                np_start = i
            cur_np = np_type
    if cur_np != 'NONE':
        noun_chunks.add((np_start, len(sentence)))
    return [sentence[s:e] for (s, e) in sorted(noun_chunks)]


def find_token(sentence, start_pos):
    found_tok = None
    for tok in sentence:
        if tok.idx == start_pos:
            found_tok = tok
            break
    return found_tok


def find_span(sentence, search_text, start=0):
    search_text = search_text.lower()
    for tok in sentence[start:]:
        remainder = sentence[tok.i:].text.lower()
        if remainder.startswith(search_text):
            len_to_consume = len(search_text)
            start_idx = tok.idx
            for next_tok in sentence[tok.i:]:
                end_idx = next_tok.idx + len(next_tok.text)
                if end_idx - start_idx == len_to_consume:
                    span = sentence[tok.i:next_tok.i + 1]
                    return span
    return None


@lru_cache(maxsize=1)
def get_detokenizer():
    from sacremoses import MosesDetokenizer
    detok = MosesDetokenizer(lang='en')
    return detok


@lru_cache(maxsize=1)
def get_spacy_nlp():
    import en_core_web_lg
    nlp = en_core_web_lg.load()
    return nlp


def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
    detok = get_detokenizer()
    nlp = get_spacy_nlp()

    with open(input_fname) as fin:
        for line in fin:
            sample = json.loads(line.strip())

            if positive_only and 'label' in sample and not sample['label']:
                # only consider examples where the query is correct
                continue

            target = sample['target']

            # clean up the query
            query = target['span1_text']
            if query is not None:
                if '\n' in query:
                    continue
                if query.endswith('.') or query.endswith(','):
                    query = query[:-1]

            # split tokens
            tokens = sample['text'].split(' ')

            def strip_pronoun(x):
                return x.rstrip('.,"')

            # find the pronoun
            pronoun_idx = target['span2_index']
            pronoun = strip_pronoun(target['span2_text'])
            if strip_pronoun(tokens[pronoun_idx]) != pronoun:
                # hack: sometimes the index is misaligned
                if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
                    pronoun_idx += 1
                else:
                    raise Exception('Misaligned pronoun!')
            assert strip_pronoun(tokens[pronoun_idx]) == pronoun

            # split tokens before and after the pronoun
            before = tokens[:pronoun_idx]
            after = tokens[pronoun_idx + 1:]

            # the GPT BPE attaches leading spaces to tokens, so we keep track
            # of whether we need spaces before or after the pronoun
            leading_space = ' ' if pronoun_idx > 0 else ''
            trailing_space = ' ' if len(after) > 0 else ''

            # detokenize
            before = detok.detokenize(before, return_str=True)
            pronoun = detok.detokenize([pronoun], return_str=True)
            after = detok.detokenize(after, return_str=True)

            # hack: when the pronoun ends in a period (or comma), move the
            # punctuation to the "after" part
            if pronoun.endswith('.') or pronoun.endswith(','):
                after = pronoun[-1] + trailing_space + after
                pronoun = pronoun[:-1]

            # hack: when the "after" part begins with a comma or period, remove
            # the trailing space
            if after.startswith('.') or after.startswith(','):
                trailing_space = ''

            # parse sentence with spacy
            sentence = nlp(before + leading_space + pronoun + trailing_space + after)

            # find pronoun span
            start = len(before + leading_space)
            first_pronoun_tok = find_token(sentence, start_pos=start)
            pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
            assert pronoun_span.text == pronoun

            if eval:
                # convert to format where pronoun is surrounded by "[]" and
                # query is surrounded by "_"
                query_span = find_span(sentence, query)
                query_with_ws = '_{}_{}'.format(
                    query_span.text,
                    (' ' if query_span.text_with_ws.endswith(' ') else '')
                )
                pronoun_with_ws = '[{}]{}'.format(
                    pronoun_span.text,
                    (' ' if pronoun_span.text_with_ws.endswith(' ') else '')
                )
                if query_span.start < pronoun_span.start:
                    first = (query_span, query_with_ws)
                    second = (pronoun_span, pronoun_with_ws)
                else:
                    first = (pronoun_span, pronoun_with_ws)
                    second = (query_span, query_with_ws)
                sentence = (
                    sentence[:first[0].start].text_with_ws
                    + first[1]
                    + sentence[first[0].end:second[0].start].text_with_ws
                    + second[1]
                    + sentence[second[0].end:].text
                )
                yield sentence, sample.get('label', None)
            else:
                yield sentence, pronoun_span, query, sample.get('label', None)


def winogrande_jsonl_iterator(input_fname, eval=False):
    with open(input_fname) as fin:
        for line in fin:
            sample = json.loads(line.strip())
            sentence, option1, option2 = sample['sentence'], sample['option1'],\
                sample['option2']

            pronoun_span = (sentence.index('_'), sentence.index('_') + 1)

            if eval:
                query, cand = option1, option2
            else:
                query = option1 if sample['answer'] == '1' else option2
                cand = option2 if sample['answer'] == '1' else option1
            yield sentence, pronoun_span, query, cand


def filter_noun_chunks(chunks, exclude_pronouns=False, exclude_query=None, exact_match=False):
    if exclude_pronouns:
        chunks = [
            np for np in chunks if (
                np.lemma_ != '-PRON-'
                and not all(tok.pos_ == 'PRON' for tok in np)
            )
        ]

    if exclude_query is not None:
        excl_txt = [exclude_query.lower()]
        filtered_chunks = []
        for chunk in chunks:
            lower_chunk = chunk.text.lower()
            found = False
            for excl in excl_txt:
                if (
                    (not exact_match and (lower_chunk in excl or excl in lower_chunk))
                    or lower_chunk == excl
                ):
                    found = True
                    break
            if not found:
                filtered_chunks.append(chunk)
        chunks = filtered_chunks

    return chunks
