import collections
import nest
import numpy as np
import pandas as pd
import plotly.express as px
from tokenizers import BertWordPieceTokenizer


class Visualizer:
    def __init__(self, vocab_file, separate_message_novelty=False):
        self.novelty_records = []
        self.separate_message_novelty = separate_message_novelty
        self.last_mean_step = 0
        self.last_novelty_step = 0
        self.init_tokenizer(vocab_file)
        self.reset_counts()

    def reset_counts(self):
        self.message_counts = collections.Counter()
        self.mean_novelty = collections.defaultdict(list)
        if self.separate_message_novelty:
            self.mean_message_novelty = collections.defaultdict(list)
        else:
            self.mean_message_novelty = None

    def update_novelty(self, messages, novelty, message_novelty=None):
        message_strs = self.messages_to_str(messages)
        if self.separate_message_novelty:
            assert message_novelty is not None

        for i, msg in enumerate(message_strs):
            self.message_counts[msg] += 1
            self.mean_novelty[msg].append(novelty[i])
            if self.separate_message_novelty:
                self.mean_message_novelty[msg].append(message_novelty[i])

    def compute_means(self, step):
        self.last_mean_step = step
        this_novelty = {m: np.mean(arr) for m, arr in self.mean_novelty.items()}
        if self.separate_message_novelty:
            this_message_novelty = {m: np.mean(arr) for m, arr in self.mean_message_novelty.items()}
        else:
            this_message_novelty = None

        for msg, n in this_novelty.items():
            novelty_record = {
                "message": msg,
                "count": self.message_counts[msg],
                "step": step,
                "novelty": n,
            }
            if this_message_novelty is not None:
                novelty_record["message_novelty"] = this_message_novelty[msg]
            self.novelty_records.append(novelty_record)

        # Reset message counts
        self.reset_counts()

    def get_novelty_df(self, step):
        self.last_novelty_step = step
        df = pd.DataFrame.from_records(self.novelty_records)
        return df

    def init_tokenizer(self, vocab_file):
        self.tokenizer = BertWordPieceTokenizer(vocab_file, lowercase=True)

    def messages_to_str(self, messages):
        messages_flat = messages.view(-1, messages.shape[-1])
        messages_flat = messages_flat.cpu().numpy()
        message_strs = self.tokenizer.decode_batch(messages_flat)
        return message_strs
