import numpy as np
import algorithms.util as util

class StateEnumerator(object):
	def __init__(self, state, policy):
		self._histories = state.generate_histories_in_public_state()
		self._policy = policy
		self._weigh_histories()
		self.name = "True"

	def _weigh_histories(self):
		self._reach_weights = {}
		norm_factor = 0.
		for i, h in enumerate(self._histories):
			reach_prob = h.get_reach_probability(self._policy)
			self._reach_weights[str(h)] = reach_prob
			norm_factor += reach_prob
		for key in self._reach_weights.keys():
			self._reach_weights[key] /= norm_factor

	def get_reach_probability(self, history):
		try:
			return self._reach_weights[str(history)]
		except ValueError:
			return 0.

	def get_reach_distribution(self):
		return self._reach_weights

	def sample_history_uniformly(self):
		return np.random.choice(self._histories)

	def mc_estimate(self, n, eval_every, player=0):
		p = [self._reach_weights[str(h)] for h in self._histories]
		sample_values = []
		estimates = []
		for i in range(n+1):
			h = np.random.choice(self._histories, p=p)
			sample_values.append(util.expected_value(h, self._policy)[player])
			if i % eval_every == 0:
				estimates.append(np.mean(sample_values))
		return estimates

	def generate_samples(self, num_samples, player=0):
		samples = []
		p = [self._reach_weights[str(h)] for h in self._histories]
		sample_histories = np.random.choice(self._histories, num_samples, p=p)
		for h in sample_histories:
			samples.append(util.expected_value(h, self._policy)[player])
		return samples

	def variance(self):
		v = 0.
		ev = self.expected_value()[0]
		for h in self._histories:
			v += self._reach_weights[str(h)] * (util.expected_value(h, self._policy)[0] - ev)**2
		return v

	def entropy(self):
		H = 0.
		for h in self._histories:
			r = self._reach_weights[str(h)]
			H += r * np.log2(r)
		return -1. * H

	def cost(self):
		return len(self._histories)

	def expected_value(self):
		v = 0.
		for h in self._histories:
			v += self._reach_weights[str(h)] * util.expected_value(h, self._policy)
		return v



