""" Script that takes dataset with (x,y) and evaluates p(y.color | x.color).
	If base dataset is fed then this is P(Y|X)
	If the generated (final) dataset is fed, this is P_X(Y)

"""


import torch
import torch.nn as nn
from cfg.dataloader_pickle import PickleDataset
from torch.utils.data import DataLoader
from collections import defaultdict
import napkin_mnist4.train_classifiers as tc

import os
import argparse
from tqdm.auto import tqdm


@torch.no_grad()
def get_pyx(eval_dataloader, x_color_cls, y_color_cls, device):
	output = defaultdict(lambda: defaultdict(int)) # x.color -> y.color
	x_color_cls = x_color_cls.eval().to(device)
	y_color_cls = y_color_cls.eval().to(device)
	for batch in tqdm(eval_dataloader):
		x, y = batch['X'].to(device), batch['Y'].to(device)

		x_color = x_color_cls(x).max(dim=1)[1]
		y_color = y_color_cls(y).max(dim=1)[1]
		for x_c, y_c in zip(x_color, y_color):
			output[x_c.item()][y_c.item()] += 1

	for k, vdict in output.items():
		total = sum(vdict.values())
		output[k] = dict({k_: v_ / total for k_, v_ in vdict.items()})
	return dict(output)


# ===========================================
# =           MAIN BLOCK                    =
# ===========================================

def main():
	parser = argparse.ArgumentParser()
	parser.add_argument('--data_pkl', type=str, required=True)
	parser.add_argument('--batch_size', type=int, default=512)
	parser.add_argument('--num_workers', type=int, default=8)
	parser.add_argument('--cls_loc', type=str, required=True)
	parser.add_argument('--device', type=int, required=True)

	args = parser.parse_args()
	device = 'cuda:%s' % args.device

	dataset = PickleDataset(args.data_pkl)
	dataloader = DataLoader(dataset, num_workers=args.num_workers, 
							batch_size=args.batch_size, 
							shuffle=False, drop_last=False)

	classifiers = tc.load_models('', args.cls_loc)
	x_color_cls = classifiers['X_color']
	y_color_cls = classifiers['Y_color']

	output = get_pyx(dataloader, x_color_cls, y_color_cls, device)
	for k, v in output.items():
		print('X.color = %s' % k)
		for k_, v_ in sorted(v.items()):
			print('\t %s:%.03f' % (k_, v_))
	return output


if __name__ == '__main__':
	main()

