"""Experiment that shows arbitrary off-policy behavior of TD."""

import argparse
from distutils.util import strtobool
import os
from pathlib import Path
import pickle

import numpy as np
import ray
from ray import tune
import tensorflow as tf
import tensorflow_probability as tfp


from policy_evaluation.utils import PROJECT_ROOT, get_git_rev
from policy_evaluation import algorithms
from policy_evaluation import tasks
from policy_evaluation import value_functions

from .experiment_runner import (
    ExperimentRunner, datetime_stamp, set_gpu_memory_growth)

set_gpu_memory_growth(True)

tfd = tfp.distributions


DISCOUNT = 0.98
CURRENT_FILE_PATH = Path(__file__)
CACHE_DIR = CURRENT_FILE_PATH.parent / 'data' / CURRENT_FILE_PATH.stem


experiment_params = {
    # 'total_samples': 5000,  # l
    # 'epoch_length': 50,  # error_every
    'n_episodes': 1,  # n_eps
    'episodic': False,
    'name': "n_link_pendulum",
    'title': "N-Link Pendulum",
    'criterion': "RMSE",
}

run_params = {
    # 'run_eagerly': True,
    'num_samples': 5,  # n_indep
    # 'seed': 1,
    'verbose': 100,
}


environment_params = {
    'class_name': 'NLinkPendulum',
    'config': {},
}


class NLinkPendulumExperimentRunner(ExperimentRunner):
    def _setup(self, variant):
        super(NLinkPendulumExperimentRunner, self)._setup(variant)

        seed = variant['run_params']['seed']
        self.dataset = ray.get(variant['dataset_object_ids'][str(seed)])
        self.true_value_states, self.true_values = ray.get(
            variant['value_function_object_id'])

        algorithm_params = variant['algorithm_params']
        value_function_params = variant['value_function_params']

        if algorithm_params['class_name'] == 'BBORandomizedPrior':
            V_omega = value_functions.StateValueFunction(
                model=tf.keras.Sequential([
                    tf.keras.layers.Dense(
                        hidden_layer_size,
                        activation=value_function_params['activation'])
                    for hidden_layer_size
                    in value_function_params['hidden_layer_sizes']
                ] + [
                    tf.keras.layers.Dense(1, activation='linear')
                ]),
            )
            V_phi = value_functions.StateValueFunction(
                model=tf.keras.Sequential([
                    tf.keras.layers.Dense(
                        hidden_layer_size,
                        activation=value_function_params['activation'])
                    for hidden_layer_size
                    in value_function_params['hidden_layer_sizes']
                ] + [
                    tf.keras.layers.Dense(1, activation='linear')
                ]),
            )
            algorithm_params['config'].update({
                'V_omega': V_omega,
                'V_phi': V_phi,
            })

        elif algorithm_params['class_name'] == 'TD0':
            V = value_functions.StateValueFunction(
                model=tf.keras.Sequential([
                    tf.keras.layers.Dense(
                        hidden_layer_size,
                        activation=value_function_params['activation'])
                    for hidden_layer_size
                    in value_function_params['hidden_layer_sizes']
                ] + [
                    tf.keras.layers.Dense(1, activation='linear')
                ]),
            )
            algorithm_params['config'].update({'V': V})
        elif algorithm_params['class_name'] in ('TDC', 'GTD2'):
            V_theta = value_functions.StateValueFunction(
                model=tf.keras.Sequential([
                    tf.keras.layers.Dense(
                        hidden_layer_size,
                        activation=value_function_params['activation'])
                    for hidden_layer_size
                    in value_function_params['hidden_layer_sizes']
                ] + [
                    tf.keras.layers.Dense(1, activation='linear')
                ]),
            )
            # initialize model
            V_theta.values(self.dataset['samples']['state_0'][:1, ...])
            algorithm_params['config'].update({'V_theta': V_theta})
        else:
            raise ValueError(variant['algorithm_params']['class_name'])

        self.algorithm = getattr(algorithms, algorithm_params['class_name'])(
            **algorithm_params['config'])

        task_params = variant['task_params']
        assert (task_params['class_name']
                == 'ValuePredictionTask'), (
                    task_params['class_name'])
        task_params['config'].update({
            'algorithm': self.algorithm,
        })

        self.task = tasks.ValuePredictionTask(
            **task_params['config'])
        self.task.true_value_states = self.true_value_states
        self.task.true_values = self.true_values
        self.task.dataset = self.dataset

        self._all_values = []

    def _train(self, *args, **kwargs):
        result = super(NLinkPendulumExperimentRunner, self)._train(
            *args, **kwargs)
        values = self.algorithm.V.values(self.true_value_states).numpy()
        self._all_values.append(values)
        if result[ray.tune.result.DONE]:
            self._plot_values()
        return result

    def _plot_values(self):
        return


algorithm_params = {
    'bbo-rp': {
        'class_name': 'BBORandomizedPrior',
        'config': {
            'gamma': DISCOUNT,
            'num_phi_steps': 20,

            'phi_lr': 1e-3,
            'omega_lr': 1e-2,

            'prior_loc': 0.0,
            'prior_scale': 0.3,
            'prior_loss_weight': 0.5,
        },
    },
    'td0': {
        'class_name': 'TD0',
        'config': {
            'gamma': DISCOUNT,
            'alpha': 1e-5,
        },
    },
    'tdc': {
        'class_name': 'TDC',
        'config': {
            'gamma': DISCOUNT,
            'alpha': 1e-4,
            'beta': 3e-3,
        },
    },
    'gtd2': {
        'class_name': 'GTD2',
        'config': {
            'gamma': DISCOUNT,
            'alpha': tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0]),
            'beta': tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0]),
        },
    },
}


value_function_params = {
    'hidden_layer_sizes': (256, ),
    'activation': tune.sample_from(lambda spec: (
        {
            'TDC': 'tanh',
            'GTD2': 'tanh',
        }.get(
            spec.get('config', spec)
            ['algorithm_params']
            ['class_name'],
            'relu')
    )),
}


def train(num_samples,
          num_steps,
          epoch_length,
          debug,
          use_wandb,
          algorithm='bbo-rp',
          experiment_name=None):
    ray.init(resources={}, local_mode=debug, include_webui=False)

    if num_samples < 1 or 25 < num_samples:
        raise ValueError("num_samples must be between 1 and 25.")
    seeds = np.sort(np.random.choice(25, num_samples, replace=False)).tolist()

    run_params.update({
        'run_eagerly': debug,
        'seed': tune.grid_search(seeds),
    })

    def generate_dataset(seed):
        cache_path = CACHE_DIR / 'datasets' / f'dataset-{seed}.pkl'
        if not os.path.exists(cache_path):
            raise ValueError(
                "Can't find dataset! Generate the dataset with"
                " `tdlearn/generate_20_link_pendulum_data.py` and move them to"
                f" to {CACHE_DIR / 'datasets'}")

        with open(cache_path, 'rb') as f:
            dataset = pickle.load(f)

        return dataset

    datasets = {
        seed: generate_dataset(seed)
        for seed in seeds
    }

    dataset_object_ids = {
        str(seed): ray.put(dataset, weakref=False)
        for seed, dataset in datasets.items()
    }

    def compute_value_function():
        cache_path = (
            CACHE_DIR / 'value_functions' / f'value_function.pkl')
        if not os.path.exists(cache_path):
            raise ValueError(
                "Can't find value function! Generate the value function with"
                " `tdlearn/generate_20_link_pendulum_data.py` and move it to"
                f" {CACHE_DIR / 'value_functions'}")

        with open(cache_path, 'rb') as f:
            states, values = pickle.load(f)

        return states, values

    value_function = compute_value_function()
    value_function_object_id = ray.put(value_function, weakref=False)

    experiment_config = {
        'dataset_object_ids': dataset_object_ids,
        'value_function_object_id': value_function_object_id,
        'algorithm_params': algorithm_params[algorithm],
        'value_function_params': value_function_params,
        'experiment_params': {
            'total_samples': num_steps,
            'epoch_length': epoch_length,
            **experiment_params,
        },
        'run_params': run_params,
        'environment_params': environment_params,
        'task_params': {
            'class_name': 'ValuePredictionTask',
            'config': {
                'criteria': ['RMSE', 'MSE', ],
                'batch_size': tune.grid_search([512])
            },
        },
        'git_rev': get_git_rev(PROJECT_ROOT),
    }

    if experiment_name is not None:
        experiment_name = '-'.join((datetime_stamp(), experiment_name))
    else:
        experiment_name = datetime_stamp()

    local_dir = os.path.join(
        PROJECT_ROOT, 'data', ('debug' if debug else ''), 'n_link_pendulum')

    tune.run(
        NLinkPendulumExperimentRunner,
        name=experiment_name,
        config=experiment_config,
        resources_per_trial={
            'cpu': 3,
            'gpu': 0.0,
        },
        local_dir=local_dir,
        num_samples=1,  # Use seeds to generate multiple samples
        checkpoint_freq=0,
        checkpoint_at_end=False,
        max_failures=0,
        restore=None,
        with_server=False,
        scheduler=None,
        loggers=(ray.tune.logger.CSVLogger, ray.tune.logger.JsonLogger),
        reuse_actors=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--mode',
        type=str,
        choices=('train', 'visualize'),
        default='train')

    parser.add_argument('--num-samples', type=int, default=25)
    parser.add_argument('--num-steps', type=int, default=5000)
    parser.add_argument('--epoch-length', type=int, default=25)
    parser.add_argument('--experiment-name', type=str, default=None)
    parser.add_argument('--experiment-path', type=str, default=None)
    parser.add_argument('--algorithm', type=str, default='bbo-rp')
    parser.add_argument(
        '--use-wandb',
        type=lambda x: bool(strtobool(x)),
        nargs='?',
        const=True,
        default=False)
    parser.add_argument(
        '--debug',
        type=lambda x: bool(strtobool(x)),
        nargs='?',
        const=True,
        default=False,
        help="Whether or not to execute sequentially to allow breakpoints.")

    args = parser.parse_args()
    if args.mode == 'train':
        train(num_samples=args.num_samples,
              num_steps=args.num_steps,
              epoch_length=args.epoch_length,
              debug=args.debug,
              use_wandb=args.use_wandb,
              algorithm=args.algorithm,
              experiment_name=args.experiment_name)
    elif args.mode == 'visualize':
        raise NotImplementedError(args.mode)
        if args.experiment_path is None:
            raise ValueError("Set '--experiment-path [path-to-experiment]'.")
        visualize_experiment(args.experiment_path)
