from rsa.algos import MCTD3
import rsa.utils as utils
import rsa.utils.pytorch_utils as ptu
from rsa.utils.arg_parser import parse_args
from rsa.utils.logx import EpochLogger
import rsa.utils.spb_utils as spbu

import numpy as np
from tqdm import trange
import os
import json

if __name__ == '__main__':
    params = parse_args(td3_args=True)

    utils.seed(params['seed'])
    logdir = params['logdir']
    os.makedirs(logdir)
    os.makedirs(os.path.join(logdir, 'misc'))
    ptu.setup(params['device'])
    with open(os.path.join(logdir, 'hparams.json'), 'w') as f:
        json.dump(params, f)

    env, test_env = utils.make_env(params)
    is_pointbot_env = params['env'] in utils.pointbot_envs

    logger = EpochLogger(output_dir=logdir, exp_name=params['exper_name'])
    loss_plotter = utils.LossPlotter(os.path.join(logdir, 'loss_plots'))

    # rsa = TD3((17,), (6,), 1)
    td3 = MCTD3(params)

    if params['env'] in utils.d4rl_envs:
        replay_buffer = utils.load_d4rl_replay_buffer(env, params, add_drtg=True)
    else:
        if params['gen_data']:
            NUM_BC_EPISODES = 20
            expert_policy = utils.make_expert_policy(params, env)
            utils.generate_offline_data(env, expert_policy, params)
        replay_buffer = utils.load_replay_buffer(params, add_drtg=True)

    if params['checkpoint'] is not None:
        td3.load(params['checkpoint'])
    else:
        print('Pretraining Policy')
        os.makedirs(os.path.join(logdir, 'pretrain_plots'))
        for i in trange(params['init_iters']):
            info = td3.update(replay_buffer)
            loss_plotter.add_data(**info)
            if i > 0 and i % 500 == 0 and False:
                spbu.plot_Q(td3, env,
                            points=np.array([transition['obs'] for transition in
                                             replay_buffer.all_transitions()]) * (180, 150),
                            file=os.path.join(logdir, 'pretrain_plots', 'q_%d.pdf' % i),
                            skip=2)
                if params['plot_drtg_maxes']:
                    spbu.plot_maxes(td3, env,
                                    file=os.path.join(logdir, 'pretrain_plots', 'q_maxes_%d.pdf' % i))
                    td3.drtg_buffer = set()
                    td3.bellman_buffer = set()
                loss_plotter.plot()
        if params['init_iters'] > 0:
            td3.save(os.path.join(logdir, 'pretrain'))
            loss_plotter.plot()

    # Run training loop
    # Prepare for interaction with environment
    i = 0
    n_episodes = 0
    epoch = 0
    metrics = {
        'Timesteps': 0,
    }
    robosuite = params['env'] in ('Lift', 'Door', 'NutAssembly', 'TwoArmPegInHole')

    total_timesteps = params['total_timesteps']

    while i < total_timesteps:
        # Collect one trajectory
        obs, done, t = env.reset(), False, 0
        ep_buf, rets = [], []
        while not done and t < params['horizon']:
            ################################################################################
            # Every params['eval_freq'] timesteps, run the evaluation loop and output logs #
            ################################################################################
            if i % params['eval_freq'] == 0:

                print('Testing Agent')
                for j in range(params['num_eval_episodes']):
                    obs, done, ep_ret, ep_len = test_env.reset(), False, 0, 0
                    # print('-------------')
                    while not done:
                        # Take deterministic actions at test time (noise_scale=0)
                        act = td3.select_action(obs)
                        next_obs, rew, done, info = test_env.step(act)
                        ep_ret += rew
                        ep_len += 1
                        obs = next_obs
                        # print(rew)
                    if robosuite:
                        test_env.close()
                    logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)

                # Log info about epoch
                logger.log_tabular('Epoch', epoch)
                logger.log_tabular('TotalEnvInteracts', i)
                logger.log_tabular('TestEpRet')
                logger.log_tabular('TestEpLen', average_only=True)
                if epoch == 0:
                    logger.log_tabular('AverageTrainEpRet', 0)
                    logger.log_tabular('StdTrainEpRet', 0)
                    logger.log_tabular('TrainEpLen', 0)
                    logger.log_tabular('QRisk1', 0)
                    logger.log_tabular('QRisk2', 0)
                    logger.log_tabular('Q1', 0)
                    logger.log_tabular('Q2', 0)
                else:
                    logger.log_tabular('TrainEpRet')
                    logger.log_tabular('TrainEpLen', average_only=True)
                    logger.log_tabular('QRisk1', average_only=True)
                    logger.log_tabular('QRisk2', average_only=True)
                    logger.log_tabular('Q1', average_only=True)
                    logger.log_tabular('Q2', average_only=True)
                for metric, value in metrics.items():
                    logger.log_tabular(metric, value)
                logger.dump_tabular()

                epoch += 1
                loss_plotter.plot()
                td3.save(os.path.join(logdir, 'models'))

                if is_pointbot_env:
                    spbu.plot_Q(td3, env,
                                os.path.join(logdir, 'misc', 'q_%d.pdf' % n_episodes),
                                skip=2)
                if params['plot_drtg_maxes']:
                    spbu.plot_maxes(td3, env,
                                    os.path.join(logdir, 'misc', 'q_maxes_%d.pdf' % n_episodes))
                    td3.drtg_buffer = set()
                    td3.bellman_buffer = set()

            ########################
            # Begin policy updates #
            ########################

            if i < params['start_timesteps']:
                act = env.action_space.sample()
                a_expert = None
            else:
                act = (td3.select_action(obs) +
                       np.random.normal(0, params['max_action'] * params['expl_noise'],
                                        size=params['d_act']))\
                    .clip(-params['max_action'], params['max_action'])

            next_obs, rew, done, info = env.step(act)
            ep_buf.append({
                'obs': obs,
                'next_obs': next_obs,
                'act': act,
                'rew': utils.shift_reward(rew, params),
                'done': done,
                'expert': 0,
                'goal': info['goal'] if 'goal' in info else 0,
                'mask': info['mask'] if 'mask' in info
                else (1 if t == params['horizon'] else float(not done))
            })
            obs = next_obs

            i += 1
            t += 1
            rets.append(rew)
            metrics['Timesteps'] += 1

            # grad steps
            if i >= params['start_timesteps']:
                for _ in range(params['update_n_steps']):
                    if len(replay_buffer) == 0:
                        break
                    info = td3.update(replay_buffer)
                    logger.store(**info)
                    loss_plotter.add_data(**info)

        x, succ = 0, 0
        for j, transition in enumerate(reversed(ep_buf)):
            # TODO We need to come up with a good way to estimate this for general environments.
            #   For the goal conditioned method it's easy to say the rest of the rewards will
            #   always be -1 or 0. However, for general environments this is not the case.
            #   Possible options I've considered are assuming it will always be minimum reward,
            #   mean reward or median reward. We should experiment with this. For now implementing
            #   median option.
            # print(x)
            if j == 0:
                succ = succ or transition['goal']
                if not transition['mask']:
                    x = transition['rew']
                else:
                    # Set drtg to infinite discounted reward sum.
                    # reward_estimate = np.median(rets)
                    reward_estimate = ep_buf[-1]['rew']
                    if params['discount'] < 1:
                        x = reward_estimate / (1 - params['discount'])
                    else:
                        x = reward_estimate * float('inf')
            else:
                x = transition['rew'] + transition['mask'] * params['discount'] * x
            transition['drtg'] = x
            transition['succ'] = succ
            del transition['goal']
            replay_buffer.store_transition(transition)

        if robosuite:
            env.close()

        logger.store(TrainEpRet=sum(rets), TrainEpLen=len(rets))
        n_episodes += 1
