import pandas as pd
import pickle
import numpy as np
from n_gram_sim import *
import random
from diversity_metrics import *
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
from prediction_models import *
import os
from sklearn.neighbors import KNeighborsClassifier

path ="./doc_text/embeddings/"

def prepare_all_doc_datasets(dataset_names, n_samples = 500, n_size=300, n_dims=384, normalise=True, path=path):
    dat=[]
    mod=[]
    idx=0
    for dataset_name in dataset_names:
        dataset_list, model_list = prepare_doc_datasets(dataset_name, n_samples = n_samples, n_size=n_size, n_dims=n_dims, normalise=normalise, idx=idx, path=path)
        idx+=len(dataset_list)
        dat.extend(dataset_list)
        mod.extend(model_list)
    return dat, mod

def prepare_doc_datasets(dataset_name, path=path, n_samples = 500, n_size=300, n_dims=384, normalise=True, idx=0):
    model_names=["ada-002", "mistral-embed", "all-MiniLM-L6-v2","all-distilroberta-v1", 
            "all-mpnet-base-v2", "multi-qa-distilbert-cos-v1"
            ]
    dataset_list = []
    model_list =[]
  
    indx_list=[] 
    embs_all={}
    random.seed(42)
    for r in range(n_samples): 
        all_indices = list(range(16384)) 
        indices = random.sample(all_indices, n_size) 
    for model_name in model_names:
        print(model_name)
        print(path+'embeddings_'+dataset_name+'_'+model_name+'.pkl')
        print(os.path.isfile(path+'embeddings_'+dataset_name+'_'+model_name+'.pkl'))
        with open(path+'embeddings_'+dataset_name+'_'+model_name+'.pkl', 'rb') as fp:
            emb_all= pickle.load(fp)
        if normalise:
            emb_all=normalise_emb(emb_all, num_dimensions = n_dims)
            embs_all[model_name] = emb_all
            for r in range(n_samples):
                indx_list.append(indices[r])
            dataset_list.append(dataset_name)
            model_list.append(model_name)
    return dataset_list, model_list, embs_all, indx_list

def load_doc_data(idx, embs, index_lists, 
                  nnn_samples):
    nnnn=0
    for nn in range(len(nnn_samples)):
        nnnn+=nnn_samples[nn]
        if idx<nnnn:
            continue
        else:
            sample_base=nnnn
            num=nn
    emb=embs[num]
    ind = idx-sample_base
    return emb[index_lists[num][ind]]

def normalise_emb(emb_all, num_dimensions = 100):
    embedding_matrix_normalized = normalize(emb_all, axis=1, norm='l2')

    pca = PCA(n_components=num_dimensions)
    embedding_matrix_pca = pca.fit_transform(embedding_matrix_normalized)
    embedding_matrix_pca=normalize(embedding_matrix_normalized, axis=1, norm='l2')
    return embedding_matrix_pca           

def run_doc_metrics(metrics=["cosine"], normalise=True, n_samples = 500, n_size=300, n_dims=384, resolutions=None):
    setting=str(round(n_samples))+"_"+str(round(n_size))+"_"+str(n_dims)+"_"
    #metrics=["cosine", "L2", "L1"]
    model_names=["ada-002", "mistral-embed", "all-MiniLM-L6-v2","all-distilroberta-v1", 
        "all-mpnet-base-v2", "multi-qa-distilbert-cos-v1"
        ]

    datasets=[
        "cnn_dailymail___3.0.0_16384",
        "big_patent___a_16384",
        "EdinburghNLP_-_xsum_16384",
        "gfissore_-_arxiv-abstracts-2021_16384",
        ]

    #res=0.5

    for dataset in datasets:
        these_metrics=metrics.copy()
        print(dataset)
        
        #dat, mod = prepare_doc_datasets(dataset, n_samples = n_samples, n_size=n_size, n_dims=n_dims, normalise=normalise)
        dat, mod, embs_all, indx_list = prepare_doc_datasets(dataset, n_samples = n_samples, n_size=n_size, n_dims=n_dims, normalise=normalise)
        n_all_samples=len(dat)

        if len(these_metrics)>0:
            output_path="./doc_text/"+dataset+setting

            def load_data(idx):
                print(idx)
                emb = embs_all[mod[idx]][indx_list[idx]]
                return emb
                #with open('./doc_text/embeddings/'+str(round(idx))+'.pkl', 'rb') as fp:
                #    emb = pickle.load(fp)
                #os.remove('./doc_text/embeddings/'+str(round(idx))+'.pkl')
                #return emb

            mag_results = calc_metrics_from_embeddings(load_data, n_samples=n_all_samples,  n_ts=10, metrics=these_metrics, 
                                reference_summaries = False, reference_scale=0.5, scale=True, absolute_area=True, nearest_k=10, 
                                target_scale=0.99)
            mag_results["dataset"]=dat
            mag_results["model"]=mod
            mag_results["summary_statistics"]["dataset"]=dat
            mag_results["summary_statistics"]["model"]=mod
            #mag_results["indices"]=index_lists
            mag_results["parameters"]={"n_samples":n_samples, "n_size":n_size, "n_dims":n_dims}
            #df_scores = mag_results["summary_statistics"]
            save_magnitude_results(mag_results, output_path)
            random.seed(42)
            summary_scores = prediction_task_documents(dataset, results=None, n_samples = n_samples, n_size=n_size, n_dims=n_dims)
            mag_results["prediction_results"]=summary_scores
            #pd.DataFrame(summary_scores).to_csv(output_path+"_prediction_results.csv")
            pd.DataFrame(mag_results["prediction_results"]).to_csv("./doc_text/"+dataset+"_pred_doc.csv")
            save_magnitude_results(mag_results, output_path)
                

def read_files(dataset, n_samples = 500, n_size=300, n_dims=384):
    setting=str(round(n_samples))+"_"+str(round(n_size))+"_"+str(n_dims)+"_"
    
    output_path="./doc_text/"+dataset+setting
    with open(output_path+"_magnitude_results"+'.pkl', 'rb') as fp:
        mag_results = pickle.load(fp)
    return mag_results

def prediction_documents_per_dataset():
    summary_scores={}
    datasets=[
        "cnn_dailymail___3.0.0_16384",
        "big_patent___a_16384",
        "EdinburghNLP_-_xsum_16384",
        "gfissore_-_arxiv-abstracts-2021_16384"
        ]
    for dataset in datasets:
        all_scores_values = prediction_task_documents(dataset)
        summary_scores[dataset]=all_scores_values
    return summary_scores

def get_prediction_results(summary_scores):
    Summary=pd.DataFrame()
    datasets=[
        "cnn_dailymail___3.0.0_16384",
        "big_patent___a_16384",
        "EdinburghNLP_-_xsum_16384",
        "gfissore_-_arxiv-abstracts-2021_16384"
        ]
   
    dataset_names=["CNN Dailymail", "Big Patent", "EdinburghNLP", "Arvix Abstracts", "iclr"]
    for i, dataset in enumerate(datasets):
        Summary_table=pd.DataFrame()
        sc = pd.DataFrame(summary_scores[dataset])
        #sc2 = pd.DataFrame(summary_scores2[dataset])
        #sc = pd.concat([sc, sc2], axis=0)
        #sc.columns = [c for c in  sc.columns[:10]]+ ["mag_fun", "mag_diff"]#[c+"_knn" for c in  sc.columns[:4]]+
        Summary_table["mean"]=sc.mean(axis=0)
        Summary_table["std"]=sc.std(axis=0) 
        #Summary_table["dataset"]=dataset
        Summary_table_t = Summary_table.transpose()
        Summary_table_t["dataset"]=dataset_names[i]
        Summary=pd.concat([Summary,Summary_table_t],axis=0)
    return Summary

def prediction_task_documents(dataset, results=None, n_samples = 500, n_size=300, n_dims=384):
    if results is None:
        results = read_files(dataset, n_samples = n_samples, n_size=n_size, n_dims=n_dims)
    all_scores = results["summary_statistics"]

    all_scores_values ={}

    for c in ["stds_div_zero", "neg_mean_cosine", "vendi_cosine", "mag_area_cosine"]:
        scoring="accuracy"
        df, scores_1 = classification_cross_validation_Xy(np.array(all_scores[c]).reshape(-1, 1), all_scores["model"], 
                                                        scoring=scoring,
                                                        model=LogisticRegression(),
                                                        n_splits=5)
        df["diversity"] = c
        all_scores_values[c] = scores_1

        df, scores_1 = classification_cross_validation_Xy(np.array(all_scores[c]).reshape(-1, 1), all_scores["model"], 
                                                        scoring=scoring,
                                                        model=KNeighborsClassifier(),
                                                        n_splits=5)
        df["diversity"] = c+"_knn"
        all_scores_values[c+"_knn"] = scores_1

    all_X = results["magnitude_function_dfs"]["cosine"]
    for i in range(all_X.shape[1]):
        df, scores_1 = classification_cross_validation_Xy(all_X[:,i].reshape(-1, 1), all_scores["model"], 
                                                        scoring=scoring,
                                                        model=LogisticRegression(),
                                                        n_splits=5)
        df["diversity"] = i
        all_scores_values[str(i)] = scores_1

        df, scores_1 = classification_cross_validation_Xy(all_X[:,i].reshape(-1, 1), all_scores["model"], 
                                            scoring=scoring,
                                            model=KNeighborsClassifier(),
                                            n_splits=5)
        
        df["diversity"] = i
        all_scores_values[str(i)+"_knn"] = scores_1


    mat = results["magnitude_differences"]["cosine"]
    df, scores_1 = classification_cross_validation_Xypre(mat, all_scores["model"], scoring='accuracy', 
                                                            n_splits=5, model=KNeighborsClassifier(metric="precomputed"))
    df["diversity"] = "mag_diff_knn"
    all_scores_values["mag_diff_knn"] = scores_1

    return all_scores_values


if __name__ == "__main__":
    mag_results = run_doc_metrics(metrics=["cosine"], normalise=True, n_samples = 200, n_size=300, n_dims=384)
    mag_results = run_doc_metrics(metrics=["cosine"], n_samples = 200, n_size=300, n_dims=0, normalise=False)