import numpy as np
from sklearn.feature_selection import mutual_info_classif
import torch
from torch import softmax, nn
from utils.metrics import get_kernel, get_bandwidth
from methods.calibration import NadarayaWatson

def discrete_mi_est(xs, ys, nx=2, ny=2):
    prob = np.zeros((nx, ny))
    for a, b in zip(xs, ys):
        prob[a,b] += 1.0/len(xs)
    pa = np.sum(prob, axis=1)
    pb = np.sum(prob, axis=0)
    mi = 0
    for a in range(nx):
        for b in range(ny):
            if prob[a,b] < 1e-9:
                continue
            mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))
    return max(0.0, mi)

def estimate_fcmi(masks, preds, num_examples, verbose=False):
    '''
    Estimating a fCMI through the way in Harutyunyan et. al (NeurIPS2021).
    '''
    list_of_mis = []
    num_classes = preds[0].shape[1]
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            ps[i] = num_classes * ps[i][0] + ps[i][1]
            ps[i] = ps[i].item()
        cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        list_of_mis.append(cur_mi)
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)

    return np.array(list_of_mis)

def estimate_cmi_ece_bound(train_eces, val_eces, masks):

    ece_gap = np.array([np.abs(np.mean(train_eces[i]) - np.mean(val_eces[i])) for i in range(len(train_eces))])
    list_of_mis = []
    for j in range(len(ece_gap)):
        ece = np.tile(ece_gap[j], len(masks[j])).reshape(-1,1)
        list_of_mis.append(max(0., mutual_info_classif(ece, masks[j], discrete_features=[False]).sum()))
    
    return list_of_mis

def estimate_cmi_recal_bound(ls_loss_list, masks):
    
    #ece_gap = np.array([np.abs(np.mean(train_eces[i]) - np.mean(val_eces[i])) for i in range(len(train_eces))])
    list_of_mis = []
    list_of_mis_bin = []
    for j in range(len(ls_loss_list)):
        ls = np.tile(ls_loss_list[j][0], len(masks[j])).reshape(-1,1)
        bin_ls = np.tile(ls_loss_list[j][1], len(masks[j])).reshape(-1,1)
        list_of_mis.append(max(0., mutual_info_classif(ls, masks[j], discrete_features=[False]).sum()))
        list_of_mis_bin.append(max(0., mutual_info_classif(bin_ls, masks[j], discrete_features=[False]).sum()))
    
    return list_of_mis, list_of_mis_bin


def estimate_cmi_bound(masks, preds, labels, bins, n_bins, num_examples, loss='diff', verbose=False):
    """
    Estimating our bound's value.
    """
    ## number of bins
    B = n_bins

    list_of_mis = []
    bound = 0.0
    for idx in range(num_examples):
        ms = np.array([p[idx] for p in masks])
        ps = [p[2*idx:2*idx+2] for p in preds]
        ls = [l[2*idx:2*idx+2] for l in labels]
        bs = np.array([b[idx] for b in bins])
        
        if loss == 'diff' or loss == 'reuse':
            loss_bin = torch.zeros(B, len(ps), 2)
        for i in range(len(ps)):
            if loss == "diff":
                if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
                else:
                    ps[i] = torch.max(ps[i], dim=1).values ## predictive values
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                l = ls[i] - ps[i]
                loss_bin[bs[i]][i] = l
            elif loss == 'reuse':
                loss_bin[bs[i]][i] = ls[i]
            elif loss == 'fcmi':
                f_out = torch.zeros(len(ps), 2)
                if not torch.all(torch.abs(torch.sum(ps[i], dim=1) - 1) < 1e-10):
                    ps[i] = torch.max(softmax(ps[i], 1), dim=1).values ## predictive values
                else:
                    ps[i] = torch.max(ps[i], dim=1).values ## predictive values
                ps[i][ls[i] == 0] = 1 - ps[i][ls[i] == 0] ## f(x) is the predictive prob. for y=1.
                f_out[i] = ps[i]
            else:
                raise ValueError(f"Unexpected loss type: {loss}")
        
        if loss == 'fcmi':
            cur_mi = max(0., mutual_info_classif(f_out, ms, discrete_features=[False, False]).sum())
        else:
            cur_mi = max(0., mutual_info_classif(loss_bin.reshape(-1,2), np.tile(ms, B), discrete_features=[False, False]).sum())
        list_of_mis.append(cur_mi)
        bound += cur_mi
        
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    
    if loss == 'reuse':
        bound = np.sqrt((2*(bound + B*np.log(2))) / num_examples)
    else:
        bound = np.sqrt((8*(bound + B*np.log(2))) / num_examples)
    
    return bound, list_of_mis

def estimate_semipara_bound(preds, labels, masks, num_examples, kernel='gaussian'):
    preds, NW_tr_preds, NW_te_preds = get_NWpreds_for_bound(preds, labels, masks, kernel=kernel)
    num_classes = len(labels[0].unique())

    list_of_mis = []
    for idx in range(num_examples):
        ms = np.array([p[idx] for p in masks])
        ps = [p[idx] for p in preds]
        ps_tr = [p[idx] for p in NW_tr_preds]
        ps_te = [p[idx] for p in NW_te_preds]
        ls = [l[idx] for l in labels]
        
        ls_fcmi = torch.concat([torch.concat([ps[i],ps_tr[i],ps_te[i]]) for i in range(len(ps))]).reshape(-1, 3*num_classes)
        loss = torch.zeros(len(ps), 1)
        for i in range(len(ps)):
            loss[i] = torch.norm(ls[i] - ps_tr[i]) - torch.norm(ls[i] - ps_te[i])
        cur_mi = max(0., mutual_info_classif(loss, ms, discrete_features=[False]).sum())
        cur_fcmi = max(0., mutual_info_classif(ls_fcmi, ms, discrete_features=[False]*3*num_classes).sum())
        list_of_mis.append(np.sqrt(2*cur_mi) + 4*np.sqrt(2*cur_fcmi))
    
    return list_of_mis

        
def get_NWpreds_for_bound(preds, labels, masks, kernel='gaussian'):
    NW_tr_preds = []
    NW_te_preds = []
    for i in range(len(preds)):
        NW_tr, NW_te = NadarayaWatson(kernel), NadarayaWatson(kernel)
        inv_mask = 1 - masks[i]
        tr_idx, te_idx = 2*np.arange(len(masks[i])) + masks[i], 2*np.arange(len(inv_mask)) + inv_mask
        pred_tr, pred_te = preds[i][tr_idx], preds[i][te_idx]
        ls_tr, ls_te = labels[i][tr_idx], labels[i][te_idx]
        # Fit & predict
        NW_tr.fit(pred_tr)
        NW_tr_tr, NW_tr_te = NW_tr.predict_proba(ls_tr), NW_tr.predict_proba(ls_te)
        NW_tr_preds.append(torch.concat([NW_tr_tr, NW_tr_te]))
        NW_te.fit(pred_te)
        NW_te_tr, NW_te_te = NW_te.predict_proba(ls_tr), NW_te.predict_proba(ls_te)
        NW_te_preds.append(torch.concat([NW_te_tr, NW_te_te]))

    return preds, NW_tr_preds, NW_te_preds