import copy
import os
import sys

import tree

from bbac import algorithms
from softlearning import policies
from softlearning import value_functions
from softlearning import preprocessors as preprocessors_lib
from softlearning import replay_pools
from softlearning import samplers
from softlearning.utils.serialization import custom_object_scope

from examples.instrument import run_example_local
from examples.development import main as development_main

from bbac.samplers import BBACSampler
from bbac.value_functions import (
    random_prior_ensemble_feedforward_Q_function,
    random_prior_feedforward_Q_function)


class ExperimentRunner(development_main.ExperimentRunner):
    def _build(self):
        from softlearning.environments.utils import get_environment_from_params
        from bbac.environments.dm_control.suite import custom_cartpole
        variant = copy.deepcopy(self._variant)
        environment_params = variant['environment_params']
        training_environment = self.training_environment = (
            get_environment_from_params(environment_params['training']))
        evaluation_environment = self.evaluation_environment = (
            get_environment_from_params(environment_params['evaluation'])
            if 'evaluation' in environment_params
            else training_environment)

        Q_input_shapes = (
            training_environment.observation_shape,
            training_environment.action_shape)
        policy_input_shapes = training_environment.observation_shape

        pixel_observation_names = [
            observation_name
            for observation_name in training_environment.observation_shape
            if 'pixels' in observation_name
        ]
        if pixel_observation_names:
            pixel_preprocessor_params = {
                'class_name': 'convnet_preprocessor',
                'config': {
                    'conv_filters': (64, ) * 4,
                    'conv_kernel_sizes': (3, ) * 4,
                    'conv_strides': (2, ) * 4,
                    'normalization_type': 'layer',
                    'downsampling_type': 'conv',
                },
            }
            pixel_Q_preprocessor = preprocessors_lib.deserialize(
                pixel_preprocessor_params)
            pixel_policy_preprocessor = preprocessors_lib.deserialize(
                pixel_preprocessor_params)
            Q_preprocessors = tree.map_structure_with_path(
                lambda path, x: (
                    pixel_Q_preprocessor
                    if 'pixels' in str(path[-1])
                    else None
                ), Q_input_shapes)
            policy_preprocessors = tree.map_structure_with_path(
                lambda path, x: (
                    pixel_policy_preprocessor
                    if 'pixels' in str(path[-1])
                    else None
                ), policy_input_shapes)
        else:
            Q_preprocessors = None
            policy_preprocessors = None

        variant['Q_params']['config'].update({
            'input_shapes': Q_input_shapes,
            'preprocessors': Q_preprocessors,
        })
        variant['Q_target_params']['config'].update({
            'input_shapes': Q_input_shapes,
            'preprocessors': Q_preprocessors
        })

        variant['exploitation_Q_params']['config'].update({
            'input_shapes': Q_input_shapes,
            'preprocessors': Q_preprocessors
        })
        variant['exploitation_Q_target_params']['config'].update({
            'input_shapes': Q_input_shapes,
            'preprocessors': Q_preprocessors
        })

        assert (
            variant['algorithm_params']['class_name']
            == 'BayesianBellmanActorCritic'), variant['algorithm_params']['class_name']
        assert (
            variant['Q_params']['class_name']
            == 'random_prior_ensemble_feedforward_Q_function'), variant['Q_params']['class_name']
        assert (
            variant['Q_target_params']['class_name']
            == 'random_prior_ensemble_feedforward_Q_function'), variant['Q_target_params']['class_name']

        with custom_object_scope({
                'random_prior_ensemble_feedforward_Q_function':
                random_prior_ensemble_feedforward_Q_function,
                'random_prior_feedforward_Q_function':
                random_prior_feedforward_Q_function,
        }):
            Qs = self.Qs = tree.flatten(value_functions.get(variant['Q_params']))
            Q_targets = self.Q_targets = tree.flatten(
                value_functions.get(variant['Q_target_params']))
            exploitation_Qs = self.exploitation_Qs = tree.flatten(
                value_functions.get(variant['exploitation_Q_params']))
            exploitation_Q_targets = self.exploitation_Q_targets = tree.flatten(
                value_functions.get(variant['exploitation_Q_target_params']))

        for Q, Q_target in zip(self.Qs, self.Q_targets):
            for source_weight, target_weight in zip(
                    Q.variables, Q_target.variables):
                target_weight.assign(source_weight)

        assert len(self.exploitation_Qs) == len(self.exploitation_Q_targets)
        for exploitation_Q, exploitation_Q_target in zip(
                self.exploitation_Qs, self.exploitation_Q_targets):
            for source_weight, target_weight in zip(
                    exploitation_Q.variables, exploitation_Q_target.variables):
                target_weight.assign(source_weight)

        variant['policy_params']['config'].update({
            'action_range': (training_environment.action_space.low,
                             training_environment.action_space.high),
            'input_shapes': policy_input_shapes,
            'output_shape': training_environment.action_shape,
            'preprocessors': policy_preprocessors,
        })
        policy = self.policy = policies.get(variant['policy_params'])

        variant['exploration_policy_params']['config'].update({
            'action_range': (training_environment.action_space.low,
                             training_environment.action_space.high),
            'input_shapes': policy_input_shapes,
            'output_shape': training_environment.action_shape,
            'preprocessors': policy_preprocessors,
        })
        exploration_policies = self.exploration_policies = [
            policies.get(variant['exploration_policy_params'])
            for i in range(len(self.Qs))
        ]
        variant['replay_pool_params']['config'].update({
            'environment': training_environment,
        })
        replay_pool = self.replay_pool = replay_pools.get(
            variant['replay_pool_params'])

        variant['sampler_params']['config'].update({
            'environment': training_environment,
            'policies': self.exploration_policies,
            'exploitation_policy': policy,
            # 'policy': policy,
            'pool': replay_pool,
        })
        with custom_object_scope({
                'BBACSampler': BBACSampler,
        }):
            sampler = self.sampler = samplers.get(variant['sampler_params'])

        variant['algorithm_params']['config'].update({
            'training_environment': training_environment,
            'evaluation_environment': evaluation_environment,
            'policy': policy,
            'exploration_policies': exploration_policies,
            'Qs': Qs,
            'Q_targets': Q_targets,
            'exploitation_Qs': exploitation_Qs,
            'exploitation_Q_targets': exploitation_Q_targets,
            'pool': replay_pool,
            'sampler': sampler
        })
        algorithm_class = getattr(
            algorithms, variant['algorithm_params']['class_name'])
        self.algorithm = algorithm_class(
            **variant['algorithm_params']['config'])

        self._built = True

    def step(self, *args, **kwargs):
        diagnostics = super().step(*args, **kwargs)

        if diagnostics.get('done', False):
            logdir = self.logdir
            self._save_replay_pool(logdir)

        return diagnostics

    def _save_value_functions(self, checkpoint_dir):
        exploration_Q_save_path = os.path.join(checkpoint_dir, 'exploration-Q')
        tree.map_structure_with_path(
            lambda path, Q: Q.save_weights(os.path.join(
                exploration_Q_save_path,
                '-'.join([str(x) for x in path])),
                save_format='tf'),
            self.Qs)

        exploitation_Q_save_path = os.path.join(checkpoint_dir, 'exploitation-Q')
        tree.map_structure_with_path(
            lambda path, Q: Q.save_weights(os.path.join(
                exploitation_Q_save_path,
                '-'.join([str(x) for x in path])),
                save_format='tf'),
            self.exploitation_Qs)

    def _save_policy(self, checkpoint_dir):
        exploitation_save_path = os.path.join(
            checkpoint_dir, 'exploitation-policy')
        os.makedirs(exploitation_save_path, exist_ok=True)
        self.policy.save(os.path.join(exploitation_save_path, '0'))

        exploration_save_path = os.path.join(
            checkpoint_dir, 'exploration-policy')
        os.makedirs(exploration_save_path, exist_ok=True)
        tree.map_structure_with_path(
            lambda path, policy: policy.save(os.path.join(
                exploration_save_path, '-'.join([str(x) for x in path]))),
            self.exploration_policies)

    def save_checkpoint(self, checkpoint_dir):
        """Implements the checkpoint save logic."""
        # self._save_replay_pool(checkpoint_dir)
        # self._save_sampler(checkpoint_dir)
        self._save_value_functions(checkpoint_dir)
        self._save_policy(checkpoint_dir)
        # self._save_algorithm(checkpoint_dir)

        return os.path.join(checkpoint_dir, '')


def main(argv=None):
    """Run ExperimentRunner locally on ray.

    To run this example on cloud (e.g. gce/ec2), use the setup scripts:
    'softlearning launch_example_{gce,ec2} examples.development <options>'.

    Run 'softlearning launch_example_{gce,ec2} --help' for further
    instructions.
    """
    run_example_local('bbac.runners.bbac', argv)


if __name__ == '__main__':
    main(argv=sys.argv[1:])
