""" 2. Run causal discovery experiment based on data from DataGenerator. """
import os
import sys
import numpy as np
import pandas as pd
import pickle as pk
from datetime import datetime
from tabulate import tabulate
import gc
import time
from pathlib import Path
from . import utils

np.set_printoptions(suppress=False)  # allow scientific notation
np.set_printoptions(threshold=sys.maxsize)  # don't truncate big arrays
N_EPISODES = int(1e4)


class ExperimentRunner:
    def __init__(self, opt):
        """
        Run algorithms on datasets, save data in output_dir
        Args:
            datasets(list): datasets (c.f. utils)
            base_dir(String): load datasets from here!
            overwrite_prev(String): Overwrite "_res" folder
        """
        self.opt = opt

        # parent dir of all the experiments
        self.exp_dir = os.path.join(
            opt.base_dir, os.path.basename(opt.base_dir) + opt.exp_name
        )
        self.input_folder = os.path.join(self.exp_dir, "_data")
        self.output_folder = os.path.join(self.exp_dir, "_res")
        utils.create_folder(self.output_folder, opt.overwrite)

        # gather all data files, only load when necessary
        self.data_files = [i for i in list(Path(self.input_folder).rglob("*.pk"))]
        self.git_hash = "123"

    def get_empty_results(self):
        return {
            "algorithm": [],
            "hyperparams": [],
            "scaler": [],
            "scaling_factors": [],
            "dataset_parameters": [],
            "dataset_hash": [],
            "edge_weight_range": [],
            "n_nodes": [],
            "n_obs": [],
            "random_seed": [],
            "graph_type": [],
            "edge_type": [],
            "x": [],
            "noise_dist": [],
            "noise_sigma_dist": [],
            "noise_sigma_lims": [],
            "start_time": [],
            "runtime": [],
            "git_hash": [],
            "W_true": [],
            "W_est": [],
            "vars": [],
            "scaled_vars": [],
            "varsortability": [],
            "R2": [],
            "R2sortability": [],
            "CEVsortability": [],
        }

    def _log_results_local(self, func, data_file, *args, **kwargs):
        self._log_results(func, data_file, *args, **kwargs)

    def _log_results(self, func, data_file, *args, **kwargs):
        """
        For each single dataset: Common functionality for all DAG-fitting methods:
            1. thresholding
            2. parallelization
            3. logging
        """
        dataset = pk.load(open(data_file, "rb"))
        # create results folder
        exp_combo_name = data_file.parent.name
        exp_combo_folder = os.path.join(self.output_folder, exp_combo_name)
        utils.create_folder(exp_combo_folder)

        results = self.get_empty_results()
        algo_name = func.__name__
        if dataset.parameters.random_seed == 0:
            print(algo_name.upper(), dataset.parameters)

        # actual function
        start_time = str(datetime.now())
        start = time.time()
        W_est = func(self, dataset, *args, **kwargs)
        end = time.time()
        runtime = end - start
        W_est = W_est.astype("float64")

        # decoration
        results["algorithm"].append(algo_name)
        results["hyperparams"].append(kwargs)
        results["scaler"].append(dataset.scaler)
        results["scaling_factors"].append([dataset.scaling_factors])
        results["dataset_parameters"].append(utils.dataset_parameters(dataset))
        results["dataset_hash"].append(dataset.hash)
        results["start_time"].append(start_time)
        results["runtime"].append(np.round(runtime, 0))
        results["git_hash"].append(self.git_hash)
        results["W_true"].append([np.copy(dataset.W_true)])
        results["W_est"].append([np.copy(W_est)])
        results["vars"].append([dataset.vars])
        results["scaled_vars"].append([dataset.scaled_vars])
        results["varsortability"].append(dataset.varsortability)
        results["R2"].append([dataset.R2])
        results["R2sortability"].append(dataset.R2sortability)
        results["CEVsortability"].append(dataset.CEVsortability)
        for k, v in dataset.parameters._asdict().items():
            results[k].append(v)

        # write file
        input_fname = data_file.stem
        fname = f"{func.__name__}_{input_fname}_results.csv"
        fpath = os.path.join(exp_combo_folder, fname)
        self.collect_results(
            results, fpath, display=dataset.parameters.random_seed == 0
        )

        del dataset
        gc.collect()

    def _algo_decorator(func):
        """Apply each DAG estimation function to all .pk files in base_dir, log and save result."""

        def wrapper(self, *args, **kwargs):
            for data_file in self.data_files:
                print(data_file)
                self._log_results_local(func, data_file, *args, **kwargs)

        return wrapper

    @_algo_decorator
    def sortnregressIC(self, dataset):
        from CausalDisco.baselines import var_sort_regress

        return var_sort_regress(dataset.data)

    @_algo_decorator
    def sortnregressIC_R2(self, dataset):
        from CausalDisco.baselines import r2_sort_regress

        return r2_sort_regress(dataset.data)

    @_algo_decorator
    def randomregressIC(self, dataset):
        from CausalDisco.baselines import random_sort_regress

        return random_sort_regress(dataset.data)

    def keep_newest_only(self, df):
        """Make sure we only keep latest run"""
        df.sort_values(by=["algorithm", "start_time"], ascending=True, inplace=True)
        df.drop_duplicates(
            subset=["algorithm", "dataset_hash"], keep="last", inplace=True
        )
        return df

    def collect_results(self, res_dict, path, save=True, display=False):
        """Save and print result dataframe"""
        res_df = pd.DataFrame(res_dict)
        res_df = self.keep_newest_only(res_df)
        if save:
            with np.printoptions(threshold=sys.maxsize, suppress=False):
                res_df.to_csv(path, index=False)
        if display:
            display_cols = ["algorithm", "dataset_parameters", "start_time", "runtime"]
            print(tabulate(res_df[display_cols], headers="keys", tablefmt="psql"))
