# -*- coding:utf-8 -*- 
import logging
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from seqeval.metrics import classification_report
from utils.data_utils import load_and_cache_examples, tag_to_id, id_to_tag, get_chunks
from flashtool import Logger

def evaluate_ori(args, model, tokenizer, labels, pad_token_label_id, best, mode, logger, prefix="", verbose=True):
    eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode=mode)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running evaluation %s *****", prefix)
    if verbose:
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    all_valid = None
    num_labels = len(labels)
    model.eval()
    for batch in eval_dataloader:
        batch = tuple(t.to(args.device) for t in batch)
        valid_pos = batch[2]
        pseudo_labels = batch[3]
        with torch.no_grad():
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos": batch[2]}
            _,_,_,_,logits = model(**inputs)
            loss_fct = CrossEntropyLoss()
            #? todo: logits => pred labels
            tmp_eval_loss = loss_fct(logits, pseudo_labels[valid_pos>0])
            if args.n_gpu > 1:
                tmp_eval_loss = tmp_eval_loss.mean()
            eval_loss += tmp_eval_loss.item()
        nb_eval_steps += 1
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = pseudo_labels.detach().cpu().numpy()
            all_valid = valid_pos.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, pseudo_labels.detach().cpu().numpy(), axis=0)
            all_valid = np.append(all_valid, valid_pos.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    #? todo: preds
    preds = np.argmax(preds, axis=-1)
    total_preds = np.zeros((out_label_ids.shape[0],out_label_ids.shape[1]))
    total_preds = total_preds - 100
    k = 0
    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if all_valid[i][j]:
                total_preds[i][j] = preds[k]
                k = k+1

    inv_map = id_to_tag()
    label_map = {i: label for i, label in enumerate(labels)}
    preds_list = [[] for _ in range(out_label_ids.shape[0])]
    out_id_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_id_list = [[] for _ in range(out_label_ids.shape[0])]

    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if out_label_ids[i, j] != pad_token_label_id:
                preds_list[i].append(label_map[total_preds[i][j]])
                out_id_list[i].append(out_label_ids[i][j])
                preds_id_list[i].append(total_preds[i][j])

    # *Calculate per-entity F1 score
    correct_preds, total_correct, total_preds = 0., 0., 0. # i variables

    for ground_truth_id,predicted_id in zip(out_id_list,preds_id_list):
        # We use the get chunks function defined above to get the true chunks
        # and the predicted chunks from true labels and predicted labels respectively
        lab_chunks      = set(get_chunks(ground_truth_id, tag_to_id(args.data_dir, args.dataset)))
        lab_pred_chunks = set(get_chunks(predicted_id, tag_to_id(args.data_dir, args.dataset)))
        # Updating the i variables
        correct_preds += len(lab_chunks & lab_pred_chunks)
        total_preds   += len(lab_pred_chunks)
        total_correct += len(lab_chunks)

    p   = correct_preds / total_preds if correct_preds > 0 else 0
    r   = correct_preds / total_correct if correct_preds > 0 else 0
    new_F  = 2 * p * r / (p + r) if correct_preds > 0 else 0

    is_updated = False
    if new_F > best[-1]:
        best = [p, r, new_F]
        is_updated = True

    results = {
       "loss": eval_loss,
       "precision": p,
       "recall": r,
       "f1": new_F,
       "best_precision": best[0],
       "best_recall":best[1],
       "best_f1": best[-1]
    }

    logger.info("***** Eval results %s *****", prefix)
    for key in sorted(results.keys()):
        logger.info("  %s = %s", key, str(results[key]))

    return results, preds_list, best, is_updated

def evaluate(args, model, tokenizer, labels, pad_token_label_id, best, mode, logger, prefix="", verbose=True):
    eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode=mode)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running evaluation %s *****", prefix)
    if verbose:
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    all_valid = None
    num_labels = len(labels)
    model.eval()
    for batch in eval_dataloader:
        batch = tuple(t.to(args.device) for t in batch)
        valid_pos = batch[2]
        pseudo_labels = batch[3]
        with torch.no_grad():
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos": batch[2]}
            #!modify
            _,_,_,_,logits = model(**inputs)
            #_,_,logits = model(**inputs)
            # logits = model(**inputs)
            loss_fct = CrossEntropyLoss()
            #? todo: logits => pred labels
            tmp_eval_loss = loss_fct(logits, pseudo_labels[valid_pos>0])
            if args.n_gpu > 1:
                tmp_eval_loss = tmp_eval_loss.mean()

            eval_loss += tmp_eval_loss.item()
        nb_eval_steps += 1
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = pseudo_labels.detach().cpu().numpy()
            all_valid = valid_pos.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, pseudo_labels.detach().cpu().numpy(), axis=0)
            all_valid = np.append(all_valid, valid_pos.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    #? todo: preds
    preds = np.argmax(preds, axis=-1)
    total_preds = np.zeros((out_label_ids.shape[0],out_label_ids.shape[1]))
    total_preds = total_preds - 100
    k = 0
    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if all_valid[i][j]:
                total_preds[i][j] = preds[k]
                k = k+1

    inv_map = id_to_tag()
    label_map = {i: label for i, label in enumerate(labels)}
    preds_list = [[] for _ in range(out_label_ids.shape[0])]
    out_id_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_id_list = [[] for _ in range(out_label_ids.shape[0])]

    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if out_label_ids[i, j] != pad_token_label_id:
                preds_list[i].append(label_map[total_preds[i][j]])
                out_id_list[i].append(out_label_ids[i][j])
                preds_id_list[i].append(total_preds[i][j])

    # *Calculate per-entity F1 score
    correct_preds, total_correct, total_preds = 0., 0., 0. # i variables
    correct_preds_dict = {'PER':0, 'LOC':0, 'ORG':0, 'MISC':0}
    total_correct_dict = {'PER':0, 'LOC':0, 'ORG':0, 'MISC':0}
    total_preds_dict = {'PER':0, 'LOC':0, 'ORG':0, 'MISC':0}

    for ground_truth_id,predicted_id in zip(out_id_list,preds_id_list):
        # We use the get chunks function defined above to get the true chunks
        # and the predicted chunks from true labels and predicted labels respectively
        lab_chunks      = set(get_chunks(ground_truth_id, tag_to_id(args.data_dir, args.dataset)))
        lab_pred_chunks = set(get_chunks(predicted_id, tag_to_id(args.data_dir, args.dataset)))
        # Updating the i variables
        correct_preds += len(lab_chunks & lab_pred_chunks)
        total_preds   += len(lab_pred_chunks)
        total_correct += len(lab_chunks)
        for e in lab_chunks & lab_pred_chunks:
            correct_preds_dict[e[0]] += 1
        for e in lab_pred_chunks:
            total_preds_dict[e[0]] += 1
        for e in lab_chunks:
            total_correct_dict[e[0]] += 1

    p   = correct_preds / total_preds if correct_preds > 0 else 0
    r   = correct_preds / total_correct if correct_preds > 0 else 0
    new_F  = 2 * p * r / (p + r) if correct_preds > 0 else 0

    p_per = correct_preds_dict['PER'] / total_preds_dict['PER'] if correct_preds_dict['PER'] > 0 else 0
    r_per = correct_preds_dict['PER'] / total_correct_dict['PER'] if correct_preds_dict['PER'] > 0 else 0
    new_F_per = 2 * p_per * r_per / (p_per + r_per) if correct_preds_dict['PER'] > 0 else 0
    p_loc = correct_preds_dict['LOC'] / total_preds_dict['LOC'] if correct_preds_dict['LOC'] > 0 else 0
    r_loc = correct_preds_dict['LOC'] / total_correct_dict['LOC'] if correct_preds_dict['LOC'] > 0 else 0
    new_F_loc = 2 * p_loc * r_loc / (p_loc + r_loc) if correct_preds_dict['LOC'] > 0 else 0
    p_org = correct_preds_dict['ORG'] / total_preds_dict['ORG'] if correct_preds_dict['ORG'] > 0 else 0
    r_org = correct_preds_dict['ORG'] / total_correct_dict['ORG'] if correct_preds_dict['ORG'] > 0 else 0
    new_F_org = 2 * p_org * r_org / (p_org + r_org) if correct_preds_dict['ORG'] > 0 else 0
    p_misc = correct_preds_dict['MISC'] / total_preds_dict['MISC'] if correct_preds_dict['MISC'] > 0 else 0
    r_misc = correct_preds_dict['MISC'] / total_correct_dict['MISC'] if correct_preds_dict['MISC'] > 0 else 0
    new_F_misc = 2 * p_misc * r_misc / (p_misc + r_misc) if correct_preds_dict['MISC'] > 0 else 0

    is_updated = False
    if new_F > best[-1]:
        best = [p, r, new_F]
        is_updated = True

    results = {
       "loss": eval_loss,
       "precision": p,
       "recall": r,
       "PER_f1": new_F_per,
       "LOC_f1": new_F_loc,
       "ORG_f1": new_F_org,
       "MISC_f1": new_F_misc,
       "f1": new_F,
       "best_precision": best[0],
       "best_recall":best[1],
       "best_f1": best[-1]
    }

    logger.info("***** Eval results %s *****", prefix)
    
    for key in sorted(results.keys()):
        logger.info("  %s = %s", key, str(results[key]))

    return results, preds_list, best, is_updated

def evaluate_both(args, model1, model2, tokenizer, labels, pad_token_label_id, best, mode, logger, prefix="", verbose=True):
    eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode=mode)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running evaluation %s *****", prefix)
    if verbose:
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds_mean = None
    preds_max = None
    out_label_ids = None
    all_valid = None
    num_labels = len(labels)
    model1.eval()
    model2.eval()
    for batch in eval_dataloader:
        batch = tuple(t.to(args.device) for t in batch)
        valid_pos = batch[2]
        pseudo_labels = batch[3]
        with torch.no_grad():
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos": batch[2]}
            _,_,_,_,logits1 = model1(**inputs)
            _,_,_,_,logits2 = model2(**inputs)
            logits_mean = 0.5*logits1 + 0.5*logits2
            score1, predicted = torch.max(logits1, 1)
            score2, predicted_2 = torch.max(logits2, 1)
            max1_idx = torch.where(score1 > score2)
            predicted_max = logits2.clone().detach()
            predicted_max[max1_idx] = logits1[max1_idx]

            loss_fct = CrossEntropyLoss()
            #? todo: logits => pred labels
            tmp_eval_loss = loss_fct(0.5*(logits1+logits2), pseudo_labels[valid_pos>0])
            if args.n_gpu > 1:
                tmp_eval_loss = tmp_eval_loss.mean()

            eval_loss += tmp_eval_loss.item()
        nb_eval_steps += 1
        if preds_mean is None:
            preds_mean = logits_mean.detach().cpu().numpy()
            out_label_ids = pseudo_labels.detach().cpu().numpy()
            all_valid = valid_pos.detach().cpu().numpy()
        else:
            preds_mean = np.append(preds_mean, logits_mean.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, pseudo_labels.detach().cpu().numpy(), axis=0)
            all_valid = np.append(all_valid, valid_pos.detach().cpu().numpy(), axis=0)
        if preds_max is None:
            preds_max = predicted_max.detach().cpu().numpy()
        else:
            preds_max = np.append(preds_max, predicted_max.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    #? todo: preds
    preds_mean = np.argmax(preds_mean, axis=-1)
    preds_max = np.argmax(preds_max, axis=-1)

    total_preds_mean = np.zeros((out_label_ids.shape[0],out_label_ids.shape[1]))
    total_preds_mean = total_preds_mean - 100
    total_preds_max = np.zeros((out_label_ids.shape[0],out_label_ids.shape[1]))
    total_preds_max = total_preds_max - 100
    k = 0
    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if all_valid[i][j]:
                total_preds_mean[i][j] = preds_mean[k]
                total_preds_max[i][j] = preds_max[k]
                k = k+1

    inv_map = id_to_tag()
    label_map = {i: label for i, label in enumerate(labels)}
    preds_mean_list = [[] for _ in range(out_label_ids.shape[0])]
    out_id_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_id_mean_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_max_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_id_max_list = [[] for _ in range(out_label_ids.shape[0])]

    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if out_label_ids[i, j] != pad_token_label_id:
                preds_mean_list[i].append(label_map[total_preds_mean[i][j]])
                out_id_list[i].append(out_label_ids[i][j])
                preds_id_mean_list[i].append(total_preds_mean[i][j])
                preds_max_list[i].append(label_map[total_preds_max[i][j]])
                preds_id_max_list[i].append(total_preds_max[i][j])

    # *Calculate per-entity F1 score
    total_correct = 0.
    correct_preds_mean, total_preds_mean = 0., 0.
    correct_preds_max, total_preds_max = 0., 0.
    
    for ground_truth_id,predicted_id_mean,predicted_id_max in zip(out_id_list,preds_id_mean_list,preds_id_max_list):
        # We use the get chunks function defined above to get the true chunks
        # and the predicted chunks from true labels and predicted labels respectively
        lab_chunks      = set(get_chunks(ground_truth_id, tag_to_id(args.data_dir, args.dataset)))
        lab_pred_mean_chunks = set(get_chunks(predicted_id_mean, tag_to_id(args.data_dir, args.dataset)))
        lab_pred_max_chunks = set(get_chunks(predicted_id_max, tag_to_id(args.data_dir, args.dataset)))
        # Updating the i variables
        correct_preds_mean += len(lab_chunks & lab_pred_mean_chunks)
        total_preds_mean   += len(lab_pred_mean_chunks)
        correct_preds_max += len(lab_chunks & lab_pred_max_chunks)
        total_preds_max   += len(lab_pred_max_chunks)
        total_correct += len(lab_chunks)

    p_mean   = correct_preds_mean / total_preds_mean if correct_preds_mean > 0 else 0
    r_mean   = correct_preds_mean / total_correct if correct_preds_mean > 0 else 0
    new_F_mean  = 2 * p_mean * r_mean / (p_mean + r_mean) if correct_preds_mean > 0 else 0

    p_max  = correct_preds_max / total_preds_max if correct_preds_max > 0 else 0
    r_max   = correct_preds_max / total_correct if correct_preds_max > 0 else 0
    new_F_max  = 2 * p_max * r_max / (p_max + r_max) if correct_preds_max > 0 else 0

    is_updated = False
    if new_F_mean > best[-1]:
        best = [p_mean, r_mean, new_F_mean]
        is_updated = True
    if new_F_max > best[-1]:
        best = [p_max, r_max, new_F_max]
        is_updated = True

    results = {
       "loss": eval_loss,
       "precision_mean": p_mean,
       "recall_mean": r_mean,
       "f1_mean": new_F_mean,
       "precision_max": p_max,
       "recall_max": r_max,
       "f1_max": new_F_max,
       "best_precision": best[0],
       "best_recall":best[1],
       "best_f1": best[-1]
    }

    logger.info("***** Eval results %s *****", prefix)
    for key in sorted(results.keys()):
        logger.info("  %s = %s", key, str(results[key]))

    return results, preds_mean_list, best, is_updated

def evaluate_four(args, model1, model2, model3, model4, tokenizer, labels, pad_token_label_id, best, mode, logger, prefix="", verbose=True):
    eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode=mode)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running evaluation %s *****", prefix)
    if verbose:
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds_mean = None
    preds_max = None
    out_label_ids = None
    all_valid = None
    num_labels = len(labels)
    model1.eval()
    model2.eval()
    for batch in eval_dataloader:
        batch = tuple(t.to(args.device) for t in batch)
        valid_pos = batch[2]
        pseudo_labels = batch[3]
        with torch.no_grad():
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos": batch[2]}
            _,_,_,_,logits1 = model1(**inputs)
            _,_,_,_,logits2 = model2(**inputs)
            _,_,_,_,logits3 = model3(**inputs)
            _,_,_,_,logits4 = model4(**inputs)
            logits_mean = 0.25*(logits1 + logits2 + logits3 + logits4)
            score1, predicted = torch.max(logits1, 1)
            score2, predicted_2 = torch.max(logits2, 1)
            score3, predicted_3 = torch.max(logits3, 1)
            score4, predicted_4 = torch.max(logits4, 1)
            max1_idx = torch.where(score1 > score2)
            predicted_max12 = logits2.clone().detach()
            predicted_max12[max1_idx] = logits1[max1_idx]
            max3_idx = torch.where(score3 > score4)
            predicted_max34 = logits4.clone().detach()
            predicted_max34[max3_idx] = logits3[max3_idx]
            score12 = torch.where(score1>score2,score1,score2)
            score34 = torch.where(score3>score4,score3,score4)
            max12_idx = torch.where(score12>score34)
            predicted_max = predicted_max34.clone().detach()
            predicted_max[max12_idx] = predicted_max12[max12_idx]
            loss_fct = CrossEntropyLoss()
            #? todo: logits => pred labels
            tmp_eval_loss = loss_fct(0.5*(logits1+logits2), pseudo_labels[valid_pos>0])
            if args.n_gpu > 1:
                tmp_eval_loss = tmp_eval_loss.mean()

            eval_loss += tmp_eval_loss.item()
        nb_eval_steps += 1
        if preds_mean is None:
            preds_mean = logits_mean.detach().cpu().numpy()
            out_label_ids = pseudo_labels.detach().cpu().numpy()
            all_valid = valid_pos.detach().cpu().numpy()
        else:
            preds_mean = np.append(preds_mean, logits_mean.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, pseudo_labels.detach().cpu().numpy(), axis=0)
            all_valid = np.append(all_valid, valid_pos.detach().cpu().numpy(), axis=0)
        if preds_max is None:
            preds_max = predicted_max.detach().cpu().numpy()
        else:
            preds_max = np.append(preds_max, predicted_max.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    #? todo: preds
    preds_mean = np.argmax(preds_mean, axis=-1)
    preds_max = np.argmax(preds_max, axis=-1)

    total_preds_mean = np.zeros((out_label_ids.shape[0],out_label_ids.shape[1]))
    total_preds_mean = total_preds_mean - 100
    total_preds_max = np.zeros((out_label_ids.shape[0],out_label_ids.shape[1]))
    total_preds_max = total_preds_max - 100
    k = 0
    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if all_valid[i][j]:
                total_preds_mean[i][j] = preds_mean[k]
                total_preds_max[i][j] = preds_max[k]
                k = k+1

    inv_map = id_to_tag()
    label_map = {i: label for i, label in enumerate(labels)}
    preds_mean_list = [[] for _ in range(out_label_ids.shape[0])]
    out_id_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_id_mean_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_max_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_id_max_list = [[] for _ in range(out_label_ids.shape[0])]

    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if out_label_ids[i, j] != pad_token_label_id:
                preds_mean_list[i].append(label_map[total_preds_mean[i][j]])
                out_id_list[i].append(out_label_ids[i][j])
                preds_id_mean_list[i].append(total_preds_mean[i][j])
                preds_max_list[i].append(label_map[total_preds_max[i][j]])
                preds_id_max_list[i].append(total_preds_max[i][j])

    # *Calculate per-entity F1 score
    total_correct = 0.
    correct_preds_mean, total_preds_mean = 0., 0.
    correct_preds_max, total_preds_max = 0., 0.

    for ground_truth_id,predicted_id_mean,predicted_id_max in zip(out_id_list,preds_id_mean_list,preds_id_max_list):
        # We use the get chunks function defined above to get the true chunks
        # and the predicted chunks from true labels and predicted labels respectively
        lab_chunks      = set(get_chunks(ground_truth_id, tag_to_id(args.data_dir, args.dataset)))
        lab_pred_mean_chunks = set(get_chunks(predicted_id_mean, tag_to_id(args.data_dir, args.dataset)))
        lab_pred_max_chunks = set(get_chunks(predicted_id_max, tag_to_id(args.data_dir, args.dataset)))
        # Updating the i variables
        correct_preds_mean += len(lab_chunks & lab_pred_mean_chunks)
        total_preds_mean   += len(lab_pred_mean_chunks)
        correct_preds_max += len(lab_chunks & lab_pred_max_chunks)
        total_preds_max   += len(lab_pred_max_chunks)
        total_correct += len(lab_chunks)

    p_mean   = correct_preds_mean / total_preds_mean if correct_preds_mean > 0 else 0
    r_mean   = correct_preds_mean / total_correct if correct_preds_mean > 0 else 0
    new_F_mean  = 2 * p_mean * r_mean / (p_mean + r_mean) if correct_preds_mean > 0 else 0

    p_max  = correct_preds_max / total_preds_max if correct_preds_max > 0 else 0
    r_max   = correct_preds_max / total_correct if correct_preds_max > 0 else 0
    new_F_max  = 2 * p_max * r_max / (p_max + r_max) if correct_preds_max > 0 else 0

    is_updated = False
    if new_F_mean > best[-1]:
        best = [p_mean, r_mean, new_F_mean]
        is_updated = True
    if new_F_max > best[-1]:
        best = [p_max, r_max, new_F_max]
        is_updated = True

    results = {
       "loss": eval_loss,
       "precision_mean": p_mean,
       "recall_mean": r_mean,
       "f1_mean": new_F_mean,
       "precision_max": p_max,
       "recall_max": r_max,
       "f1_max": new_F_max,
       "best_precision": best[0],
       "best_recall":best[1],
       "best_f1": best[-1]
    }

    logger.info("***** Eval results %s *****", prefix)
    for key in sorted(results.keys()):
        logger.info("  %s = %s", key, str(results[key]))

    return results, preds_mean_list, best, is_updated

def calculate_f1(total_correct, total_preds, correct_preds):
    p   = correct_preds / total_preds if correct_preds > 0 else 0
    r   = correct_preds / total_correct if correct_preds > 0 else 0
    F1  = 2 * p * r / (p + r) if correct_preds > 0 else 0
    return p,r,F1

def evaluate2(args, model, tokenizer, labels, pad_token_label_id, best, mode, logger, prefix="", verbose=True):
    eval_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode=mode)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running evaluation %s *****", prefix)
    if verbose:
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    all_valid = None
    num_labels = len(labels)
    model.eval()
    for batch in eval_dataloader:
        batch = tuple(t.to(args.device) for t in batch)
        valid_pos = batch[2]
        pseudo_labels = batch[3]
        with torch.no_grad():
            inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
            logits = model(**inputs)
            loss_fct = CrossEntropyLoss()
            #? todo: logits => pred labels
            tmp_eval_loss = loss_fct(logits[valid_pos>0], pseudo_labels[valid_pos>0])
            if args.n_gpu > 1:
                tmp_eval_loss = tmp_eval_loss.mean()

            eval_loss += tmp_eval_loss.item()
        nb_eval_steps += 1
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = pseudo_labels.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids, pseudo_labels.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    # print(preds)
    #? todo: preds
    preds = np.argmax(preds, axis=-1)

    label_map = {i: label for i, label in enumerate(labels)}
    preds_list = [[] for _ in range(out_label_ids.shape[0])]
    out_id_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_id_list = [[] for _ in range(out_label_ids.shape[0])]

    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if out_label_ids[i, j] != pad_token_label_id:
                preds_list[i].append(label_map[preds[i][j]])
                out_id_list[i].append(out_label_ids[i][j])
                preds_id_list[i].append(preds[i][j])

    correct_preds, total_correct, total_preds = 0., 0., 0. # i variables
    for ground_truth_id,predicted_id in zip(out_id_list,preds_id_list):
        # We use the get chunks function defined above to get the true chunks
        # and the predicted chunks from true labels and predicted labels respectively
        lab_chunks      = set(get_chunks(ground_truth_id, tag_to_id(args.data_dir, args.dataset)))
        lab_pred_chunks = set(get_chunks(predicted_id, tag_to_id(args.data_dir, args.dataset)))
        # Updating the i variables
        correct_preds += len(lab_chunks & lab_pred_chunks)
        total_preds   += len(lab_pred_chunks)
        total_correct += len(lab_chunks)

    p   = correct_preds / total_preds if correct_preds > 0 else 0
    r   = correct_preds / total_correct if correct_preds > 0 else 0
    new_F  = 2 * p * r / (p + r) if correct_preds > 0 else 0

    is_updated = False
    if new_F > best[-1]:
        best = [p, r, new_F]
        is_updated = True

    results = {
       "loss": eval_loss,
       "precision": p,
       "recall": r,
       "f1": new_F,
       "best_precision": best[0],
       "best_recall":best[1],
       "best_f1": best[-1]
    }

    logger.info("***** Eval results %s *****", prefix)
    for key in sorted(results.keys()):
        logger.info("  %s = %s", key, str(results[key]))

    return results, preds_list, best, is_updated