""" Utility functions for distillation benchmark """
import os, glob, subprocess
from autogluon_utils.benchmarking.distill_benchmark.configs import *

def get_gibbs_name(input_name):
    """ Convert original dataset name to name used in Gibbs augmented samples file.
        Returns empty dict if input_name is unknown dataset.
    """
    original_names = ALL_DATASETS
    modified_names = {}
    for i in original_names:
        modified_names[i] = str.lower(i).split('_')[0]

    if input_name in modified_names:
        return modified_names[input_name]
    else:
        print(f"{input_name} not found in modified_names")
        return input_name

def filter_gibbs_models(augfile_name_list):
    """ Returns list of only those files corresponding to smaller model """
    augfiles_to_keep = []
    for augfile in augfile_name_list:
        is_smaller = False
        is_bigger = False
        for smallstr in SMALLER_MODEL_STRINGS:
            if smallstr in augfile:
                is_smaller = True
                break
        for bigstr in BIGGER_MODEL_STRINGS:
            if bigstr in augfile:
                is_bigger = True
                break
        if is_smaller and is_bigger:
            raise ValueError(f"augfile={augfile} contains both substr for bigger and for smaller transformer")
        if not (is_smaller or is_bigger):
            raise ValueError(f"augfile={augfile} contains neither substr for bigger or smaller transformer")
        if is_smaller:
            augfiles_to_keep.append(augfile)
    return augfiles_to_keep


def s3_sync_folder(local_folder, s3_bucket, s3_prefix, to_s3=True):
    """ Syncs local folder (full path ending in slash) contents with s3://s3_bucket/s3_prefix/
        If to_s3 = False, then folders are synced from s3 folder to local folder.
    """
    s3_address = f"s3://{s3_bucket}/{s3_prefix}"
    if to_s3:
        cmd_suffix = f"{local_folder} {s3_address}"
    else:
        cmd_suffix = f"{s3_address} {local_folder}"
    cmd = f"aws s3 sync {cmd_suffix} --quiet"
    value = os.system(cmd)  # returns the exit code in unix
    if value == 0:
        if to_s3:
            direction = '-->'
        else:
            direction = '<--'
        print(f"Synced {local_folder} {direction} {s3_address}")
    else:
        print(f"Error syncing {local_folder} with {s3_address}")


def s3_transfer_file(local_path, s3_bucket, s3_prefix, to_s3=True):
    """ Preserves file-name during transfer.
        If to_s3, then local_path should end in file-name, s3_prefix should end in '/'
        If not to_s3, then s3_prefix should end in file-name, local_path should end in '/'
    """
    s3_address = f"s3://{s3_bucket}/{s3_prefix}"
    if to_s3:
        if s3_prefix[-1] != '/':
            raise ValueError("file might overwrite the s3_prefix if it does not end in '/'")
        cmd_suffix = f"{local_path} {s3_address}"
    else:
        if local_path[-1] != '/':
            raise ValueError("file might overwrite the local_path if it does not end in '/'")
        cmd_suffix = f"{s3_address} {local_path}"
    cmd = f"aws s3 cp {cmd_suffix}"
    value = os.system(cmd)  # returns the exit code in unix
    if value == 0:
        if to_s3:
            direction = 'to'
        else:
            direction = 'from'
        print(f"Copied file: {local_path} {direction} {s3_address}")
    else:
        print(f"Error syncing {local_folder} with {s3_address}")


def get_s3_files(local_folder, s3_bucket, s3_prefix, include=[], exclude=[]):
    """ Copy directory from s3 but only keep certain files/folders """
    if not os.path.exists(local_folder):
        os.mkdir(local_folder)
    s3_address = f"s3://{s3_bucket}/{s3_prefix}"
    cmd_suffix = f"{s3_address} {local_folder}"
    cmd = f'aws s3 cp {cmd_suffix} --recursive --exclude "*" '
    for include_str in include:
       cmd += f' --include "{include_str}"'
    for exclude_str in exclude:
        cmd += f' --exclude "{exclude_str}"'
    value = os.system(cmd)  # returns the exit code in unix
    if value != 0:
        raise ValueError(f"Error getting files from {s3_address}")


def get_results_files(local_folder, profile_str, tag_str, dataset_list, wait_every=5):
    """ parallelizes fetching of s3 files via subprocess.Popen.
        Waits every 'wait_every' subprocesses, reduce this value to reduce CPU load by avoiding too many simultaneous processes.
    """
    if not os.path.exists(local_folder):
        os.mkdir(local_folder)
    distill_ldr_name = 'DistillLeaderboard.csv'
    files_tofetch = [distill_ldr_name, 'metadata.csv'] # 'PredistillLeaderboard.csv',
    processes = []
    for dataset in dataset_list:
        s3_prefix = f"s3://{BUCKET}/results/{dataset}/{profile_str}/{tag_str}/"
        target_loc = f"{local_folder}{dataset}/"
        for file_name in files_tofetch:
            cmd = f"aws s3 cp {s3_prefix}{file_name} {target_loc}"
            p = subprocess.Popen(cmd.split(" "), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            if len(processes) % wait_every == 1:
                p_prev = processes[-1]
                p_prev.wait()
            processes.append(p)
    exit_codes = [p.wait() for p in processes]
    ldr_files = glob.glob(local_folder +'/**/'+distill_ldr_name, recursive=True)
    dataset_names = [s[len(local_folder):].split('/')[0] for s in ldr_files]
    return (dataset_names, ldr_files)

def SLOWERget_results_files(local_folder, profile_str, tag_str, dataset_list):
    if not os.path.exists(local_folder):
        os.mkdir(local_folder)
    distill_ldr_name = 'DistillLeaderboard.csv'
    files_tofetch = [distill_ldr_name, 'PredistillLeaderboard.csv','metadata.csv']
    for dataset in dataset_list:
        s3_prefix = f"s3://{BUCKET}/results/{dataset}/{profile_str}/{tag_str}/"
        for file_name in files_tofetch:
            cmd = f"aws s3 cp {s3_prefix}{file_name} {local_folder}"
            try:
                value = os.system(cmd)
            except Exception:
                pass
    ldr_files = glob.glob(local_folder +'/**/'+distill_ldr_name, recursive=True)
    dataset_names = [s[len(local_folder):].split('/')[0] for s in ldr_files]
    return (dataset_names, ldr_files)

def OLDERget_results_files(local_folder, profile_str, tag_str, dataset_list):
    if not os.path.exists(local_folder):
        os.mkdir(local_folder)
    distill_ldr_name = 'DistillLeaderboard.csv'
    files_tofetch = [distill_ldr_name, 'metadata.csv'] # 'PredistillLeaderboard.csv',
    s3_prefix = f"s3://{BUCKET}/results/"
    include_names = []
    for dataset in dataset_list:
        s3_folder = f"{dataset}/{profile_str}/{tag_str}/"
        for file_name in files_tofetch:
            include_names.append(s3_folder+file_name)
    cmd = f'aws s3 cp {s3_prefix} {local_folder} --recursive --exclude "*"'
    for include_name in include_names:
        cmd = cmd + f' --include "{include_name}"'
    # print(cmd)
    value = os.system(cmd)
    ldr_files = glob.glob(local_folder +'/**/'+distill_ldr_name, recursive=True)
    dataset_names = [s[len(local_folder):].split('/')[0] for s in ldr_files]
    return (dataset_names, ldr_files)

def OLDESTget_results_files(local_folder, profile_str, tag_str):
    distill_ldr_name = 'DistillLeaderboard.csv'
    include_names = [distill_ldr_name, 'metadata.csv']  # 'PredistillLeaderboard.csv',
    include_names = ['*/'+profile_str+'/'+tag_str+'/'+include_name for include_name in include_names]
    get_s3_files(local_folder, BUCKET, 'results/', include=include_names)
    ldr_files = glob.glob(local_folder +'/**/'+distill_ldr_name, recursive=True)
    dataset_names = [s[len(local_folder):].split('/')[0] for s in ldr_files]
    return (dataset_names, ldr_files)


def get_results_files_ec2(local_folder, profile_str, tag_str, dataset_list):
    """ Keeps files in s3, just returns their paths which can still be loaded with pd.read_csv()
    """
    import boto3
    from botocore.exceptions import ClientError
    s3 = boto3.client('s3')
    distill_ldr_name = 'DistillLeaderboard.csv'
    meta_name = 'metadata.csv'
    files_tofetch = [distill_ldr_name,meta_name]  # 'PredistillLeaderboard.csv'
    ldr_files = []
    meta_files = []
    prefix_ignore = f"s3://{BUCKET}/results/"
    for dataset in dataset_list:
        local_subfolder = local_folder + dataset + '/'
        s3_prefix = f"results/{dataset}/{profile_str}/{tag_str}/"
        s3_filepath = f"s3://{BUCKET}/{s3_prefix}"
        for file_name in files_tofetch:
            try:
                s3.head_object(Bucket=BUCKET, Key=s3_prefix+file_name)
                if file_name == distill_ldr_name:
                    ldr_files.append(s3_filepath + file_name)
                elif file_name == meta_name:
                    meta_files.append(s3_filepath + file_name)
            except ClientError as e:
                pass
    dataset_names = [s[len(prefix_ignore):].split('/')[0] for s in ldr_files]
    return (dataset_names, ldr_files)

def SLOWget_results_files_ec2(local_folder, profile_str, tag_str, dataset_list):
    import boto3
    from botocore.exceptions import ClientError
    s3 = boto3.client('s3')
    if not os.path.exists(local_folder):
        os.mkdir(local_folder)
    distill_ldr_name = 'DistillLeaderboard.csv'
    files_tofetch = [distill_ldr_name,'metadata.csv']  # 'PredistillLeaderboard.csv'
    for dataset in dataset_list:
        local_subfolder = local_folder + dataset + '/'
        if not os.path.exists(local_subfolder):
            os.mkdir(local_subfolder)
        s3_prefix = f"results/{dataset}/{profile_str}/{tag_str}/"
        for file_name in files_tofetch:
            try:
                s3.head_object(BUCKET, s3_prefix+file_name)
                s3.download_file(BUCKET, s3_prefix+file_name, local_subfolder+file_name)
            except ClientError as e:
                pass
    ldr_files = glob.glob(local_folder +'/**/'+distill_ldr_name, recursive=True)
    dataset_names = [s[len(local_folder):].split('/')[0] for s in ldr_files]
    return (dataset_names, ldr_files)


def distill_model_present(model_type, distill_method, models_ran):
    present = False
    for model in models_ran:
        if model[:len(model_type)] == model_type:
            if model[-len(distill_method):] == distill_method:
                return True
            if ('dstl_GIB' in distill_method) and ('dstl_GIB' in model) and (model.split("_")[-1] == distill_method.split("_")[-1]):
                return True
    return False

def get_model_type(model_name_df, model_types):
    this_model_type = None
    for model_type in model_types:
        if model_name_df[:len(model_type)] == model_type:
            this_model_type = model_type
    return this_model_type

def get_distill_type(model_name_df, distill_methods):
    gibbs_str = 'dstl_GIB_'
    smaller_model_strings = SMALLER_MODEL_STRINGS # tags unique to the smaller transformer model
    this_distill_method = None
    for distill_method in distill_methods:
        if (gibbs_str in distill_method) and (gibbs_str in model_name_df) and (model_name_df.split("_")[-1] == distill_method.split("_")[-1]):
            is_smaller_transformer = False
            for smallstr in smaller_model_strings:
                if smallstr in model_name_df:
                    is_smaller_transformer = True
                    break
            if is_smaller_transformer:  # returns None if this not the smaller transformer model
                this_distill_method = distill_method
        elif model_name_df[-len(distill_method):] == distill_method:
            this_distill_method = distill_method
    return this_distill_method

def get_distill_name(model_name_df, distill_methods, model_types):
    this_model_type = get_model_type(model_name_df, model_types)
    this_distill_method = get_distill_type(model_name_df, distill_methods)
    if this_distill_method is not None and this_model_type is not None:
        distill_name = "_".join([this_distill_method, this_model_type])
        return distill_name
    else:
        return None

# Unused
def model_name_conversion(model_names):  # converts
    for mid in model_names:
        cpy2 = {'mhz12x'+ str(i) for i in range(6)}
        cpy2.add()
        cpy1 = {'mhz12x'+ str(i) for i in range(6,12)}
        cpy2.add('cpp')
        if mid in cpy1:
            mid = mid.replace('mhz12x','cpy1')
            print('It belongs to class of SMALL model')
            return  mid
        elif mid in cpy2:
            mid = mid.replace('mhz12x','cpy2')
            print('It belongs to class of LARGE model')
            return mid
        else:
            print('Cant map it to anything, Ask Rasool.')
            return ''







