import os
import cPickle as pkl
import sys

import h5py
import numpy as np
import sklearn.metrics as sm


def summarize(roc_hdf5, list):
    cache_hdf5 = os.path.dirname(roc_hdf5) + "/cached_roc_stats.pkl"
    if os.path.exists(cache_hdf5):
        final = pkl.load(open(cache_hdf5, 'r'))
        return final

    with open(list) as f:
        content = f.readlines()
        # you may also want to remove whitespace characters like `\n` at the
        # end of each line
        complexes = sorted([x.strip() for x in content])


    with h5py.File(roc_hdf5) as f:
        uids = f['uids'][:, 0]
        labels = f['labels'][:]
        pred = f['predictions'][:]

    final = {}

    rpred = pred.reshape(pred.shape[0] / 20, 20)
    rlabels = labels.reshape(labels.shape[0] / 20, 20)
    ruids = uids.reshape(uids.shape[0] / 20, 20)
    for num_samples in [2**x for x in range(5)] + [20]:
        curr = {}
        mean_aucs = []
        max_aucs = []
        mean_aps = []
        max_aps = []
        mean_accs = []
        max_accs = []
        spred = rpred[:, 0:num_samples]
        suids = ruids[:, 0]
        slabels = rlabels[:, 0]
        avg_pred = np.mean(spred, axis=1)
        max_pred = np.max(spred, axis=1)

        for complex in complexes:
            spos = np.array(
                [i for i, x in enumerate(suids) if complex in x])
            if spos.shape[0] == 0:
                continue
            mean_auc = sm.roc_auc_score(slabels[spos], avg_pred[spos])
            max_auc = sm.roc_auc_score(slabels[spos], max_pred[spos])
            mean_ap = sm.average_precision_score(slabels[spos], avg_pred[spos])
            max_ap = sm.average_precision_score(slabels[spos], max_pred[spos])
            mean_acc = sm.accuracy_score(slabels[spos], avg_pred[spos].round())
            max_acc = sm.accuracy_score(slabels[spos], max_pred[spos].round())
            mean_aucs.append(mean_auc)
            max_aucs.append(max_auc)
            mean_aps.append(mean_ap)
            max_aps.append(max_ap)
            mean_accs.append(mean_acc)
            max_accs.append(max_acc)
        curr['median_max_aucs'] = np.median(max_aucs)
        curr['median_max_aps'] = np.median(max_aps)
        curr['median_max_accs'] = np.median(max_accs)
        curr['median_mean_aucs'] = np.median(mean_aucs)
        curr['median_mean_aps'] = np.median(mean_aps)
        curr['median_mean_accs'] = np.median(mean_accs)

        curr['full_max_aucs'] = sm.roc_auc_score(slabels, max_pred)
        curr['full_max_aps'] = sm.average_precision_score(slabels, max_pred)
        curr['full_max_accs'] = sm.accuracy_score(slabels, max_pred.round())
        curr['full_mean_aucs'] = sm.roc_auc_score(slabels, avg_pred)
        curr['full_mean_aps'] = sm.average_precision_score(slabels, avg_pred)
        curr['full_mean_accs'] = sm.accuracy_score(slabels, avg_pred.round())

        curr['full_aucs'] = sm.roc_auc_score(labels, pred)
        curr['full_aps'] = sm.average_precision_score(labels, pred)
        curr['full_accs'] = sm.accuracy_score(labels, pred.round())
        final[num_samples] = curr

    pkl.dump(final, open(cache_hdf5, 'w'))
    return final


if __name__ == '__main__':
    if len(sys.argv) < 3:
        print "Usage: python complex_auc.py roc.h5 list.txt"
        sys.exit(0)

    roc_hdf5 = sys.argv[1]
    list = sys.argv[2]

    final = summarize(roc_hdf5, list)
    curr = final[20]
    print "ALL"
    print ("TOTAL   (ROC AUC/AVG PRE/ACC) -- reg: ({:4.3f}, {:4.3f}, {:4.3f}) "
           .format(curr['full_aucs'], curr['full_aps'], curr['full_accs']))
    for num_samples in sorted(final.keys()):
        if num_samples == 'all':
            continue
        curr = final[num_samples]
        print "{:02d}-sample".format(num_samples)
        print ("COMPLEX (ROC AUC/AVG PRE/ACC) -- "
               "max: ({:4.3f}, {:4.3f}, {:4.3f}) "
               "mean: ({:4.3f}, {:4.3f}, {:4.3f})".format(
                curr['median_max_aucs'], curr['median_max_aps'],
                curr['median_max_accs'], curr['median_mean_aucs'],
                curr['median_mean_aps'], curr['median_mean_accs']))
        print ("FULL    (ROC AUC/AVG PRE/ACC) -- "
               "max: ({:4.3f}, {:4.3f}, {:4.3f}) "
               "mean: ({:4.3f}, {:4.3f}, {:4.3f})".format(
                curr['full_max_aucs'], curr['full_max_aps'],
                curr['full_max_accs'], curr['full_mean_aucs'],
                curr['full_mean_aps'], curr['full_mean_accs']))
