r"""Define a tree env like so.

            1
           / \
          2   3
         / \ / \
        4  5 6  7

where leaves are terminal and the binary tree is full.

Starting from 1 lets us do 2*x, 2*x+1. From 0 is annoying.

A tree is represented by a numpy array, where the index is the state ID, and the
value in the array is the reward (which is always 0 or 1).

Initial state distribution = uniform over any non-terminal state.
"""

# pylint: disable=redefined-outer-name
# pylint: disable=invalid-name

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats


NUM_ACTIONS = 2
# Controls how stochastic the tree is.
# With this probability, the action taken is flipped.
# 0.5 for full randomness.
# NOTE: in the paper, epsilon = 2 * STOCHASTICITY, because the paper defines
# it as probability of random action. Here, it's the probability of flipping
# the action. Defining it this way makes the coding easier.
STOCHASTICITY = 0.2


def next_state(s, a, tree):
  next_is_terminal = terminal(2 * s, tree)
  if np.random.rand() < STOCHASTICITY:
    # Flip action.
    a = 1 - a
  next_ind = 2*s + a
  return next_ind


def terminal(s, tree):
  next_ind = 2 * s
  # return iff out of index
  return next_ind >= len(tree)


def eval_policy(q_function, s, tree):
  while not terminal(s, tree):
    a = np.argmax(q_function[s])
    s = next_state(s, a, tree)
  return tree[s]


def make_tree(levels, num_bad=1):
  r"""Makes a tree, num_bad leaves have reward 0 and the rest have reward 1."""
  # 3 levels = (0) 1 2 3 4 5 6 7 = length 8 array, corresponding to tree
  #          1
  #         / \
  #        2   3
  #       / \ / \
  #      4  5 6  7
  #
  # Don't pile all bad into one subtree, that's not interesting. Distribute
  # randomly.
  tree = np.zeros(2 ** levels)
  leaves = range(2 ** (levels - 1), 2 ** levels)
  bad_leaves = set(np.random.choice(leaves, size=num_bad, replace=False))
  for i in leaves:
    if i not in bad_leaves:
      tree[i] = 1.0
    else:
      tree[i] = 0.0
  return tree


def get_return(q_function, tree):
  """Gets return of the argmax policy of the Q-functions."""
  # Computes expected return of argmax policy for Q-function.
  # Only does 1 rollout, not good for stochastic envs. Use get_exact_return.
  total = 0
  count = 0
  for s in range(1, len(tree)):
    if terminal(s, tree):
      break
    total += 1
    count += eval_policy(q_function, s, tree)
  return float(count) / total


def get_exact_return(q_function, tree):
  # Compute true return using DP.
  # initialized uniformly at random from any non-terminal state

  returns = [0] * len(tree)  # return if initialized from this state.
  # init DP
  for s in range(1, len(tree)):
    if terminal(s, tree):
      returns[s] = tree[s]
  # iterate backwards
  for s in range(len(tree), 0, -1):
    if terminal(s, tree):
      continue
    a = np.argmax(q_function[s])
    p_a = 1 - STOCHASTICITY
    p_nota = STOCHASTICITY
    returns[s] = p_a * returns[2*s + a] + p_nota * returns[2*s + (1-a)]
  # average over all non-terminals
  values = []
  for s in range(1, len(tree)):
    if terminal(s, tree):
      break
    values.append(returns[s])
  return sum(values) / len(values)


def test(q_function, tree):
  exact = get_exact_return(q_function, tree)
  returns = [get_return(q_function, tree) for _ in range(1000)]
  print(exact, sum(returns[:10]) / 10.0, sum(returns[:100]) / 100.0, sum(returns) / 1000.0)


def get_dataset(tree, NUM_EPS=1000):
  """Generate dataset of NUM_EPS episodes from uniform behavior policy."""
  sequences = []
  positive_sequences = []
  rewards = []
  # Loop over initial state exactly to reduce variance in generated epsodes.
  nonterminal = [s for s in range(1, len(tree)) if not terminal(s, tree)]

  for i in xrange(NUM_EPS):
    s = nonterminal[i % len(nonterminal)]
    seq = []
    while not terminal(s, tree):
      a = np.random.randint(NUM_ACTIONS)  # 0 or 1
      seq.append((s, a))
      s = next_state(s, a, tree)
    sequences.append(seq)
    rewards.append(tree[s])
    if tree[s] == 1.0:
      positive_sequences.append(seq)
  return sequences, positive_sequences, rewards


def opc(positives, all_values, prior=1.0):
  """Computes OPC score given list of positive Q(s,a) and all Q(s,a)."""
  # OPC score:
  #   E_all[Q > b] - prior * E_pos[Q > b]
  # Find the minimum.
  #
  # |------b-------|
  #
  # score: for points ahead of b, positives give +1/all - prior/pos
  #        for points ahead of b, all gives +1/all
  #
  # initial value: 1 - prior.
  #
  # This is not normalized to be between 0 and 1.
  num_pos = len(positives)
  num_all = len(all_values)
  # score = contribution if x > b.
  # 1st term, 2nd term.
  # positive values will be inside all_values.
  pos_ = [(x, -prior / num_pos) for x in positives]
  all_ = [(x, 1.0 / num_all) for x in all_values]
  # Make sure each x is unique.
  # Collapse (x,1), (x,2), (x,3) into (x,6) to make sure total score is right.
  # Important for tree env since there are finitely many (s,a).
  unique_points = collections.defaultdict(list)
  for k, v in pos_ + all_:
    unique_points[k].append(v)

  # Aggregate + sort by x.
  reduced_scores = sorted([(k, sum(v)) for k, v in unique_points.items()],
                          key=lambda pair: pair[0])
  # curr_score = if b smaller than all points.
  # subtract contribution of next value.
  curr_score = 1.0 - prior
  best = curr_score
  for qval, score in reduced_scores:  # pylint: disable=unused-variable
    curr_score -= score
    best = min(curr_score, best)
  return best


def opc_score(dataset, q_function, prior=1.0):
  """Computes OPC score of Q-function over dataset."""
  all_seq, pos_seq, _ = dataset
  all_transitions = []
  pos_transitions = []

  for seq in all_seq:
    for s, a in seq:
      all_transitions.append(q_function[s, a])

  for seq in pos_seq:
    for s, a in seq:
      pos_transitions.append(q_function[s, a])

  return opc(pos_transitions, all_transitions, prior)


def sample_qfunction(levels):
  """Generates random Q-function from U[0,1]."""
  len_ = 2 ** levels
  return np.random.uniform(size=(len_, 2))


def soft_opc(dataset, q_function, prior=1.0):
  """Computes SoftOPC score."""
  all_qmeans = []
  pos_qmeans = []
  all_seq, pos_seq, _ = dataset

  for seq in all_seq:
    q_total = 0
    for s, a in seq:
      q_total += q_function[s, a]
    all_qmeans.append(q_total / float(len(seq)))

  for seq in pos_seq:
    q_total = 0
    for s, a in seq:
      q_total += q_function[s, a]
    pos_qmeans.append(q_total / float(len(seq)))

  pos_mean = np.mean(pos_qmeans)
  all_mean = np.mean(all_qmeans)

  return prior * pos_mean - all_mean


# This implementation is just for undiscounted, 0/1 reward case.
# Bellman, disc sum adv, MCC can be implemented in other cases with other code.


def bellman(dataset, q_function):
  """Average TD Error."""
  err = []
  all_seq, _, rewards = dataset
  for seq, rew in zip(all_seq, rewards):
    # Q(s0,a0) Q(s1,a1) Q(s2,a2)
    q_values = [q_function[s, a] for s, a in seq]
    # Intermediate TD error (reward always 0).
    errors = [(q1 - q2) ** 2 for q1, q2 in zip(q_values[:-1], q_values[1:])]
    # Final TD error (target = 0 or 1).
    errors.append((q_values[-1] - rew) ** 2)
    # Average over the episode.
    err.append(np.mean(errors))
  return np.mean(err)


def advantages(seq, q_function):
  """Helper function for computing advantage."""
  q_values = [q_function[s, a] for s, a in seq]
  # V(s0) V(s1) V(s2)
  values = [q_function[s].max() for s, _ in seq]
  return np.array(q_values) - np.array(values)


def disc_sum_adv(dataset, q_function):
  """Discounted sum of advantages, discount=1."""
  total = []
  all_seq, _, _ = dataset
  for seq in all_seq:
    advant = advantages(seq, q_function)
    # To get better estimate, take sum over any start time (equivalent to
    # init in random start state).
    for t in range(len(advant)):
      total.append(advant[t:].sum())
  return np.mean(total)


def mcc(dataset, q_function):
  """Monte Carlo corrected error."""
  errs = []
  all_seq, _, rewards = dataset
  for seq, rew in zip(all_seq, rewards):
    advant = advantages(seq, q_function)
    q_values = np.array([q_function[s, a] for s, a in seq])
    # error between Q_t and [(0/1) - sum advantages from t+1].
    for t in range(len(advant)):
      target = rew - advant[t+1:].sum()
      q_val = q_values[t]
      errs.append((target - q_val) ** 2)
  return np.mean(errs)


def correlations(returns, metrics):
  pearson = scipy.stats.pearsonr(returns, metrics)[0]
  spearman = scipy.stats.spearmanr(returns, metrics).correlation
  # Report R^2 not r.
  return pearson ** 2, spearman


def ranks(returns, metrics, NUM_FUNCS):
  """Reports fraction of pairs with incorrect rank, unused."""
  tot = 0
  errors = 0
  for i in range(NUM_FUNCS - 1):
    for j in range(i + 1, NUM_FUNCS):
      # if equal, don't care
      if returns[i] == returns[j]:
        continue
      tot += 1
      if metrics[i] < metrics[j]:
        errors += ((metrics[i] < metrics[j]) != (returns[i] < returns[j]))
  return float(errors) / tot


if __name__ == '__main__':
  # Ensure repeatable results.
  np.random.seed(123)
  LEVELS = 6
  #          1
  #         / \
  #        2   3
  #       / \ / \
  #      4  5 6  7
  #     /\ /\ /\ /\
  #    89101112131415
  #
  ## 1 success leaf.
  NUM_BAD = 2 ** (LEVELS - 1) - 1
  ## 1 failure leaf.
  # NUM_BAD = 1

  NUM_FUNCS = 1000
  tree = make_tree(levels=LEVELS, num_bad=NUM_BAD)
  ## Default plot.
  q_functions = [sample_qfunction(LEVELS) for _ in xrange(NUM_FUNCS)]

  ## Q-functions of different magnitude. Use with 1 success leaf.
  # q_functions = []
  # for i in range(1, NUM_FUNCS+1):
  #   q_functions.append(i * sample_qfunction(LEVELS))

  ## Big Q-functions. Use with 1 success leaf.
  # q_functions = [1000 * sample_qfunction(LEVELS) for _ in xrange(NUM_FUNCS)]

  for attempt in [0, 0.1, 0.2, 0.3, 0.4]:
    STOCHASTICITY = attempt
    dataset = get_dataset(tree, NUM_EPS=1000)
    returns = [get_exact_return(q_func, tree) for q_func in q_functions]
    print('Return of 1st Q-func for debugging', returns[0])

    # Baselines
    bellman_score = [bellman(dataset, q_func) for q_func in q_functions]
    disc_sum_adv_score = [disc_sum_adv(dataset, q_func) for q_func in q_functions]
    mcc_score = [mcc(dataset, q_func) for q_func in q_functions]
    print('Flip probability', STOCHASTICITY)
    print('Bad states', NUM_BAD)
    print('Bellman', correlations(returns, bellman_score))
    print('disc_sum_adv', correlations(returns, disc_sum_adv_score))
    print('mcc', correlations(returns, mcc_score))

    # New ones
    #priors = np.arange(1+20) / 20.0
    priors = [1.0]
    soft_opc_rank_scores = []
    opc_rank_scores = []
    for prior in priors:
      soft_opcs = [soft_opc(dataset, q_func, prior) for q_func in q_functions]
      # min opc -> max -opc, easier for ranking
      opcs = [-opc_score(dataset, q_func, prior) for q_func in q_functions]
      # Pearson, Spearman
      soft_opc_r2, soft_opc_corr = correlations(returns, soft_opcs)
      opc_r2, opc_corr = correlations(returns, opcs)
      print('Prior %f, SoftOPC correlations' % prior, soft_opc_r2, soft_opc_corr)
      print('Prior %f, OPC correlations' % prior, opc_r2, opc_corr)
      soft_opc_rank_scores.append(soft_opc_corr)
      opc_rank_scores.append(opc_corr)
      #plt.show()
  """
  plt.plot(priors, soft_opc_rank_scores)
  plt.plot(priors, opc_rank_scores)
  plt.rc('text', usetex=True)
  # If U[0,k], best possible total max is 1000 * 1001 / 2 = about 500 * 1000.
  # If U[0,1000], best possible total max is 1000 * 1000.
  # We expect max_{s,a} Q(s,a) > 500 so this threshold is enough to figure out
  # correct plot title.
  if sum(q_func.max() for q_func in q_functions) > (1000 * 1001 / 2):
    plt.title('Binary Tree, %d Levels, %d Success State\nQ(s,a) ~ U[0,1000]' % (
        LEVELS, 2 ** (LEVELS - 1) - NUM_BAD))
  elif max(q_func.max() for q_func in q_functions) > 1:
    plt.title('Binary Tree, %d Levels, %d Success State\nQ(s,a) ~ U[0,k]' % (
        LEVELS, 2 ** (LEVELS - 1) - NUM_BAD))
  elif NUM_BAD == 1:
    plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Fail State' % (STOCHASTICITY, LEVELS, NUM_BAD))
  elif NUM_BAD == 2 ** (LEVELS - 1) - 1:
    plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Success State' % (STOCHASTICITY, LEVELS, 2 ** (LEVELS - 1) - NUM_BAD))
  else:
    plt.title('Binary Tree, Flip Prob %f, %d Levels, %d Fail States' % (STOCHASTICITY, LEVELS, NUM_BAD))
  plt.xlabel('Prior $p(y=1)$')
  plt.ylabel('Spearman Correlation To Episode Return')
  plt.legend(['SoftOPC', 'OPC'])
  bs = correlations(returns, bellman_score)[1]
  plt.axhline(y=bs, linestyle='--', color='gray')
  plt.text(x=0.75, y=bs + 0.007, s='TD Error')
  ds = correlations(returns, disc_sum_adv_score)[1]
  plt.axhline(y=ds, linestyle='--', color='gray')
  plt.text(x=0.75, y=ds + 0.007, s='Sum Advantages')
  ms = correlations(returns, mcc_score)[1]
  plt.axhline(y=ms, linestyle='--', color='gray')
  plt.text(x=0.75, y=ms + 0.007, s='MCC Error')
  plt.show()
  """
