import collections as col
import glob
import os
import sys

import scripts.eval as ev
import numpy as np


def summarize(rep_dir):
    models = []
    rep_summary = []
    for rep in glob.glob(rep_dir + '/*/train.json'):
        model_dir = os.path.dirname(rep)
        model = os.path.basename(model_dir)
        models.append((model, model_dir))
    for model, model_dir in sorted(models, key=lambda x: x[0]):
        one_model_summary = ev.summarize(model_dir)
        rep_summary.append((model, one_model_summary))
    return rep_summary


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print "Usage: python complex_auc.py REP_DIR"
        sys.exit(0)

    rep_dir = sys.argv[1]
    rep_summary = summarize(rep_dir)
    num_reps = len(rep_summary)
    tests = {}
    measurement_names = [
        'median_mean_aucs', 'median_mean_aps', 'median_mean_accs',
        'full_mean_aucs', 'full_mean_aps', 'full_mean_accs',
        'full_aucs', 'full_aps', 'full_accs']
    for model, one_model_summary in rep_summary:
        for test_name, one_test_summary in one_model_summary.items():
            if test_name == 'val_loss':
                continue
            if test_name not in tests:
                tests[test_name] = col.defaultdict(list)
            complex = one_test_summary[20]
            for mn in measurement_names:
                tests[test_name][mn].append(complex[mn])

    print "Format is (ROC AUC/AVG PRE/ACC)"
    fmt_str = "TEST NAME" + " " * 31 + "--"
    vlosses = []
    for rep in range(num_reps):
        vlosses.append(rep_summary[rep][1]['val_loss'])
        fmt_str += " REP{:} (vloss: {:4.3f})".format(
            rep, rep_summary[rep][1]['val_loss'])
        fmt_str += " "
        fmt_str += " - "
    avg_vloss = np.mean(np.array(vlosses))
    stdev_vloss = np.std(np.array(vlosses))
    fmt_str += " {:4.3f} +/- {:4.3f} ".format(avg_vloss, stdev_vloss)
    print fmt_str
    of_interest = ('median_mean_aucs', 'median_mean_aps', 'median_mean_accs')
    for test_name in ('DB5-update-unbound-heavy6',
                      'DB5-update-bound-heavy6',
                      'DB4-small-cleaned-unbound-heavy6',
                      'DB4-small-cleaned-bound-heavy6'):
        fmt_str = "{:<40}--".format(test_name,)
        for rep in range(num_reps):
            for mn in of_interest:
                if test_name not in tests:
                    print "{:} not in tests".format(test_name)
                    continue
                if mn not in tests[test_name]:
                    print "{:} not in tests[{:}]".format(mn, test_name)
                    continue
                if len(tests[test_name][mn]) != num_reps:
                    print "not enough reps in tests[{:}][{:}]".format(
                        test_name, mn)
                    continue
                fmt_str += " {:4.3f} ".format(tests[test_name][mn][rep])
            fmt_str += " - "
        avg_auc = np.mean(np.array(tests[test_name]['median_mean_aucs']))
        stdev_auc = np.std(np.array(tests[test_name]['median_mean_aucs']))
        fmt_str += " {:4.3f} +/- {:4.3f} ".format(avg_auc, stdev_auc)
        print fmt_str
