"""
Author: Anonymous
Code for uncertainty-aware self-training for few label learning.
"""

import numpy as np
from sklearn.utils import shuffle
import random
from numpy.random import seed
import os

def sample_by_bald_difficulty(tokenizer, x, y_mean, y_var, y, num_samples, num_classes, weight=False, y_T=None):

	print ("Sampling by difficulty BALD acquisition function")

	expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0)  # [batch size]
	expected_p = np.mean(y_T, axis=0)
	entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1)  # [batch size]
	BALD_acq = entropy_expected_p - expected_entropy
	
	p_norm = BALD_acq/np.sum(BALD_acq)


	indices = np.random.choice(len(x), num_samples, p=p_norm, replace=False)

	x_s = x[indices]
	y_s = y[indices]
	w_s = y_var[indices][:,0]

	x_s, y_s, w_s = shuffle(x_s, y_s, w_s)

	return np.array(x_s), np.array(y_s), np.array(w_s)


def sample_by_bald_easiness(tokenizer, x, y_mean, y_var, y, num_samples, num_classes, weight=False, var=False, y_T=None):

	print ("Sampling by easy BALD acquisition function")

	expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0)  # [batch size]
	expected_p = np.mean(y_T, axis=0)
	entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1)  # [batch size]
	BALD_acq = entropy_expected_p - expected_entropy
	
	p_norm = np.maximum(np.zeros(len(BALD_acq)), (1. - BALD_acq)/np.sum(1. - BALD_acq))

	p_norm = p_norm / np.sum(p_norm)


	print (p_norm[:10])

	indices = np.random.choice(len(x), num_samples, p=p_norm, replace=False)

	x_s = x[indices]
	y_s = y[indices]
	w_s = y_var[indices][:,0]

	x_s, y_s, w_s = shuffle(x_s, y_s, w_s)

	return np.array(x_s), np.array(y_s), np.array(w_s)


def sample_by_bald_class_easiness(tokenizer, x, y_mean, y_var, y, num_samples, num_classes, weight=False, var=False, y_T=None):

	print ("Sampling by easy BALD acquisition function per class")

	expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0)  # [batch size]
	expected_p = np.mean(y_T, axis=0)
	entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1)  # [batch size]
	BALD_acq = entropy_expected_p - expected_entropy
	
	BALD_acq = (1. - BALD_acq)/np.sum(1. - BALD_acq)

	samples_per_class = num_samples // num_classes
	x_s = []
	y_s = []
	w_s = []

	for label in range(num_classes):
		x_ = x[y == label]
		y_ = y[y==label]
		y_var_ = y_var[y == label]		
		# p = y_mean[y == label]
		p_norm = BALD_acq[y==label]
		p_norm = np.maximum(np.zeros(len(p_norm)), p_norm)
		p_norm = p_norm/np.sum(p_norm)

		indices = np.random.choice(len(x_), samples_per_class, p=p_norm, replace=False)

		x_s.extend(x_[indices])
		y_s.extend(y_[indices])
		w_s.extend(y_var_[indices][:,0])

	x_s, y_s, w_s = shuffle(x_s, y_s, w_s)

	return np.array(x_s), np.array(y_s), np.array(w_s)


def sample_by_bald_class_difficulty(tokenizer, x, y_mean, y_var, y, num_samples, num_classes, weight=False, var=False, y_T=None):

	print ("Sampling by difficulty BALD acquisition function per class")

	expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0)  # [batch size]
	expected_p = np.mean(y_T, axis=0)
	entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1)  # [batch size]
	BALD_acq = entropy_expected_p - expected_entropy

	samples_per_class = num_samples // num_classes
	x_s = []
	y_s = []
	w_s = []
	for label in range(num_classes):
		x_ = x[y == label]
		y_ = y[y==label]
		y_var_ = y_var[y == label]		
		# p = y_mean[y == label]
		p_norm = BALD_acq[y==label]
		p_norm = np.maximum(np.zeros(len(p_norm)), p_norm)
		p_norm = p_norm/np.sum(p_norm)

		indices = np.random.choice(len(x_), samples_per_class, p=p_norm, replace=False)
		x_s.extend(x_[indices])
		y_s.extend(y_[indices])
		w_s.extend(y_var_[indices][:,0])

	x_s, y_s, w_s = shuffle(x_s, y_s, w_s)

	return np.array(x_s), np.array(y_s), np.array(w_s)
