# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Run off-policy policy evaluation."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags

import os
import pickle 
import numpy as np

import gridworld.environments as gridworld_envs
import gridworld.policies as gridworld_policies
import transition_data as transition_data
from run_utils import * 


FLAGS = flags.FLAGS

flags.DEFINE_integer('seed', 0, 'Initial NumPy random seed.')
flags.DEFINE_integer('num_seeds', 100, 'How many seeds to run.')
flags.DEFINE_integer('num_trajectories', 200,
                     'Number of trajectories to collect.')                   
flags.DEFINE_integer('max_trajectory_length', 100,
                     'Cutoff trajectory at this step.')
flags.DEFINE_integer('grid_length', 10,
                     'Cutoff trajectory at this step.')                     

flags.DEFINE_float('gamma', 0.99, 'Discount factor.')
flags.DEFINE_bool('tabular_obs', True, 'Use tabular observations?')
flags.DEFINE_bool('tabular_solver', True, 'Use tabular solver?')
flags.DEFINE_bool('use_aggregation', True, 'Use tabular solver?')
flags.DEFINE_string('env_name', 'grid', 'Environment to evaluate on.')
flags.DEFINE_string('save_dir', None, 'Directory to save results to.')

flags.DEFINE_bool('deterministic_env', False, 'assume deterministic env.')
flags.DEFINE_bool('state_action', False, 'learns SA')

flags.DEFINE_string('name', '', 'name of run')
flags.DEFINE_bool('separate_mu', True, 'Separate mu.')
flags.DEFINE_integer('start', 50,
                     'to sample start states from')
flags.DEFINE_float('model_quality', -1, 'model quality in [0, 1], unless -1 for no model')
flags.DEFINE_float('policy_p', 0.1, 'behavioral policy parameter')
flags.DEFINE_float('policy_q', 0.4, 'behavioral policy parameter')
flags.DEFINE_string('model_type', None, 'Discount factor.')



      

def main(argv):
  del argv
  start_seed = FLAGS.seed
  num_seeds = FLAGS.num_seeds
  num_trajectories = FLAGS.num_trajectories
  max_trajectory_length = FLAGS.max_trajectory_length
  gamma = FLAGS.gamma
  tabular_obs = FLAGS.tabular_obs
  tabular_solver = FLAGS.tabular_solver
  if tabular_solver and not tabular_obs:
    raise ValueError('Tabular solver can only be used with tabular obs.')
  state_action = FLAGS.state_action
  env_name = FLAGS.env_name
  save_dir = FLAGS.save_dir
  grid_length=FLAGS.grid_length
  use_aggregation = FLAGS.use_aggregation
  name = FLAGS.name
  separate_mu = FLAGS.separate_mu 
  start=FLAGS.start
  policy_p = FLAGS.policy_p # probability of moving y first
  policy_q = FLAGS.policy_q # probability of moving x first 
  model_quality = FLAGS.model_quality 
  model_type = FLAGS.model_type 


  hparam_format = ('{SOLVER}_{ENV}_p{PI_P}_q{PI_Q}_N{NUM_TRAJ}_H{TRAJ_LEN}_gam{GAM}_start{START}_model_{MODEL_TYPE}_{MODEL_QUALITY}_seed{SEED}_{NAME}')
  hparam_str = hparam_format.format(
      ENV=env_name + tabular_obs * '-tab',
      PI_P = policy_p, 
      PI_Q = policy_q, 
      NUM_TRAJ=num_trajectories,
      TRAJ_LEN=max_trajectory_length,
      GAM=gamma,
      SOLVER='CF', 
      SA=state_action, 
      START=start, 
      NAME = name, 
      MODEL_TYPE=model_type, 
      MODEL_QUALITY = model_quality, 
      SEED=start_seed, 
      )

  env = gridworld_envs.GridWalk(grid_length, tabular_obs, start=start)
  target_policy = gridworld_policies.get_optimal_policy(env, tabular_obs)
  behavior_policy = gridworld_policies.get_other_policy(env, policy_p, policy_q, tabular_obs)

  mu_pi, rewards_pi, transitions_pi = env._get_matrices(target_policy)
  mu_beta, _, _ = env._get_matrices(behavior_policy)

  d_D, q_D, behavior_rewards = finite_get_true_solns(behavior_policy, env, 1.0, max_trajectory_length)
  d_pi, q_pi, target_rewards = infinite_get_true_solns(target_policy, env, gamma)
  w_pi = zeros_divide(d_pi, d_D) 

  mu_s = get_d_s(mu_pi, grid_length)
  d_D_s = get_d_s(d_D, grid_length)
  d_pi_s = get_d_s(d_pi, grid_length)

  indices = np.unique(np.around(q_pi, 3))
  q_phi = generate_mapping(q_pi, indices)
  q_phi = q_phi if use_aggregation else np.eye(len(w_pi))

  nu = (w_pi > 50) * d_pi
  nu = nu / np.sum(nu)
  uniform = np.ones_like(nu) / len(nu)

  if model_quality > -1: 
    if model_type == 'full': 
      mask = np.ones_like(indices)
      full_mask = np.ones_like(w_pi)
    elif model_type == 'nu': 
      mask = (((w_pi > 50) @ q_phi) > 0 ).astype(float)
      full_mask = (w_pi > 50).astype(float)
    q_pi_mean = np.mean(q_pi[(w_pi > 50).astype(bool)])
    model =  (model_quality * indices + (1-model_quality) * q_pi_mean) * mask 
  else: 
    model = None

  nu_list =  [d_D, mu_pi, uniform, nu,]
  nu_labels = ['emp_D', 'emp_mu', 'd_D', 'mu_pi', 'uniform', 'nu', ]
  for idx, epsilon in enumerate(np.linspace(0, 1, 3)):
      policy = gridworld_policies.get_other_policy(env, policy_p, policy_q, tabular_obs, epsilon=epsilon)
      d_policy, _, _ = infinite_get_true_solns(policy, env, gamma)
      nu_list.append(d_policy)
      nu_labels.append(str(epsilon))
  nu_labels.append('zero')
    
  g_list = [] 
  for dist in nu_list: 
    wf = get_wf(q_pi, d_D, target_policy, env, gamma, dist, model=0.)
    g_list.append(wf)

  for dist in nu_list[:5]:
      if model_quality > -1: 
        wf = get_wf(full_mask, d_D, target_policy, env, gamma, dist, model=0.)
        g_list.append(wf)
        wf = get_wf(full_mask * q_pi, d_D, target_policy, env, gamma, dist, model=0.)
        g_list.append(wf)
      else: 
        wf = get_wf(q_pi, np.ones_like(d_D), target_policy, env, gamma, dist, model=0.)
        g_list.append(wf)
  
  g_phi = (1-gamma)* np.array(g_list).transpose() if use_aggregation else np.eye(len(w_pi))

  d = q_phi.shape[1]
  k = g_phi.shape[1]
  n = num_trajectories * max_trajectory_length

  results = {} 
  q_dict = {}
  for label in nu_labels: 
    results[label] = []
    q_dict[label] = []
  if model_quality > -1: 
    results['model'] = []
    q_dict['model'] = [model]

  true_lhs, true_rhs = true_linear_system(env, target_policy, gamma, d_D, q_phi, g_phi)
  for seed in range(start_seed*num_seeds, start_seed*num_seeds + num_seeds):
    if num_seeds == 1:
      summary_prefix = ''
    else:
      summary_prefix = 'seed%d/' % seed
    np.random.seed(seed)

    (dataset, _, _) = transition_data.collect_data(
     env,
     behavior_policy,
     num_trajectories, 
     max_trajectory_length,
     gamma=gamma)
    mu_data = np.random.choice(len(mu_s), size=(num_trajectories * max_trajectory_length), replace=True, p=mu_s)
    
    emp_mu = np.zeros([d])
    for mu_state in mu_data: 
        mu_probs = target_policy.get_probabilities(mu_state)
        for mu_action, mu_prob in enumerate(mu_probs):
            mu_idx = get_index(mu_state, mu_action)
            q_feat = q_phi[mu_idx]
            emp_mu += mu_prob * q_feat
    emp_mu = emp_mu / n
    
    lhs = np.zeros([k, d])
    rhs = np.zeros([k])
    emp_dist = np.zeros([d])

    for transition in dataset.iterate_once():
        idx = get_index(transition.state, transition.action)
        g_feat = g_phi[idx]
        q_feat = q_phi[idx]

        emp_dist += q_feat 

        next_probs = target_policy.get_probabilities(transition.next_state)
        next_q_feat = np.zeros_like(q_feat)
        for next_action, next_prob in enumerate(next_probs):
            next_idx = get_index(transition.next_state, next_action)
            next_q_feat += q_phi[next_idx] * next_prob

        rhs += transition.reward * g_feat
        lhs += np.outer(g_feat, q_feat) - gamma * np.outer(g_feat, next_q_feat)

    rhs = rhs / n 
    lhs = lhs / n
    emp_dist = emp_dist / n
    
    print('n= %s, seed=%s' % (num_trajectories, seed))
    for dist, label in zip([emp_dist, emp_mu] + list(nu_list) + [np.zeros_like(mu_pi)], nu_labels): 
        if len(dist) > d: 
            q_param = cp_optimize(lhs, rhs, np.diag(dist @ q_phi), model)
        else: 
            q_param = cp_optimize(lhs, rhs, np.diag(dist), model)
        q = (q_param @ q_phi.transpose()).flatten()
        q_ope = estimate_value(mu_data, target_policy, gamma, q)
        pop_q_ope = (1-gamma) * np.dot(mu_pi, q)
        result = [
          compute_ope_error(q_ope, target_rewards), 
          compute_ope_error(pop_q_ope, target_rewards), 
          ]
        for dist_ in nu_list: 
            q_error_nu = compute_l2_error(q, q_pi, dist_)
            result.append(q_error_nu)
        results[label].append(result)
        q_dict[label].append(q_param)
    
    if model_quality > -1: 
      model_q = (model @ q_phi.transpose()).flatten()
      q_ope = estimate_value(mu_data, target_policy, gamma, model_q)
      pop_q_ope = (1-gamma) * np.dot(mu_pi, model_q)
      result = [
          compute_ope_error(q_ope, target_rewards), 
          compute_ope_error(pop_q_ope, target_rewards), 
          ]
      for dist_ in nu_list: 
        q_error_nu = compute_l2_error(model_q, q_pi, dist_)
        result.append(q_error_nu)
      results['model'].append(result)


  if save_dir is not None:
    filename = os.path.join(save_dir, '%s.pickle' % hparam_str)
    with open(filename, 'wb') as f: 
      pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    filename = os.path.join(save_dir, '%s_qparams.pickle' % hparam_str)
    with open(filename, 'wb') as f: 
      pickle.dump(q_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
    
  print('Done!')


if __name__ == '__main__':
  app.run(main)
