import argparse
import json
import numpy as np
import torch
from torch import device
import itertools 

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
    parser.add_argument("--args_to_care", default=None, type=str, help="A list of args to care about. If provided, outputs which configs were not tested in your grid search")

    # These options should be kept as their default values
    parser.add_argument("--log", type=str, default="log_hessian", help="Log path.")

    args = parser.parse_args()

    condition = eval(args.condition)
    if args.args_to_care is not None:
        args_to_care = eval(args.args_to_care)

    with open(args.log) as f:
        result_list = []
        for line in f:
            line = line.replace("<", "\"")
            line = line.replace(">", "\"")
            line = line.replace(" inf,", "float('inf'),")
            line = line.replace(" nan,", "float('nan'),")
            line = line.replace("tensor", ""); 
            result_list.append(eval(line))

    seed_result = {}

    for item in result_list:
        ok = True
        for cond in condition:
            if isinstance(condition[cond], list):
                if cond not in item or (item[cond] not in condition[cond]):
                    ok = False
                    break
            else:
                if cond not in item or (item[cond] != condition[cond]):
                    ok = False
                    break
        if ok:
            seed = item['data_dir'].split('-')[-1] + '-' + str(item['seed'])
            if seed not in seed_result:
                seed_result[seed] = [item]
            else:
                seed_result[seed].append(item)

    all_seed_result = seed_result
    all_tags = sorted(set(x['tag'] for x in sum(all_seed_result.values(), [])))
    all_k = sorted(set(x['num_k'] for x in sum(all_seed_result.values(), [])))

    for tag in all_tags:
        for k in all_k:
            print("Tag: {}, K: {}".format(tag, k))
            seed_result_with_duplicates = {
                s: list(x for x in v if x['tag'] == tag and x['num_k'] == k)
                for s, v in all_seed_result.items()
            }
            seed_result = {
                s: list({x['output_dir']: x for x in v}.values())
                for s, v in seed_result_with_duplicates.items()
            }

            ### check if all possible configs were run or not
            if args.args_to_care is not None:
                unique_arg_values = {}
                for _arg in args_to_care:
                    unique_arg_values[_arg] = []

                # collect all desired configs and all that were run 
                arg_configs = {} # all configs that were run
                for seed in seed_result.keys():
                    seed_configs = []
                    for config in seed_result[seed]:
                        if config['tag'] == tag and config['num_k'] == k:
                            _config = []
                            for _arg in args_to_care:
                                _value = config[_arg]
                                if _value not in unique_arg_values[_arg]:
                                    unique_arg_values[_arg].append(_value)
                                _config.append(_value)
                            seed_configs.append(tuple(_config))
                    arg_configs[seed] = seed_configs

                # compare to all possible configs
                missing_configs = {'seeds': []}
                for arg in args_to_care:
                    missing_configs[f'{arg}s'] = [] 

                num_missing = 0
                for seed in seed_result.keys():
                    all_possible_configs = itertools.product(*list(unique_arg_values.values()))
                    print(f'Missing configs for seed {seed}')
                    for config in all_possible_configs:
                        if config not in arg_configs[seed]:
                            missing_configs['seeds'].append(seed.split('-')[-1])
                            print(f'\t', end='')
                            for _arg, _val in zip(args_to_care, config):
                                print(f'{_arg}: {_val}', end=' ')
                                missing_configs[f'{_arg}s'].append(str(_val))
                            print()
                            num_missing += 1

                print(f'Rerun {num_missing} configs')
                for key, missed in missing_configs.items():
                    missed_str = ' '.join(missed)
                    print(f'{key}=({missed_str})')

            for i, seed in enumerate(seed_result):
                if len(seed_result[seed]) == 0:
                    continue
                try:
                    print("%s: trace (%.4f), op_norm (%.4f), effective_rank (%.4f) | total trials: %d (ignored %d)" % (
                        seed,
                        seed_result[seed][0]['trace'],
                        seed_result[seed][0]['op_norm'],
                        seed_result[seed][0]['effective_rank'],
                        len(seed_result[seed]),
                        len(seed_result_with_duplicates[seed]) - len(seed_result[seed])
                    ))
                except:
                    import pdb; pdb.set_trace()
                if len(seed_result[seed]) > 1:
                    print("WARNING! Multiple results for the same seed. Only reporting the first one. You should set a more detailed --condition.")
                s = ''
                if args.args_to_care is None:
                    hp_to_care_about = [
                        'per_device_train_batch_size',
                        'gradient_accumulation_steps',
                        'learning_rate',
                        'zero_order_eps',
                        'zero_order_sample',
                        'zero_order_sample_scheduler',
                        'scale_lr_with_samples',
                        'lr_scheduler_type',
                        'weight_decay',
                    ]
                else:
                    hp_to_care_about = args_to_care
                for k in hp_to_care_about:
                    s += '| {}: {} '.format(k, seed_result[seed][0].get(k, ""))
                print('    ' + s)

            trace = [seed_result[seed][0]['trace'] for seed in seed_result if len(seed_result[seed]) > 0]
            op_norm = [seed_result[seed][0]['op_norm'] for seed in seed_result if len(seed_result[seed]) > 0]
            effective_rank = [seed_result[seed][0]['effective_rank'] for seed in seed_result if len(seed_result[seed]) > 0]

            s = "mean +- std (num of seeds: %d)\n" % (len(trace))
            s += "trace: %.4f (%.4f)\nop_norm: %.4f (%.4f)\neffective_rank: %.4f (%.4f)" % (
                np.mean(trace), np.std(trace),
                np.mean(op_norm), np.std(op_norm),
                np.mean(effective_rank), np.std(effective_rank),
            ) 
            print(s)
            print("")

if __name__ == '__main__':
    main()
