""" some evaluation measures using classifiers for the synthetic dataset

i.e. suppose we have a dataset generated from 
    W1 ~ P(W1)
    W2 ~ P(W2)
    XY ~ f(W1,W2) approx P(X,Y | W1,W2)

then we want to evaluate the following queries:

Since:
X = [w2.digit, w2.color, w1.thick] (with noise 1-p)
Y = [x.digit,  COLOR.color, x.thick] (with noise 1-p)

(for latents, the formula for equality is 
q(p,N) := (1-p + p/N)**2 + (p/N)**2 * (N-1)
)

P[x.digit=w2.digit] = 1-p
P[x.color=w2.color] = 1-p
P[x.thick=w1.thick] = q(p, 3)
P[y.digit=x.digit] = 1-p
P[y.color=w1.color] = q(p, 6)
P[y.thick=x.thick] = 1-p


"""

import torch
from collections import defaultdict

from cfg.dataloader_pickle import PickleDataset
from torch.utils.data import DataLoader
from napkin_mnist4 import train_classifiers as tc
import argparse
from tqdm.auto import tqdm

# ================================================
# =           Helper functions                   =
# ================================================

def q(p, N):
    return (1-p + p/N) ** 2 + (p/N)**2 * (N-1)

def mixprob(p, N):
    return 1-p + p/N

@torch.no_grad()
def get_discrete_data(dataloader, classifiers, device):
    """ Gets dict like
    {W1_digit: [], W1_color: [], W1_thickness: [],
     X...
     Y...
    }
    for the whole dataset
    """
    output = {'W2a': [],'W2b': []}
    for v in ['W1', 'X', 'Y']:
        for attr in ['digit', 'color', 'thickness']:
            output['%s_%s' % (v, attr)] = []


    classifiers = {k: v.eval().to(device) for k,v in classifiers.items()}

    for batch in tqdm(dataloader):
        output['W2a'].append(batch['W2a'])
        output['W2b'].append(batch['W2b'])
        for k in ['W1', 'X', 'Y']:
            data = batch[k].to(device)

            for attr in ['digit', 'color', 'thickness']:
                name = '%s_%s' % (k, attr)
                classifier = classifiers[name]
                pred = classifier(data).max(dim=1)[1].cpu().detach()
                output[name].append(pred)

    output = {k: torch.cat(v) for k,v in output.items()}
    return output


def get_discrete_base_data(dataloader):
    KEYS = ['W1_digit', 'W1_thickness', 'W1_color', 
            'W2a', 'W2b', 
            'X_digit', 'X_color', 'X_thickness',
            'Y_digit', 'Y_color', 'Y_thickness']
    output = {k: [] for k in KEYS}
    for batch in tqdm(dataloader):
        for k in KEYS:
            output[k].append(batch[k])
    return {k: torch.cat(v) for k,v in output.items()}



def eq_check(datadict, k1, k2):
    total = datadict[k1].numel()
    return (datadict[k1] == datadict[k2]).sum() / total


def get_empirical(datadict):
    output = {}
    KEYS = [('X_digit', 'W2a'),
            ('X_color', 'W2b'),
            ('X_thickness', 'W1_thickness'),
            ('Y_digit', 'X_digit'),
            ('Y_color', 'W1_color'),
            ('Y_thickness', 'X_thickness')]
    for k1, k2 in KEYS:
        output[(k1, k2)] = eq_check(datadict, k1, k2)
    return output

def get_true(p):
    output = {}
    p_comp_probs = [('X_digit', 'W2a', 10),
                    ('X_color', 'W2b', 2),
                    ('Y_digit', 'X_digit', 10,),
                    ('Y_thickness', 'X_thickness', 3)]

    q_probs = [('X_thickness', 'W1_thickness', 3),
                ('Y_color', 'W1_color', 6)]

    for k1, k2,n in p_comp_probs:
        output[(k1, k2)] = mixprob(p, n)
    for k1,k2,n in q_probs:
        output[(k1,k2)] = q(p, n)
    return output



# ================================================
# =           Main block                         =
# ================================================

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--synth_data_pkl', type=str, required=True)
    parser.add_argument('--base_data_pkl', type=str, required=True)
    parser.add_argument('--cls_loc', type=str, required=True)
    parser.add_argument('--p', type=float, default=0.1)
    parser.add_argument('--device', type=int, required=True)


    args = parser.parse_args()
    device = 'cuda:%s' % args.device

    synth_dataset = PickleDataset(args.synth_data_pkl)
    synth_dataloader = DataLoader(synth_dataset, num_workers=8, batch_size=512, shuffle=False, drop_last=False)


    base_dataset = PickleDataset(args.base_data_pkl)
    base_dataloader = DataLoader(base_dataset, num_workers=8, batch_size=512, shuffle=False, drop_last=False)

    classifiers = tc.load_models('', args.cls_loc)
    # print(classifiers.keys())
    
    synth_pred_datadict = get_discrete_data(synth_dataloader, classifiers, device)
    base_pred_datadict = get_discrete_data(base_dataloader, classifiers, device)
    base_true_datadict = get_discrete_base_data(base_dataloader)

    synth_pred_empirical = get_empirical(synth_pred_datadict)
    base_pred_empirical = get_empirical(base_pred_datadict)
    base_true_empirical = get_empirical(base_true_datadict)
    true_vals = get_true(args.p)

    print('SynthPred | BasePred | BaseTrue | True')
    for k in synth_pred_empirical.keys():
        print('%s: %.03f | %.03f | %.03f | %.03f' % (k, synth_pred_empirical[k], 
                                                        base_pred_empirical[k],
                                                        base_true_empirical[k],
                                                        true_vals[k]))

if __name__ == '__main__':
    main()