import numpy as np
import gc
import scipy.stats
import copy
from absl import app, flags

from games.oh_hell import OhHellGame
from algorithms.enumerate import StateEnumerator
from algorithms.gibbs import GibbsSampler
from algorithms.util import RandomJointPolicy

FLAGS = flags.FLAGS

flags.DEFINE_integer('num_tricks', 2, 'number of tricks in the game (cards per player)')
flags.DEFINE_integer('num_tricks_played', 1, 'number of tricks played before the experiment')
flags.DEFINE_integer('cards_in_current_trick', 0, 'number of cards played in current trick before the experiment')
flags.DEFINE_integer('num_suits', 2, 'number of suits in the game')
flags.DEFINE_integer('num_ranks', 4, 'number of cards per suit')
flags.DEFINE_integer('num_repeats', 1, 'number of times to repeat the experiment')
flags.DEFINE_float('policy_bias', 0.8, 'probability of the most favored action in each infostate for the joint policy')
flags.DEFINE_boolean('header_only', False, 'print the tsv header and exit')
		

def entropy_and_variance(state):
	policy = RandomJointPolicy(FLAGS.policy_bias)
	# set the number of tricks
	chance_player = state.get_player_to_move()
	state.play(chance_player, FLAGS.num_tricks)
	while not state.terminal():
		player = state.get_player_to_move()
		actions = state.get_legal_actions(player)
		if player >= 0:
			probs = policy.get_action_probabilities(state.get_infostate_string(player), len(actions))
			action = np.random.choice(actions, p=probs)
			if state.num_tricks_played() >= FLAGS.num_tricks_played and state.num_cards_played_current_trick() >= FLAGS.cards_in_current_trick:
				enumerator = StateEnumerator(state, policy)
				return enumerator.entropy(), enumerator.variance()
		else:
			action = np.random.choice(actions)
		state.play(player, action)


def main(_):
	if FLAGS.header_only:
		print('num_suits\tnum_ranks\tnum_tricks\tnum_tricks_played\tpolicy_bias\tentropy\tvariance')
		return
	game = OhHellGame(num_suits=FLAGS.num_suits, num_ranks=FLAGS.num_ranks)
	for _ in range(FLAGS.num_repeats):
		state = game.new_initial_state()
		entropy, variance = entropy_and_variance(state)
		print(f'{FLAGS.num_suits}\t{FLAGS.num_ranks}\t{FLAGS.num_tricks}\t{FLAGS.num_tricks_played}\t' +
			  f'{FLAGS.policy_bias}\t{entropy}\t{variance}')

if __name__ == "__main__":
	app.run(main)
