from __future__ import division
import numpy as np
import pandas as pd
from sklearn.calibration import calibration_curve
from sklearn.metrics import (brier_score_loss, precision_score, recall_score,
                             f1_score)
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn import metrics
from sklearn.calibration import calibration_curve
from sklearn.metrics import (brier_score_loss, precision_score, recall_score,
                             f1_score)
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.calibration import CalibratedClassifierCV, calibration_curve
from sklearn.model_selection import train_test_split
import scipy.stats
from scipy import stats
import pickle
import seaborn as sns
import imp
imp.reload(plt); imp.reload(sns)
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import roc_curve, auc
import copy

'''
Helpers for assessing AUC, XAUCS
'''


'''
# from
#https://www.ibm.com/developerworks/community/blogs/jfp/entry/Fast_Computation_of_AUC_ROC_score?lang=en
# AUC-ROC = | {(i,j), i in pos, j in neg, p(i) > p(j)} | / (| pos | x | neg |)
# The equivalent version of this is, Pr [ LEFT > RIGHT ]
# now Y_true is group membership (of positive examples) , not positive level
'''
def fast_auc(y_true, y_prob):
    y_true = np.asarray(y_true)
    y_true = y_true[np.argsort(y_prob)] #sort the predictions first
    nfalse = 0
    auc = 0
    n = len(y_true)
    for i in range(n): # visit the examples in increasing order of predictions.
        y_i = y_true[i]
        nfalse += (1 - y_i) # negative (RIGHT) examples seen so far
        auc += y_i * nfalse # Each time we see a positive (LEFT) example we add the number of negative examples we've seen so far
    auc /= (nfalse * (n - nfalse))
    return auc
'''
cross_auc for the Ra0 > Rb1 error
function takes in scores for (a,0), (b,1)
'''
def cross_auc(R_a_0, R_b_1):
    scores = np.hstack(np.asarray([R_a_0, R_b_1]))
    y_true = np.zeros(len(R_a_0)+len(R_b_1))
    y_true[0:len(R_a_0)] = 1 # Pr[ LEFT > RIGHT]; Y = 1 is the left (A0)
    return fast_auc(y_true, scores)

'''
Use the Delong method to compute conf intervals on AUC
'''
def cross_auc_delong(R_a_0, R_b_1, alpha=0.95):
    scores = np.hstack(np.asarray([R_a_0, R_b_1]))
    y_true = np.zeros(len(R_a_0)+len(R_b_1))
    y_true[0:len(R_a_0)] = 1 # Pr[ LEFT > RIGHT]; Y = 1 is the left (A0)
    auc, auc_cov = delong_roc_variance(y_true,scores)

    auc_std = np.sqrt(auc_cov)
    lower_upper_q = np.abs(np.array([0, 1]) - (1 - alpha) / 2)
    ci = stats.norm.ppf(lower_upper_q,loc=auc,scale=auc_std)
    ci[ci > 1] = 1; ci[ci < 0] = 0 # truncate interval
    return [auc, ci]

'''
Get the cross AUCs (assuming A \in \{ 0,1 \} )
'''
def get_cross_auc_delong(Rhat, Y, A, alpha=0.95):
    Rhat_a_0 = Rhat[(A==0)&(Y==0)] # a (0) is black
    Rhat_b_1 = Rhat[(A==1)&(Y==1)] # b (1) is white
    Rhat_b_0 = Rhat[(A==1)&(Y==0)] # b is white
    Rhat_a_1 = Rhat[(A==0)&(Y==1)] # a is black
    # What's the probability that a black innocent is misranked above an actually offending white?
    [xauc_b1_a0, ci_b1_a0] = cross_auc_delong(Rhat_b_1, Rhat_a_0 )
    # What's the probability that a white innocent is misranked above an actually offending black?
    [xauc_a1_b0, ci_a1_b0] = cross_auc_delong(Rhat_a_1, Rhat_b_0)
    return [xauc_b1_a0, ci_b1_a0, xauc_a1_b0, ci_a1_b0]


''' Report AUCs for each level of class
'''
def get_AUCs(Rhat, Y, A):
    class_levels = np.unique(A); AUCs = np.zeros(len(class_levels))
    for ind,a in enumerate(class_levels):
        fpr, tpr, thresholds = metrics.roc_curve(Y[A==a], Rhat[A==a], pos_label=1)
        AUCs[ind] = metrics.auc(fpr,tpr)
    return AUCs
''' Report AUCs + finite conf interval for each level of class
'''
def get_AUCs_delong(Rhat,Y,A, alpha = 0.95):
    class_levels = np.unique(A); AUCs = np.zeros(len(class_levels)); CIs = [None]*len(class_levels)
    for ind,a in enumerate(class_levels):
        auc, auc_cov = delong_roc_variance(Y[A==a],Rhat[A==a])
        auc_std = np.sqrt(auc_cov)
        lower_upper_q = np.abs(np.array([0, 1]) - (1 - alpha) / 2)
        ci = stats.norm.ppf(lower_upper_q,loc=auc,scale=auc_std)
        ci[ci > 1] = 1; ci[ci < 0] = 0 # truncate interval
        AUCs[ind] = auc
        CIs[ind] = ci
    return [AUCs, CIs]
'''
Get the cross AUCs (assuming A \in \{ 0,1 \} )
'''
def get_cross_aucs(Rhat, Y, A, quiet=True, stump="def", save=False):
    Rhat_a_0 = Rhat[(A==0)&(Y==0)] # a (0) is black
    Rhat_b_1 = Rhat[(A==1)&(Y==1)] # b (1) is white
    Rhat_b_0 = Rhat[(A==1)&(Y==0)] # b is white
    Rhat_a_1 = Rhat[(A==0)&(Y==1)] # a is black
    if not quiet:
        try:
            fig = plt.figure(figsize=(7,3))#  plt.figure(figsize=(3,3))
    # Densities kde
            ax1 = plt.subplot(121)
            sns.set_style("white")
            sns.kdeplot(Rhat_a_0, shade=True, color = 'r', label='A=a, Y=0', clip = (0,1))
            sns.kdeplot(Rhat_b_1, shade=True, color = 'b', label='A=b, Y=1', clip = (0,1))
            plt.xlim((0,1))
            plt.title(r'KDEs of $R_A^Y$ for XAUC')
    # Normed histograms
    #         plt.hist(Rhat_b_1, alpha=0.5, color='blue', label='A=b, Y=1', density=True)
    #         plt.hist(Rhat_a_0, alpha=0.5, color='red', label='A=a, Y=0', density=True)
            plt.legend()
    #         plt.figure(figsize=(3,3))
            plt.subplot(122, sharey = ax1)
            sns.kdeplot(Rhat_b_0, shade=True, color = 'r', label='A=b, Y=0', clip = (0,1))
            sns.kdeplot(Rhat_a_1, shade=True, color = 'b', label='A=a, Y=1', clip = (0,1))
            plt.xlim((0,1))
            plt.title(r'KDEs of $R_A^Y$ for XAUC')
    #         plt.hist(Rhat_b_0, alpha=0.5, color='blue', label='A=b,Y=0', density=True)
    #         plt.hist(Rhat_a_1, alpha=0.5, color='red', label='A=a, Y=1', density=True)
            plt.legend()
            if save:
                plt.savefig('figs/'+stump+'KDEs.pdf')
            plt.close('all')
        except:
            pass
            print "exception getting kde plots"
# 1/7 flip to be an accuracy
    # # What's the probability that a black innocent is ranked above an actually offending white?
    Rhatb1_cross_Rhata0 = cross_auc(Rhat_b_1, Rhat_a_0)
    # # What's the probability that a white innocent is ranked above an actually offending black?
    Rhata1_cross_Rhatb0 = cross_auc(Rhat_a_1, Rhat_b_0)
    return [Rhatb1_cross_Rhata0,Rhata1_cross_Rhatb0]
    # # What's the probability that a black innocent is misranked above an actually offending white?
    # Rhata0_cross_Rhatb1 = cross_auc(Rhat_a_0, Rhat_b_1)
    # # What's the probability that a white innocent is misranked above an actually offending black?
    # Rhatb0_cross_Rhata1 = cross_auc(Rhat_b_0, Rhat_a_1)
    # return [Rhata0_cross_Rhatb1,Rhatb0_cross_Rhata1]

'''
Get the balanced-cross AUCs (assuming A \in \{ 0,1 \} )
'''
def get_balanced_cross_aucs(Rhat, Y, A):
    Rhat_a_0 = Rhat[(A==0)&(Y==0)] # a (0) is black
    Rhat_b_1 = Rhat[(A==1)&(Y==1)] # b (1) is white
    Rhat_b_0 = Rhat[(A==1)&(Y==0)] # b is white
    Rhat_a_1 = Rhat[(A==0)&(Y==1)] # a is black
    Rhat_1 = Rhat[Y==1]; Rhat_0 = Rhat[Y==0]
# 1/7 flip to be an accuracy

    Rhatb1_cross_Rhat0 = cross_auc(Rhat_b_1, Rhat_0)
    Rhata1_cross_Rhat0 = cross_auc(Rhat_a_1, Rhat_0)

    Rhat1_cross_Rhata0 = cross_auc(Rhat_1, Rhat_a_0)
    Rhat1_cross_Rhatb0 = cross_auc(Rhat_1, Rhat_b_0)
    return [Rhatb1_cross_Rhat0,Rhata1_cross_Rhat0, Rhat1_cross_Rhata0,Rhat1_cross_Rhatb0]



''' # Get the calibration curves
'''
def get_calib_curves(Rhat, Y, A, A_labels, stump = 'def', quiet = True, save = False):
    if not quiet:
        fig = plt.figure(figsize=(3,3))
    clf_scores = np.zeros(len(np.unique(A)))
    for ind,a in enumerate(np.unique(A)):
        clf_scores[ind] = brier_score_loss(Y[A==a], Rhat[A==a], pos_label=Y.max())
        fraction_of_positives, mean_predicted_value = calibration_curve(Y[A==a], Rhat[A==a] , n_bins = 10)
        if not quiet:
            plt.plot(mean_predicted_value, fraction_of_positives, "s-",
                     label="A=%s (%1.3f)" % (A_labels[a], clf_scores[ind]))
    if not quiet:
        plt.legend()
        plt.title('Calibration curves')
        if save:
            plt.savefig('figs/'+stump+'-calibration-curves.pdf')
        plt.close('all')

    return [clf_scores, mean_predicted_value, fraction_of_positives]


def get_roc(n_thresh, Rhat, Y, compute_thresh=True, precomputed_thresh = None):
    # Range over clf_scores
    if compute_thresh:
        maxr = max(Rhat)
        minr = min(Rhat)
        thresholds = np.linspace(maxr,minr,n_thresh)
    else:
        thresholds = precomputed_thresh
    ROC = np.zeros((n_thresh,2))
    for i in range(n_thresh):
        t = thresholds[i]
        # Classifier / label agree and disagreements for current threshold.
        TP_t = np.logical_and( Rhat > t, Y==1 ).sum()
        TN_t = np.logical_and( Rhat <=t, Y==0 ).sum()
        FP_t = np.logical_and( Rhat > t, Y==0 ).sum()
        FN_t = np.logical_and( Rhat <=t, Y==1 ).sum()
        # Compute false positive rate for current threshold.
        FPR_t = FP_t / float(FP_t + TN_t)
        ROC[i,0] = FPR_t
        # Compute true  positive rate for current threshold.
        TPR_t = TP_t / float(TP_t + FN_t)
        ROC[i,1] = TPR_t
    return ROC

''' Cross ROC defined for Ra0 > R1b
(permute identity of a,b, to compute the other way)
Returns a XROC with FPR on Y axis, TPR on X axis
Assume Rhat_a, Rhat_b are already separate subsets of A=a, A=b
'''
def get_cross_roc(n_thresh, Rhat_a, Rhat_b, Y_a, Y_b, A, compute_thresh=True, precomputed_thresh = None):
    # Range over clf_scores
    if compute_thresh:
        maxr = max(max(Rhat_a), max(Rhat_b))
        minr = min(min(Rhat_a), min(Rhat_b))
        thresholds = np.linspace(maxr,minr,n_thresh)
    else:
        thresholds = precomputed_thresh
    XROC = np.zeros((n_thresh,2))
    for i in range(n_thresh):
        t = thresholds[i]
        # Classifier / label agree and disagreements for current threshold.
        TP_t_b = np.logical_and( Rhat_b > t, Y_b==1 ).sum()
        TN_t_a = np.logical_and( Rhat_a <=t, Y_a==0 ).sum()
        FP_t_a = np.logical_and( Rhat_a > t, Y_a==0 ).sum()
        FN_t_b = np.logical_and( Rhat_b <=t, Y_b==1 ).sum()
        # Compute false positive rate for current threshold.
        FPR_t_a = FP_t_a*1.0 / (FP_t_a + TN_t_a)
        # Compute true  positive rate for current threshold.
        TPR_t_b = TP_t_b*1.0 / (TP_t_b + FN_t_b)
# previous version with errors
        # XROC[i,1] = FPR_t_a
        # XROC[i,0] = TPR_t_b
        XROC[i,1] = TPR_t_b
        XROC[i,0] = FPR_t_a
    return XROC



''' For now assumes that A \in \{0,1\}
'''
def get_rocs_xrocs(Rhat, Y, A, classes, n_thresh, compute_thresh=True, precomputed_thresh = None):
    ROCs_A = [None] * len(np.unique(A))
    ROCs = [ get_roc(n_thresh, Rhat[A==a], Y[A==a], compute_thresh, precomputed_thresh) for a in np.unique(A) ]
    XROC = get_cross_roc(n_thresh, Rhat[A==0], Rhat[A==1], Y[A==0], Y[A==1], A, compute_thresh, precomputed_thresh)
    XROC_backwards = get_cross_roc(n_thresh, Rhat[A==1], Rhat[A==0],Y[A==1], Y[A==0], A, compute_thresh, precomputed_thresh)
    return [ROCs, XROC, XROC_backwards]

''' Modify to handle partition well
'''
def get_balanced_cross_roc(n_thresh, Rhat_a, Rhat_b, Y_a, Y_b, A):
    thresholds = np.linspace(1,0,n_thresh)
    XROC = np.zeros((n_thresh,2))
    for i in range(n_thresh):
        t = thresholds[i]
        # Classifier / label agree and disagreements for current threshold.
        TP_t_b = np.logical_and( Rhat_b > t, Y_b==1 ).sum()
        TN_t_a = np.logical_and( Rhat_a <=t, Y_a==0 ).sum()
        FP_t_a = np.logical_and( Rhat_a > t, Y_a==0 ).sum()
        FN_t_b = np.logical_and( Rhat_b <=t, Y_b==1 ).sum()
        # Compute false positive rate for current threshold.
        FPR_t_a = FP_t_a*1.0 / (FP_t_a + TN_t_a)
        # Compute true  positive rate for current threshold.
        TPR_t_b = TP_t_b*1.0 / (TP_t_b + FN_t_b)
# previous version with errors
        # XROC[i,1] = FPR_t_a
        # XROC[i,0] = TPR_t_b
        XROC[i,1] = TPR_t_b
        XROC[i,0] = FPR_t_a
    return XROC

def plot_ROCS(ROCs, XROC, XROC_backwards, classes, A, stump = 'def', ROCs_only=False, save = False):
    if ROCs_only:
        fig = plt.figure(figsize=(3,3))
    else:
        fig = plt.figure(figsize=(6,3))
        plt.subplot(121)#
        plt.tight_layout()
#     plt.figure(figsize=(3,3))
    [ plt.plot(ROCs[a][:,0], ROCs[a][:,1], label = classes[a]) for a in range(len(np.unique(A))) ]
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.legend()
    if save:
        plt.savefig('figs/'+stump+'ROC.pdf')
    plt.title('ROC curve')
    if not ROCs_only:
        plt.subplot(122) #
    #     plt.figure(figsize=(3,3))
        # [Rhatb1_cross_Rhata0,Rhata1_cross_Rhatb0]
        plt.plot(XROC[:,0], XROC[:,1], label = r'$R_b^1 > R_a^0$', color = 'blue')
        plt.plot(XROC_backwards[:,0], XROC_backwards[:,1], label = r'$R_a^1 > R_b^0$', color = 'red')
        plt.xlabel('FPR')
        plt.ylabel('TPR')
        plt.title(r'XROC curve')
    plt.legend()
    if save:
        plt.savefig('figs/'+stump+'XROC.pdf')
    # plt.close('all')

def plot_ROC_comparison(ROC1, ROC2, stump = 'def', save = False):
    fig =plt.figure(figsize=(3,3))
    plt.plot(ROC1[:,0], ROC1[:,1], label = 'unadjusted' , color = 'purple')
    plt.plot(ROC2[:,0], ROC2[:,1], label = 'adjusted', linestyle = '--', color = 'purple', alpha = 0.5)
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    if save:
        plt.savefig('figs/'+stump+'ROC-adjustment-comparison.pdf')
    plt.close('all')

def plot_XROCs_comparison(XROC, XROC_backwards,XROC_adj, XROC_backwards_adj, classes, A=None,type_='XROC', stump = 'def', save = False):
    fig = plt.figure(figsize=(3,3))
    plt.tight_layout()
    if (type_ == 'XROC'):
        label_1 = r'$R_b^1 > R_a^0$'; label_2 = r'$R_a^1 > R_b^0$'
    else:
        label_1 = r'$R_a^1 > R_a^0$'; label_2 = r'$R_b^1 > R_b^0$'
    plt.plot(XROC[:,0], XROC[:,1], label = label_1 , color = 'blue')
    plt.plot(XROC_backwards[:,0], XROC_backwards[:,1], label = label_2, color = 'red')
    plt.plot(XROC_adj[:,0], XROC_adj[:,1], label = label_1+', adj', linestyle='--', color = 'blue')
    plt.plot(XROC_backwards_adj[:,0], XROC_backwards_adj[:,1], label = label_2+', adj',  linestyle='--', color = 'red')

    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.title(type_ + r' curve adjustment')
    plt.legend()
    if save:
        plt.savefig('figs/'+stump+type_+'s-comparison.pdf', bbox_inches='tight')
    plt.close('all')

''' Get balanced xroc curves
'''
def get_balanced_xrocs(Rhat, Y, A, classes, n_thresh):
    XROC_1part_a = get_cross_roc(n_thresh, Rhat[A==0], Rhat, Y[A==0], Y, A)
    XROC_1part_b_bwds = get_cross_roc(n_thresh, Rhat[A==1], Rhat, Y[A==1], Y, A)
    XROC_0part_a = get_cross_roc(n_thresh, Rhat, Rhat[A==0], Y, Y[A==0], A)
    XROC_0part_b_bwds = get_cross_roc(n_thresh, Rhat, Rhat[A==1], Y, Y[A==1], A)
    return [XROC_1part_a, XROC_1part_b_bwds, XROC_0part_a, XROC_0part_b_bwds]

''' Plot balanced ROC curves
'''
def plot_balanced_ROCS(XROC_1part_a, XROC_1part_b_bwds, XROC_0part_a, XROC_0part_b_bwds, classes, A, stump = 'def', save = False):
    fig = plt.figure(figsize=(6.5,3))
    plt.subplot(121)#
#     plt.figure(figsize=(3,3)) # pattern: keep blue as pattern closer to error for group a
    plt.plot(XROC_1part_a[:,0], XROC_1part_a[:,1], label = r'$R^1 > R_a^0$', color = 'blue')
    plt.plot(XROC_1part_b_bwds[:,0], XROC_1part_b_bwds[:,1], label = r'$R^1 > R_b^0$', color = 'red')
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.title(r'XROC1 curve')
    plt.legend()

    plt.subplot(122) #
    plt.plot(XROC_0part_a[:,0], XROC_0part_a[:,1], label = r'$R_a^1 > R^0$', color = 'blue')
    plt.plot(XROC_0part_b_bwds[:,0], XROC_0part_b_bwds[:,1], label = r'$R_b^1 > R^0$', color = 'red')
    plt.xlabel('FPR')
    plt.ylabel('TPR')

    plt.title(r'XROC0 curve')
    plt.legend()
    if save:
        plt.savefig('figs/'+stump+'XROC01.pdf')
    plt.close('all')

'''return LR model
'''
def get_lr(X,Y):
    clf = LogisticRegression(); clf.fit(X,Y)
    Rhat = clf.predict_proba(X)[:,1]
    return [clf, Rhat]

'''
!!! Main helper function
Print diagnostics for given score Rhat,
Get AUCs, XAUCs; ROC curves
'''
def get_diagnostics(Rhat, X, A, Y,labels, n_thresh, save=False,stump="default", calib=True, quiet = False):
    if calib:
        [briers, mean_predicted_value, fraction_of_positives] = get_calib_curves(Rhat, Y, A, labels, stump="stump",quiet=quiet, save=save)
    else:
        [briers, mean_predicted_value, fraction_of_positives] = [0,0,0]
    [AUCs, AUCs_CIs] = get_AUCs_delong(Rhat, Y, A)
    print 'AUCs',[ (AUCs[i], AUCs_CIs[i], labels[i]) for i in range(len(np.unique(A))) ]
    [ROCs, XROC, XROC_backwards] = get_rocs_xrocs(Rhat, Y, A, labels, n_thresh)
    if (not quiet):
        plot_ROCS(ROCs, XROC, XROC_backwards, labels, A, stump = stump, save=save)
# previous order
    #[Rhata0_cross_Rhatb1,Rhatb0_cross_Rhata1] = get_cross_aucs(Rhat, Y,A, quiet=False, stump = stump, save=save)
    [Rhatb1_cross_Rhata0,Rhata1_cross_Rhatb0] = get_cross_aucs(Rhat, Y,A, quiet=quiet, stump = stump, save=save)
    print 'XAUCs', [Rhatb1_cross_Rhata0,Rhata1_cross_Rhatb0]
    balanced_XAUCs = get_balanced_cross_aucs(Rhat, Y,A)
    # [Rhatb1_cross_Rhat0,Rhata1_cross_Rhat0, Rhat1_cross_Rhata0,Rhat1_cross_Rhatb0]
    # plot balanced rocs
    [XROC_1part_a, XROC_1part_b_bwds, XROC_0part_a, XROC_0part_b_bwds] = get_balanced_xrocs(Rhat, Y, A, labels, n_thresh)
    if (not quiet):
        plot_balanced_ROCS(XROC_1part_a, XROC_1part_b_bwds, XROC_0part_a, XROC_0part_b_bwds, labels, A, stump = stump, save = save)

    print 'balanced xaucs Rb1>R0,Ra1>R0; R1>Ra0, R1>Rb0 ', balanced_XAUCs
    [xauc_b1_a0, ci_b1_a0, xauc_a1_b0, ci_a1_b0] = get_cross_auc_delong(Rhat, Y, A)
    print('xauc fwds from delong', xauc_b1_a0, ci_b1_a0)
    print('xauc bwds from delong', xauc_a1_b0, ci_a1_b0)
    XAUCs=[Rhatb1_cross_Rhata0,Rhata1_cross_Rhatb0]; XCIs = [ci_b1_a0, ci_a1_b0]

    return [AUCs, AUCs_CIs, briers, ROCs, XROC, XROC_backwards, XAUCs, XCIs, balanced_XAUCs]

def get_rocs_xrocs_disc(Rhat, Y, A, labels, n_thresh, compute_thresh=True, precomputed_thresh=None):
    mean_fpr = np.linspace(1,0,1000)
    ROCs = [ get_roc(n_thresh, Rhat[A==a], Y[A==a], compute_thresh, precomputed_thresh) for a in np.unique(A) ]
    XROC = get_cross_roc(n_thresh, Rhat[A==0], Rhat[A==1], Y[A==0], Y[A==1], A, compute_thresh, precomputed_thresh)
    XROC_backwards = get_cross_roc(n_thresh, Rhat[A==1], Rhat[A==0],Y[A==1], Y[A==0], A, compute_thresh, precomputed_thresh)
    ROCss = [ np.vstack([mean_fpr, np.interp(mean_fpr, ROCs[a][:,0], ROCs[a][:,1] ) ]) for a in np.unique(A) ]

    XROC_ = np.vstack([mean_fpr, np.interp(mean_fpr, XROC[:,0], XROC[:,1] ) ])
    XROC_backwards_ = np.vstack([mean_fpr, np.interp(mean_fpr, XROC_backwards[:,0], XROC_backwards[:,1] ) ])
    [XROC_1part_a, XROC_1part_b_bwds, XROC_0part_a, XROC_0part_b_bwds] = get_balanced_xrocs(Rhat, Y, A, labels, n_thresh)
    XROC_1part_a_ = np.vstack([mean_fpr, np.interp(mean_fpr, XROC_1part_a[:,0], XROC_1part_a[:,1] ) ])
    XROC_1part_b_bwds_ = np.vstack([mean_fpr, np.interp(mean_fpr, XROC_1part_b_bwds[:,0], XROC_1part_b_bwds[:,1] ) ])
    XROC_0part_a_ = np.vstack([mean_fpr, np.interp(mean_fpr, XROC_0part_a[:,0], XROC_0part_a[:,1] ) ])
    XROC_0part_b_bwds_ = np.vstack([mean_fpr, np.interp(mean_fpr, XROC_0part_b_bwds[:,0], XROC_0part_b_bwds[:,1] ) ])

    balanced_XROCs = [XROC_1part_a_, XROC_1part_b_bwds_, XROC_0part_a_, XROC_0part_b_bwds_]
    return [ROCss, XROC_, XROC_backwards_, balanced_XROCs]

def get_metrics_quiet(Rhat, X, A, Y,labels, n_thresh, save=False,stump="default", calib=True, quiet = True):
    briers = np.zeros(len(np.unique(A)))
    for ind,a in enumerate(np.unique(A)):
        briers[ind] = brier_score_loss(Y[A==a], Rhat[A==a], pos_label=Y.max())
    [AUCs, AUCs_CIs] = get_AUCs_delong(Rhat, Y, A)
    # get interp_rocs_xrocs at fixed fpr discretization
    [ROCs, XROC, XROC_backwards, balanced_XROCs] = get_rocs_xrocs_disc(Rhat, Y, A, labels, n_thresh)
    [Rhatb1_cross_Rhata0,Rhata1_cross_Rhatb0] = get_cross_aucs(Rhat, Y,A, quiet=quiet, stump = stump, save=save)
    balanced_XAUCs = get_balanced_cross_aucs(Rhat, Y,A)

    [xauc_b1_a0, ci_b1_a0, xauc_a1_b0, ci_a1_b0] = get_cross_auc_delong(Rhat, Y, A)
    XAUCs=[Rhatb1_cross_Rhata0,Rhata1_cross_Rhatb0]; XCIs = [ci_b1_a0, ci_a1_b0]

    return [AUCs, AUCs_CIs, briers, ROCs, XROC, XROC_backwards, XAUCs, XCIs, balanced_XAUCs, balanced_XROCs]

'''assume ROCS_ is n_samp x N_THRESH x 2; [fpr,tpr] '''
def get_bootstrapped_rocs(ROCS_,label, color='blue', stump='def'):
    mean_fpr = np.mean(ROCS_[:,:,0], axis=0)
    mean_tpr = np.mean(ROCS_[:,:,1], axis=0)
    mean_tpr[0] = 1.0
    # mean_auc = auc(mean_fpr, mean_tpr)
    # std_auc = np.std(aucs)
    std_tpr = np.std(ROCS_[:,:,1], axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2)#,label=r'$\pm$ 1 std. dev.')
    plt.plot(mean_fpr, mean_tpr, color=color,
         label=label, #(AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
         lw=1)

    return [mean_fpr, mean_tpr, std_tpr]


def get_calibrated_isotonic(clf, X_train,X_test, y_train):
    clf_isotonic = CalibratedClassifierCV(clf, cv=2, method='isotonic') #clf is base estimator
    clf_isotonic.fit(X_train, y_train)
    prob_pos_isotonic = clf_isotonic.predict_proba(X_test)[:, 1]
    return [ clf_isotonic, prob_pos_isotonic ]

def get_calirated_sigmoid(clf, X_train,X_test, y_train):
    clf_sigmoid = CalibratedClassifierCV(clf, cv=2, method='sigmoid') #clf is base estimator
    clf_sigmoid.fit(X_train, y_train)
    prob_pos_sigmoid = clf_sigmoid.predict_proba(X_test)[:, 1]
    return [ clf_sigmoid, prob_pos_sigmoid ]

'''
'''


# AUC comparison adapted from
# https://github.com/Netflix/vmaf/
def compute_midrank(x):
    """Computes midranks.
    Args:
       x - a 1D numpy array
    Returns:
       array of midranks
    """
    J = np.argsort(x)
    Z = x[J]
    N = len(x)
    T = np.zeros(N, dtype=np.float)
    i = 0
    while i < N:
        j = i
        while j < N and Z[j] == Z[i]:
            j += 1
        T[i:j] = 0.5*(i + j - 1)
        i = j
    T2 = np.empty(N, dtype=np.float)
    # Note(kazeevn) +1 is due to Python using 0-based indexing
    # instead of 1-based in the AUC formula in the paper
    T2[J] = T + 1
    return T2


def compute_midrank_weight(x, sample_weight):
    """Computes midranks.
    Args:
       x - a 1D numpy array
    Returns:
       array of midranks
    """
    J = np.argsort(x)
    Z = x[J]
    cumulative_weight = np.cumsum(sample_weight[J])
    N = len(x)
    T = np.zeros(N, dtype=np.float)
    i = 0
    while i < N:
        j = i
        while j < N and Z[j] == Z[i]:
            j += 1
        T[i:j] = cumulative_weight[i:j].mean()
        i = j
    T2 = np.empty(N, dtype=np.float)
    T2[J] = T
    return T2


def fastDeLong(predictions_sorted_transposed, label_1_count, sample_weight):
    if sample_weight is None:
        return fastDeLong_no_weights(predictions_sorted_transposed, label_1_count)
    else:
        return fastDeLong_weights(predictions_sorted_transposed, label_1_count, sample_weight)


def fastDeLong_weights(predictions_sorted_transposed, label_1_count, sample_weight):
    """
    The fast version of DeLong's method for computing the covariance of
    unadjusted AUC.
    Args:
       predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
          sorted such as the examples with label "1" are first
    Returns:
       (AUC value, DeLong covariance)
    Reference:
     @article{sun2014fast,
       title={Fast Implementation of DeLong's Algorithm for
              Comparing the Areas Under Correlated Receiver Oerating Characteristic Curves},
       author={Xu Sun and Weichao Xu},
       journal={IEEE Signal Processing Letters},
       volume={21},
       number={11},
       pages={1389--1393},
       year={2014},
       publisher={IEEE}
     }
    """
    # Short variables are named as they are in the paper
    m = label_1_count
    n = predictions_sorted_transposed.shape[1] - m
    positive_examples = predictions_sorted_transposed[:, :m]
    negative_examples = predictions_sorted_transposed[:, m:]
    k = predictions_sorted_transposed.shape[0]

    tx = np.empty([k, m], dtype=np.float)
    ty = np.empty([k, n], dtype=np.float)
    tz = np.empty([k, m + n], dtype=np.float)
    for r in range(k):
        tx[r, :] = compute_midrank_weight(positive_examples[r, :], sample_weight[:m])
        ty[r, :] = compute_midrank_weight(negative_examples[r, :], sample_weight[m:])
        tz[r, :] = compute_midrank_weight(predictions_sorted_transposed[r, :], sample_weight)
    total_positive_weights = sample_weight[:m].sum()
    total_negative_weights = sample_weight[m:].sum()
    pair_weights = np.dot(sample_weight[:m, np.newaxis], sample_weight[np.newaxis, m:])
    total_pair_weights = pair_weights.sum()
    aucs = (sample_weight[:m]*(tz[:, :m] - tx)).sum(axis=1) / total_pair_weights
    v01 = (tz[:, :m] - tx[:, :]) / total_negative_weights
    v10 = 1. - (tz[:, m:] - ty[:, :]) / total_positive_weights
    sx = np.cov(v01)
    sy = np.cov(v10)
    delongcov = sx / m + sy / n
    return aucs, delongcov


def fastDeLong_no_weights(predictions_sorted_transposed, label_1_count):
    """
    The fast version of DeLong's method for computing the covariance of
    unadjusted AUC.
    Args:
       predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples]
          sorted such as the examples with label "1" are first
    Returns:
       (AUC value, DeLong covariance)
    Reference:
     @article{sun2014fast,
       title={Fast Implementation of DeLong's Algorithm for
              Comparing the Areas Under Correlated Receiver Oerating
              Characteristic Curves},
       author={Xu Sun and Weichao Xu},
       journal={IEEE Signal Processing Letters},
       volume={21},
       number={11},
       pages={1389--1393},
       year={2014},
       publisher={IEEE}
     }
    """
    # Short variables are named as they are in the paper
    m = label_1_count
    n = predictions_sorted_transposed.shape[1] - m
    positive_examples = predictions_sorted_transposed[:, :m]
    negative_examples = predictions_sorted_transposed[:, m:]
    k = predictions_sorted_transposed.shape[0]

    tx = np.empty([k, m], dtype=np.float)
    ty = np.empty([k, n], dtype=np.float)
    tz = np.empty([k, m + n], dtype=np.float)
    for r in range(k):
        tx[r, :] = compute_midrank(positive_examples[r, :])
        ty[r, :] = compute_midrank(negative_examples[r, :])
        tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :])
    aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n
    v01 = (tz[:, :m] - tx[:, :]) / n
    v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m
    sx = np.cov(v01)
    sy = np.cov(v10)
    delongcov = sx / m + sy / n
    return aucs, delongcov


def calc_pvalue(aucs, sigma):
    """Computes log(10) of p-values.
    Args:
       aucs: 1D array of AUCs
       sigma: AUC DeLong covariances
    Returns:
       log10(pvalue)
    """
    l = np.array([[1, -1]])
    z = np.abs(np.diff(aucs)) / np.sqrt(np.dot(np.dot(l, sigma), l.T))
    return np.log10(2) + scipy.stats.norm.logsf(z, loc=0, scale=1) / np.log(10)


def compute_ground_truth_statistics(ground_truth, sample_weight):
    assert np.array_equal(np.unique(ground_truth), [0, 1])
    order = (-ground_truth).argsort()
    label_1_count = int(ground_truth.sum())
    if sample_weight is None:
        ordered_sample_weight = None
    else:
        ordered_sample_weight = sample_weight[order]

    return order, label_1_count, ordered_sample_weight


def delong_roc_variance(ground_truth, predictions, sample_weight=None):
    """
    Computes ROC AUC variance for a single set of predictions
    Args:
       ground_truth: np.array of 0 and 1
       predictions: np.array of floats of the probability of being class 1
    """
    order, label_1_count, ordered_sample_weight = compute_ground_truth_statistics(
        ground_truth, sample_weight)
    predictions_sorted_transposed = predictions[np.newaxis, order]
    aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count, ordered_sample_weight)
    assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers"
    return aucs[0], delongcov

'''plotting helpers '''
''' assume reversed order
'''
def plot_thresholds(t_as, t_bs, G_a0_tilde, eq_op_threshs, stump = 'def', save = False):
    fig = plt.figure(figsize=(3,3))
    t_as=t_as[::-1];t_bs=t_bs[::-1]
    plt.plot(t_as, eq_op_threshs)
    plt.plot(t_as, t_as, alpha =0.5)
    plt.xlabel(r'$\theta_b$'); plt.ylabel(r'$\theta_a$')
    if save:
        plt.savefig('figs/'+stump+'threshold_eqop_comparison.pdf')
    plt.close('all')

# given ROCs: plot TPR, FPR over thresholds for each group
def plot_G_thresholds(ROC, Rhat, n_thresh, classes, save=False, stump='def'):
    fig = plt.figure(figsize=(3,3));
    maxr = max(Rhat); minr = min(Rhat)
    thresholds = np.linspace(maxr,minr,n_thresh)
    alphas = [1, 0.5]; colors = ['green', 'purple']; groups = ['a', 'b']; lses = ['--','-']
    [ plt.plot(thresholds, ROC[ind][:,1], color = colors[ind],linestyle='--', alpha = 1, label =r'$G_1^'+groups[ind]+'$') for ind in range(len(classes)) ]
    [ plt.plot(thresholds, ROC[ind][:,0], color = colors[ind], alpha = 0.75, label =r'$G_0^'+groups[ind]+'$'  ) for ind in range(len(classes)) ]
    plt.legend()
    plt.xlabel('thresholds'); plt.ylabel(r'$G_{0}^A(t),G_{1}^A(t)$')
    plt.tight_layout()
    if save:
        plt.savefig('figs/'+stump+"-G-comparison.pdf", bbox_inches='tight')


# a is descending order; call binary search to find index of insertion on a
def searchsorted_rev(a, t):
    return np.clip((len(a) - np.searchsorted(-1*a, -1*t) ),0, len(a)-1)


'''
# Adjusting group b with equality of opportunity adjustment
# Return Gtilde of the transformed score (corresponding to eqop adjustment)
# Gtilde( tb ) for group b thresholds
# Also return theta_b(theta_a), for all group a thresholds
'''
def get_G_b0_tilde(ROC_a, ROC_b, Rhat, A, Y, n_thresh, classes):
    Ga1 = ROC_a[:,1]; Ga0 = ROC_a[:,0]; Gb1 = ROC_b[:,1]; Gb0 = ROC_b[:,0];
    Rhat_transformed_total = copy.deepcopy(Rhat)
    G_b0_tilde = np.zeros(len(Gb0)) # Range over clf_scores
    threshes_byclass = [ np.linspace(max(Rhat[A==ind]),min(Rhat[A==ind]),n_thresh) for ind in range(len(classes)) ]
    t_as = threshes_byclass[0]; t_bs = threshes_byclass[1];
    t_as_to_ind = dict(zip(t_as, range(len(t_as)) ))#reverse lookup for convenience
    # Note that generally thresholds might be defined differently for Ga, Gb
    # thresholds in reverse order
    for ind,t in enumerate(t_bs):
        #\tilde{G}_b^0(t) = G_b^0( G_b^{-1}( G_a^1(t)) )
        t_a_closest_tb = t_as[searchsorted_rev(t_as, t)]
        G_b1_inv_of_Ga = np.argmin( np.abs(Gb1 - Ga1[t_as_to_ind[t_a_closest_tb] ] ) )
        G_b0_tilde[ind] = Gb0[ G_b1_inv_of_Ga ]
    G_b0_tilde = G_b0_tilde[::-1] # reverse to account for reverse enumeration of thresholds
    R_a1 = copy.deepcopy(Rhat[A==1])
    Rhat_transformed = np.zeros(len(R_a1));
    # get corresponding transformed score h(r, A=b) = (G_a^{1})^{-1}(G_b^1 (r))
    for i in range(len(Rhat_transformed)):
        r = R_a1[i]
        Rhat_transformed[i] = 1-t_as[ np.argmin( np.abs( Ga1 - Gb1[ searchsorted_rev(t_bs, r) ]  )) ]
    Rhat_transformed_total[A==1] = Rhat_transformed
    # get t_bs corresponding to eqop adjustment of t_as
    eq_op_tbs = np.zeros(len(t_as))
    for ind, t_a in enumerate(t_as):
        # get Gb1 inverse Ga1[t_a]
        eq_op_tbs[ind] = t_bs[ np.argmin( np.abs(Gb1-Ga1[ ind ] )) ]
    eq_op_tbs = eq_op_tbs[::-1]
    return [G_b0_tilde, eq_op_tbs, Rhat_transformed_total, t_as, t_bs]

'''assume ta is unadjusted threshold; eq opt thresh is eqopt theta_b
Expect R_a > ta = R_b > tb
'''
def check_same_tpr_behavior(t_as, eq_op_thresh, Rhat, Rhat_transformed, A, Y, n_thresh, classes):
    threshes_byclass = [ np.linspace(max(Rhat[A==ind]),min(Rhat[A==ind]),n_thresh) for ind in range(len(classes)) ]
    t_as = threshes_byclass[0]; t_bs = threshes_byclass[1];
    ROC_a = get_roc(n_thresh, Rhat[A==0], Y[A==0], compute_thresh=True, precomputed_thresh = t_as)
    ROC_b = get_roc(n_thresh, Rhat[A==1], Y[A==1], compute_thresh=True, precomputed_thresh = t_as)
    ROC_b_transformed = get_roc(n_thresh, Rhat_transformed[A==1], Y[A==1], compute_thresh=True, precomputed_thresh = t_as)
    plt.figure(figsize=(3,3))
    plt.plot(t_as, ROC_a[:,1], label=r'$R_a>\theta_a$', color = 'blue')
    plt.plot(t_as, ROC_b[:,1], label=r'$R_b>\theta_a$', color = 'purple')
    plt.plot(t_as, ROC_b_transformed[:,1], label=r'$h(R_b)>\theta_a$', linestyle='--', color = 'r'); plt.title('Check tpr behavior')
    plt.legend()

def get_eq_op_adjustments(data, Rhat_train, Rhat_test, config, quiet = False, save = False):
    [x_train, x_test, y_train, y_test, A_train, A_test] = data
    [labels, N_THRESH, name, transform_name] = config
    threshs = np.linspace(1,0, N_THRESH)
    [ROCs, XROC, XROC_backwards] = get_rocs_xrocs(Rhat_train, y_train, A_train, labels, N_THRESH, compute_thresh=False, precomputed_thresh=threshs)#     unadj_ROC = get_roc(N_THRESH, Rhat_test, y_test);
    # get xauc and curves before and after transformation
    # Adjust group a (reverse arguments and reverse group labeling)
    [G_a0_tilde, eq_op_threshs, Rhat_train_transformed, t_as, t_bs] = get_G_b0_tilde(ROCs[1], ROCs[0], Rhat_train, (1-A_train), y_train, N_THRESH, labels)
    # assume thresholds on 0,1 for now for equal opportunity adjustment
    [ROCs, XROC, XROC_backwards] = get_rocs_xrocs(Rhat_train, y_train, A_train, labels, N_THRESH, compute_thresh=False, precomputed_thresh=threshs)
    [ROCs_adj, XROC_adj, XROC_backwards_adj] = get_rocs_xrocs(Rhat_train_transformed, y_train, A_train, labels, N_THRESH, compute_thresh=False, precomputed_thresh=threshs)

    unadj_ROC = get_roc(N_THRESH, Rhat_train, y_train); adj_ROC = get_roc(N_THRESH, Rhat_train_transformed, y_train)

    if not quiet:
        fig = plt.figure(figsize=(3,3))
        plot_G_thresholds(ROCs, Rhat_train, N_THRESH, labels, save=True)
        [ROCs_adj, XROC_adj, XROC_backwards_adj] = get_rocs_xrocs(Rhat_train_transformed, y_train, A_train, labels, N_THRESH)
        plt.tight_layout()
        plt.plot( t_bs, ROCs_adj[0][:,0], color = 'purple', alpha = 0.85, label = r'$\tilde{G}_a^0$');
        plt.plot( t_bs, ROCs_adj[0][:,1], color = 'purple', alpha = 0.5, label = r'$\tilde{G}_a^1$'); plt.legend()
        if save:
            plt.savefig('figs/'+name+"--LR--eqop-G-comparison.pdf", bbox_inches='tight')
        plot_thresholds(t_as, t_bs, G_a0_tilde, eq_op_threshs, save=save, stump=name+str("--LR--eqop") )
    #     [G_b0_tilde, eq_op_threshs, t_as, t_bs] = get_G_b0_tilde(ROCs[0], ROCs[1], Rhat_train, A_train, y_train, N_THRESH, labels) #     plt.plot(t_as, eq_op_threshs)
        # Plot XROC, ROC comparison of the eq op transformation for all thresholds
        plt.close('all')
        fig = plt.figure(figsize=(3,3));
        plt.hist(Rhat_train[(y_train == 0)&(A_train==0)] - Rhat_train_transformed[(y_train == 0)&(A_train==0)], normed = True); plt.legend(); plt.title(r'$R_a - h(R_a)$');
        if save:
            plt.savefig('figs/'+name+"--LR--eqop--adjusted-delta.pdf")
        plt.close('all')
### Potting diagnostics
        # plt.figure(figsize=(3,3));
        # plt.hist(Rhat_train_transformed[(y_train == 0)&(A_train==0)] , alpha = 0.5)
        # plt.hist(Rhat_train[(y_train == 0)&(A_train==0)] , alpha = 0.5,label=r'unadj, $R\mid A=0,Y=0$'); plt.legend()
        check_same_tpr_behavior(t_as, eq_op_threshs, Rhat_train, Rhat_train_transformed, 1-A_train, y_train, N_THRESH, labels)
        plot_XROCs_comparison(XROC, XROC_backwards,XROC_adj, XROC_backwards_adj, labels, stump=name+str("--LR--eqop") , save = save)
        plot_ROC_comparison(unadj_ROC, adj_ROC, stump = name+'--LR--eqop', save = save)
    # get statistics on the test set
    [ROCs, XROC, XROC_backwards] = get_rocs_xrocs(Rhat_test, y_test, A_test, labels, N_THRESH)#     unadj_ROC = get_roc(N_THRESH, Rhat_test, y_test);
    [G_a0_tilde_test, eq_op_threshs_test, Rhat_test_transformed, t_as, t_bs] = get_G_b0_tilde(ROCs[1], ROCs[0], Rhat_test, (1-A_test), y_test, N_THRESH, labels)
    XAUCs_test=get_cross_aucs(Rhat_test_transformed, y_test, A_test, quiet = True)
    return [ Rhat_train_transformed, Rhat_test_transformed, t_as, t_bs, XAUCs_test, G_a0_tilde, eq_op_threshs  ]


def transform_logistic(Rhat, A, a_group, alpha, beta):
    '''Helper function to transform a_group of score with parameters alpha, beta'''
    Rhat_copy = copy.deepcopy(Rhat)
    transformed_score = 1./(1 + np.exp(-1*(alpha*Rhat_copy[A==a_group] + beta)) )
    Rhat_copy[A==a_group] = transformed_score
    return Rhat_copy

def transform_logistic_fn(Rhat, A, a_group, alpha, beta):
    return 1./(1 + np.exp(-1*(alpha*Rhat[A==a_group] + beta)))

''' Given data (X,Y,A) and score Rhat;
find the logistic transform optimizing the XAUC disparity
and return the optimal parameters
equalize balance 1: equalize balanced xauc that partitions on 1
equalize balance 0: equalize balanced xauc that partitions on 0
'''
def opt_xauc_disparity_logistic_transform(train_data, Rhat_train, equalize_balance_1=False, equalize_balance_0=False):
    [x_train, y_train, A_train] = train_data
    # initialize parameter space to iterate over
    n_param=50;aucs_ = np.zeros([n_param,n_param])
    alphas = np.linspace(2,7,n_param);
    betas = [-2]
    xaucs_ = np.zeros([n_param,len(betas),2]);
    Rhat_train_copy = copy.deepcopy(Rhat_train)
    for ind_a,alpha in enumerate(alphas):
        for ind_b,beta in enumerate(betas):
            Rhat_train_copy[A_train==0] = transform_logistic_fn(Rhat_train, A_train, 0, alpha, beta)
            #1./(1 + np.exp(-1*(alpha*Rhat_train[A_train==0] + beta))) # change score of group A
            [Rhatb1_cross_Rhat0,Rhata1_cross_Rhat0, Rhat1_cross_Rhata0,Rhat1_cross_Rhatb0] = get_balanced_cross_aucs(Rhat_train_copy, y_train, A_train)
            if (not equalize_balance_1) and (not equalize_balance_0):
                xaucs_[ind_a,ind_b,:] = get_cross_aucs(Rhat_train_copy, y_train, A_train)
            elif equalize_balance_1 and (not equalize_balance_0):
                xaucs_[ind_a,ind_b,:] = [Rhatb1_cross_Rhat0,Rhata1_cross_Rhat0]
            elif (not equalize_balance_1) and (equalize_balance_0):
                xaucs_[ind_a,ind_b,:] = [Rhat1_cross_Rhata0,Rhat1_cross_Rhatb0]
            else:
                xaucs_[ind_a,ind_b,:] = [np.abs(Rhatb1_cross_Rhat0-Rhata1_cross_Rhat0),-1*np.abs(Rhat1_cross_Rhata0-Rhat1_cross_Rhatb0)]

            aucs_[ind_a,ind_b] = delong_roc_variance(y_train,Rhat_train_copy)[0]
    # compute whole dataset XAUC #     print 'score calibrated rb', clf_calib_rb.score(Rhat_test.reshape(-1, 1), y_test)
    index = (np.abs(xaucs_[:,:,0] - xaucs_[:,:,1])).argmin()
    indices = [index,0] #     indices = np.unravel_index( (np.abs(xaucs_[ind,:,:,0] - xaucs_[ind,:,:,1])).argmin(), [n_param, n_param] )
    alphastar = alphas[indices[0]]; betastar = betas[indices[1]]
    transformed_score = transform_logistic(Rhat_train, A_train, 0, alphastar, betastar) #1./(1 + np.exp(-1*(alphastar*Rhat_train[A_train==0] + betastar)) )
    print indices,  alphas[index]
    print np.min( np.abs(xaucs_[:,:,0] - xaucs_[:,:,1]) )
    print aucs_[indices[0], indices[1]]
    return [transformed_score, alphastar, betastar, xaucs_, alphas, betas]

''' Apply the transform
expit( a (logit(p)) + b )
which transforms from [0,1] -> [0,1]
'''
def transform_logit_fn(Rhat, A, a_group, alpha, beta):
    logit = np.log( Rhat[A==a_group]/(1-Rhat[A==a_group]) )
    return 1./(1 + np.exp(-1*(alpha* logit+ beta)))
def transform_logit(Rhat, A, a_group, alpha, beta):
    '''Helper function to transform a_group of score with parameters alpha, beta'''
    Rhat_copy = copy.deepcopy(Rhat)
    logit = np.log( Rhat_copy[A==a_group]/(1-Rhat_copy[A==a_group]) )
    transformed_score = 1./(1 + np.exp(-1*(alpha*logit + beta)) )
    Rhat_copy[A==a_group] = transformed_score
    return Rhat_copy

''' Given data (X,Y,A) and score Rhat;
find the logistic transform optimizing the XAUC disparity
and return the optimal parameters
'''
def opt_xauc_disparity_logit_transform(train_data, Rhat_train):
    [x_train, y_train, A_train] = train_data
    # initialize parameter space to iterate over
    n_param=50;aucs_ = np.zeros([n_param,n_param])
    alphas = np.linspace(0.1,5,n_param);
    betas = np.linspace(-2,2,n_param)
    xaucs_ = np.zeros([n_param,len(betas),2]);
    Rhat_train_copy = copy.deepcopy(Rhat_train)
    for ind_a,alpha in enumerate(alphas):
        for ind_b,beta in enumerate(betas):
            Rhat_train_copy[A_train==0] = transform_logit_fn(Rhat_train, A_train, 0, alpha, beta)
            #1./(1 + np.exp(-1*(alpha*Rhat_train[A_train==0] + beta))) # change score of group A
            xaucs_[ind_a,ind_b,:] = get_cross_aucs(Rhat_train_copy, y_train, A_train)
            aucs_[ind_a,ind_b] = delong_roc_variance(y_train,Rhat_train_copy)[0]
    # compute whole dataset XAUC #     print 'score calibrated rb', clf_calib_rb.score(Rhat_test.reshape(-1, 1), y_test)
    index = (np.abs(xaucs_[:,:,0] - xaucs_[:,:,1])).argmin()
    indices = np.unravel_index( (np.abs(xaucs_[:,:,0] - xaucs_[:,:,1])).argmin(), [n_param, n_param] )
    alphastar = alphas[indices[0]]; betastar = betas[indices[1]]
    transformed_score = transform_logit_fn(Rhat_train, A_train, 0, alphastar, betastar) #1./(1 + np.exp(-1*(alphastar*Rhat_train[A_train==0] + betastar)) )
    print indices,  alphastar, betastar
    print np.min( np.abs(xaucs_[:,:,0] - xaucs_[:,:,1]) )
    print aucs_[indices[0], indices[1]]
    return [transformed_score, alphastar, betastar, xaucs_, alphas, betas]

''' Given data (X,Y,A), score Rhat, and transformed score;
plot comparison plots
'''
def plot_comparisons_of_transformed_score(data, Rhat_train, Rhat_transformed_train,Rhat_test, Rhat_transformed_test, config, transform_name='def', save = False):
    [x_train, x_test, y_train, y_test, A_train, A_test] = data;[labels, N_THRESH, name, transform_name] = config
    [ROCs_orig, XROC_orig, XROC_backwards_orig] = get_rocs_xrocs(Rhat_train, y_train, A_train, labels, N_THRESH)
    [ROCs, XROC, XROC_backwards] = get_rocs_xrocs(Rhat_transformed_train, y_train, A_train, labels, N_THRESH)
    plot_ROCS(ROCs, XROC, XROC_backwards, labels, A_train)
    ### Plot XROCS vs. original ROCs for this curve
    [ROCs_adj, XROC_adj, XROC_backwards_adj] = get_rocs_xrocs(Rhat_transformed_test, y_test, A_test, labels, N_THRESH)
    plot_G_thresholds(ROCs_orig, Rhat_train, N_THRESH, labels, save=False)
    plt.plot( np.linspace(max(Rhat_train[(A_train==0)]), min(Rhat_train[(A_train==0)]), N_THRESH) , ROCs_adj[0][:,0] , color = 'red', alpha = 1, label = r'$\tilde{G}_0^a$'); plt.legend()
    plt.plot( np.linspace(max(Rhat_train[(A_train==0)]), min(Rhat_train[(A_train==0)]), N_THRESH) , ROCs_adj[0][:,1] , color = 'red',linestyle='--', alpha = 1, label = r'$\tilde{G}_1^a$'); plt.legend()
    plt.savefig('figs/'+name+"--LR--"+transform_name+"-G-comparison.pdf")
    fig =plt.figure(figsize=(3,3)); plt.hist(Rhat_test[(y_test == 0)&(A_test==0)] - Rhat_transformed_test[(y_test == 0)&(A_test==0)], normed = True); plt.legend(); plt.title(r'$R_a - h(R_a)$'); plt.savefig('figs/'+name+"--LR--"+transform_name+"--adjusted-delta.pdf")
    plt.close('all')
    [ROCs, XROC, XROC_backwards] = get_rocs_xrocs(Rhat_test, y_test, A_test, labels, N_THRESH)
    [ROCs_adj, XROC_adj, XROC_backwards_adj] = get_rocs_xrocs(Rhat_transformed_test, y_test, A_test, labels, N_THRESH)
    unadj_ROC = get_roc(N_THRESH, Rhat_test, y_test);
    adj_ROC = get_roc(N_THRESH, Rhat_transformed_test, y_test)
    plot_XROCs_comparison(XROC, XROC_backwards,XROC_adj, XROC_backwards_adj, labels, stump=name+str("--LR-"+transform_name) , save = True)
    plot_ROC_comparison(unadj_ROC, adj_ROC, stump = name+'--LR-'+transform_name, save = save)

    balanced_XAUCs = get_balanced_cross_aucs(Rhat_transformed_test, y_test, A_test)
    # plot balanced rocs
    [XROC_1part_a, XROC_1part_b_bwds, XROC_0part_a, XROC_0part_b_bwds] = get_balanced_xrocs(Rhat_transformed_test, y_test, A_test, labels, N_THRESH)
    plot_balanced_ROCS(XROC_1part_a, XROC_1part_b_bwds, XROC_0part_a, XROC_0part_b_bwds, labels, A_test, stump=name+str("--LR-"+transform_name), save = save)

    print 'balanced xaucs Rb1>R0,Ra1>R0; R1>Ra0, R1>Rb0 ', balanced_XAUCs

    plt.close('all')
    return [ROCs, XROC, XROC_backwards, ROCs_adj, XROC_adj, XROC_backwards_adj ]
