import logging
import numpy as np

from openai import OpenAI
from llm_compiler.utils.util import check_annotated_format
from llm_compiler.datatype import DataType
from llm_compiler.llm_interface.large_language_model import LargeLanguageModel

logger = logging.getLogger("global_logger")

class Corpora_Feature:
    def __init__(self, max_param, datatype:DataType, llm:LargeLanguageModel, temperature, freq_penalty, max_tokens, llm_cache_dir):
        self.max_param = max_param
        self.feature_dim = max_param * len(datatype.type)
        self.examples = []
        self.datatype = datatype
        self.llm = llm
        self.temperature = temperature
        self.freq_penalty = freq_penalty
        self.max_tokens = max_tokens
        self.llm_cache_dir = llm_cache_dir
        
        with open("data/entity_extraction.txt", 'r') as file:
            self.entity_extraction_prompt = file.read()
        return 
    
    def data_annotate(self, corpora):
        stored_corpora = corpora[:]
        corpora = stored_corpora[:]
        example = self.datatype.get_example(self.examples)
        type = self.datatype.get_type()
        prompt = self.entity_extraction_prompt.replace("&&&&&&", example).replace("%%%%%%", type) + corpora
        annotated_corpora = self.llm.sample_completions(prompt, temperature=self.temperature, freq_penalty=self.freq_penalty, max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, num_completions=1)
        annotated_corpora = annotated_corpora[0].response_text
        format, origin, label = check_annotated_format(annotated_corpora)
        if format and sum([1 for x in label if x in self.datatype.type]) == len(label):
            return origin, label
        else:
            logger.info("format worng: "+annotated_corpora)
            return [], []