import exrex
import logging
from lark.load_grammar import load_grammar

from prompt_compiler.earley_parser.earley.earley import Parser
from prompt_compiler.earley_parser.earley.earley_exceptions import UnexpectedCharacters, UnexpectedEOF

from prompt_compiler.data_structs.option import Option
from prompt_compiler.data_structs.lexer_conf import LexerConf
from prompt_compiler.data_structs.parser_conf import ParserConf

logger = logging.getLogger("global_logger")

class EarleyParser:
    """
    Frontend for Earley parser
    """

    def __init__(self, grammar: str, **options) -> None:
        self.option = Option(**options)

        self.grammar, _ = load_grammar(grammar,
                                       source="<string>",
                                       import_paths=[],
                                       global_keep_all_tokens=self.option.keep_all_tokens)
        self.terminals, self.rules, self.ignore_tokens = self.grammar.compile(self.option.start,
                                                                              terminals_to_keep=set())

        self.lexer_conf = LexerConf(self.terminals, self.ignore_tokens)
        self.parser_conf = ParserConf(self.rules, self.option.start)
        self.parser = self.build_parser()

    def build_parser(self):
        parser = Parser(self.lexer_conf, self.parser_conf)
        return parser

    @classmethod
    def open(cls, grammar_filename: str, **options):
        with open(grammar_filename, encoding="utf-8") as f:
            return cls(f.read(), **options)

    @classmethod
    def open_by_str(cls, grammar_str: str, **options):
        return cls(grammar_str, **options)

    def parse(self, text: str, start=None):
        if start is None:
            assert len(self.option.start) == 1, "multiple start symbol, please specify one"
            start = self.option.start[0]
        else:
            assert start in self.option.start

        return self.parser.parse(text, start)

    def handle_error(self, e):
        CANDIDATE_LIMIT = 128

        def regex_to_candidates(regex):
            candidates = set()
            if exrex.count(regex) > CANDIDATE_LIMIT:
                logger.info(f"regex {regex} has too many candidates")
            for candidate in exrex.generate(regex, limit=CANDIDATE_LIMIT):
                candidates.add(candidate)
            return candidates

        def pattern_to_candidates(pattern):
            candidates = set()
            if pattern.type == "str":
                candidates.add(pattern.value)
            elif pattern.type == "regex":
                candidates.update(regex_to_candidates(pattern.value))
            return candidates

        if isinstance(e, UnexpectedCharacters):
            candidate_terminals = set()
            for terminal_name in e.allowed:
                terminal_def = self.lexer_conf.terminals_by_name[terminal_name]
                new_candidates = pattern_to_candidates(terminal_def.pattern)
                candidate_terminals.update(new_candidates)
            prefix = e.parsed_prefix

            # TODO: handle case where no candidate is found
            if len(candidate_terminals) == 0:
                candidate_terminals = [""]
            return prefix, candidate_terminals
        elif isinstance(e, UnexpectedEOF):
            candidate_terminals = set()
            for terminal_name in e.expected:
                terminal_def = self.lexer_conf.terminals_by_name[terminal_name]
                new_candidates = pattern_to_candidates(terminal_def.pattern)
                candidate_terminals.update(new_candidates)

            assert len(candidate_terminals) > 0
            prefix = e.text
            return prefix, candidate_terminals
        else:
            raise e