import numpy as np
import random
import torch

import sys
import os
from copy import deepcopy
import random
from pathlib import Path

sys.path.extend(os.path.join(os.path.dirname(__file__), "../../"))

from params import get_params
from trainer import TrainerFS
from experiments.utils.data_loader_wrapper import get_dataset_wrap

letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"


def run_multiple(datasets, modified_params, output_file, n_runs=3):
    current_params = deepcopy(params)
    current_params.update(modified_params)
    val, test, epoc = [], [], []
    for i in range(n_runs):
        trnr = TrainerFS(datasets, current_params)
        best_val, best_test, best_epoch = trnr.train()
        print(f"Run {i}: best val: {best_val}, best test: {best_test}, best epoch: {best_epoch}")
        val.append(best_val)
        test.append(best_test)
        epoc.append(best_epoch)
        del trnr

    val = np.array(val)
    test = np.array(test)
    epoc = np.array(epoc)

    with open(output_file, "a") as f:
        f.write(f"Params: {modified_params}\n")
        f.write(f"full_params: {current_params}\n")
        f.write(f"Val: {val.mean()} +/- {val.std()}\n")
        f.write(f"Test: {test.mean()} +/- {test.std()}\n")
        f.write(f"Epoch: {epoc.mean()} +/- {epoc.std()}\n")
        f.write("----------------------------\n")
        f.write("all results: \n")
        f.write(f"allVal: {val}\n")
        f.write(f"allTest: {test}\n")
        f.write(f"allEpoch: {epoc}\n")


if __name__ == '__main__':
    params = get_params()

    exp_type = params["experiment_type"]
    assert exp_type in ["graphclip", "metagraph", "metagraphpluslabels", "metagraphgat3", "metagraphpluslabelsgat3"]

    print("---------Parameters---------")
    for k, v in params.items():
        print(k + ': ' + str(v))
    print("----------------------------")

    # control random seed
    if params['seed'] is not None:
        SEED = params['seed']
        torch.manual_seed(SEED)
        torch.cuda.manual_seed(SEED)
        torch.backends.cudnn.deterministic = True
        np.random.seed(SEED)
        random.seed(SEED)


def run_multiple_dataset_len_cap():
    if "gat3" not in exp_type:
        second_gnn = "gat"
    else:
        second_gnn = "gat3"
    if exp_type == "graphclip":
        zero_shot = True
    else:
        zero_shot = False
    if "pluslabels" in exp_type:
        ignore_label_emb = False
    else:
        ignore_label_emb = True
    if zero_shot:
        ignore_label_emb = False

    datasets = get_dataset_wrap(
        root=params["root"],
        dataset=params["dataset"],
        force_cache=params["force_cache"],
        small_dataset=params["small_dataset"],
        invalidate_cache=params["invalidate_cache"],
        original_features=params["original_features"],
        n_shot=params["n_shots"],
        n_query=params["n_query"],
        val_len_cap=params["val_len_cap"],
        test_len_cap=params["test_len_cap"],
        dataset_len_cap=params["dataset_len_cap"],
        n_way=params["n_way"]
    )

    Path("outputs").mkdir(parents=True, exist_ok=True)  # save results to 'outputs'

    for dslcap in [2000, 1000, 200, 100, 50, 20]:
        print("##### dslcap ", dslcap)
        ignore_label_emb_text = "IGNORELABELEMB" if ignore_label_emb else ""
        zero_shot_text = "ZEROSHOT" if zero_shot else ""
        gat3_text = "plusgat3" if "gat3" in exp_type else ""
        random_suffix = "".join(random.choice(letters) for _ in range(10))

        run_multiple(datasets, {"gnn_type": "sage", "emb_dim": 256, "early_stopping_patience": 30,
                                "n_layer": 1, "second_gnn": second_gnn,
                                "dataset_len_cap": dslcap,
                                "ignore_label_embeddings": ignore_label_emb,
                                "zero_shot": zero_shot,
                                "n_shots": 3,
                                "n_query": 24,
                                "epochs": 400},
                     n_runs=6,
                     output_file="outputs/results_sage{}_256_arxiv_original_features_DSLENCAP_{}_{}_{}_____{}.txt".format(
                         gat3_text, dslcap, ignore_label_emb_text, zero_shot_text, random_suffix)
                     )


if __name__ == "__main__":
    run_multiple_dataset_len_cap()
