from collections import OrderedDict
from functools import partial
import os


import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tree
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

from softlearning.algorithms.sac import SAC, compute_Q_targets


class BayesianBellmanActorCritic(SAC):
    """Bayesian Bellman Actor Critic algorithm.

    References
    ----------
    [1]
    """

    def __init__(
            self,

            exploration_num_actor_updates=1,
            exploration_num_critic_updates=1,
            exploration_num_target_q_samples=1,
            exploration_q_target_reduce_type='min',
            exploration_q_target_type='sac',
            exploration_policy_update_type='sac',
            exploration_policy_num_q_samples=1,
            exploration_policy_q_reduce_type='mean',
            exploration_policy_q_type='Q',
            exploration_policy_q_ensemble_subset_size=None,

            exploitation_num_actor_updates=1,
            exploitation_num_critic_updates=1,
            exploitation_num_target_q_samples=1,
            exploitation_q_target_ensemble_subset_size=None,
            exploitation_q_target_reduce_type='mean',
            exploitation_q_target_type='sac',
            exploitation_policy_update_type='sac',
            exploitation_policy_num_q_samples=1,
            exploitation_policy_q_reduce_type='mean',
            exploitation_policy_q_type='exploitation_Q',
            exploitation_policy_q_ensemble_subset_size=None,

            exploration_policies=None,
            exploitation_Qs=None,
            exploitation_Q_targets=None,
            **kwargs,
    ):
        """
        Args:
        """
        super(BayesianBellmanActorCritic, self).__init__(**kwargs)

        self._exploration_num_actor_updates = exploration_num_actor_updates
        self._exploration_num_critic_updates = exploration_num_critic_updates
        self._exploration_num_target_q_samples = exploration_num_target_q_samples
        self._exploration_q_target_reduce_type = exploration_q_target_reduce_type
        self._exploration_q_target_type = exploration_q_target_type
        self._exploration_policy_update_type = exploration_policy_update_type
        self._exploration_policy_num_q_samples = exploration_policy_num_q_samples
        self._exploration_policy_q_reduce_type = exploration_policy_q_reduce_type
        self._exploration_policy_q_type = exploration_policy_q_type
        # self._exploration_policy_q_ensemble_subset_size = exploration_policy_q_ensemble_subset_size
        assert isinstance(exploration_policy_q_ensemble_subset_size,
                          (type(None), int)), (
            exploration_policy_q_ensemble_subset_size)
        self._exploration_policy_q_ensemble_subset_size = (
            exploration_policy_q_ensemble_subset_size)

        self._exploitation_num_actor_updates = exploitation_num_actor_updates
        self._exploitation_num_critic_updates = exploitation_num_critic_updates
        self._exploitation_num_target_q_samples = exploitation_num_target_q_samples
        assert isinstance(exploitation_q_target_ensemble_subset_size,
                          (type(None), int)), (
            exploitation_q_target_ensemble_subset_size)
        self._exploitation_q_target_ensemble_subset_size = (
            exploitation_q_target_ensemble_subset_size or len(exploitation_Qs))
        self._exploitation_q_target_reduce_type = exploitation_q_target_reduce_type
        self._exploitation_q_target_type = exploitation_q_target_type
        self._exploitation_policy_update_type = exploitation_policy_update_type
        self._exploitation_policy_num_q_samples = exploitation_policy_num_q_samples
        self._exploitation_policy_q_reduce_type = exploitation_policy_q_reduce_type
        self._exploitation_policy_q_type = exploitation_policy_q_type
        # self._exploitation_policy_q_ensemble_subset_size = exploitation_policy_q_ensemble_subset_size
        assert isinstance(exploitation_policy_q_ensemble_subset_size,
                          (type(None), int)), (
            exploitation_policy_q_ensemble_subset_size)
        self._exploitation_policy_q_ensemble_subset_size = exploitation_policy_q_ensemble_subset_size

        self._exploration_policies = exploration_policies
        self._exploitation_Qs = exploitation_Qs
        self._exploitation_Q_targets = exploitation_Q_targets

        self._exploitation_Q_optimizers = tuple(
            tf.optimizers.Adam(
                learning_rate=self._Q_lr,
                name=f'exploitation_Q_{i}_optimizer'
            ) for i, Q in enumerate(self._exploitation_Qs))

        self._exploration_policy_optimizers = tuple(
            tf.optimizers.Adam(
                learning_rate=self._policy_lr,
                name=f"exploration_policy_{i}_optimizer")
            for i, policy in enumerate(self._exploration_policies)
        )


    @tf.function(experimental_relax_shapes=True)
    def _compute_Q_targets_exploration(self, batch, policy, Q_target):
        next_observations = batch['next_observations']
        rewards = batch['rewards']
        terminals = batch['terminals']

        if self._exploration_q_target_type == 'sac':
            entropy_scale = self._alpha._value()
        elif self._exploration_q_target_type == 'virel':
            entropy_scale = 0.0
        else:
            raise ValueError(self._exploration_q_target_type)
        reward_scale = self._reward_scale
        discount = self._discount

        next_actions, next_log_pis = policy.actions_and_log_probs(
            next_observations)

        assert policy is not self._policy

        if self._exploration_q_target_reduce_type == 'min':
            q_reduce_fn = tf.reduce_min
        elif self._exploration_q_target_reduce_type == 'mean':
            q_reduce_fn = tf.reduce_mean
        else:
            raise ValueError(self._exploration_q_target_reduce_type)

        assert self._exploration_num_target_q_samples, (
            self._exploration_num_target_q_samples)
        next_Q_values = q_reduce_fn([
            Q_target.values(next_observations, next_actions)
            for _ in range(self._exploration_num_target_q_samples)
        ], axis=0)

        Q_targets = compute_Q_targets(
            next_Q_values,
            next_log_pis,
            rewards,
            terminals,
            discount,
            entropy_scale,
            reward_scale)

        return tf.stop_gradient(Q_targets)

    @tf.function(experimental_relax_shapes=True)
    def _compute_Q_targets_exploitation(self, batch):
        next_observations = batch['next_observations']
        rewards = batch['rewards']
        terminals = batch['terminals']

        if self._exploitation_q_target_type == 'sac':
            entropy_scale = self._alpha._value()
        elif self._exploitation_q_target_type == 'virel':
            entropy_scale = 0.0
        else:
            raise ValueError(self._exploitation_q_target_type)

        reward_scale = self._reward_scale
        discount = self._discount

        assert not any(
            self._policy is x for x in self._exploration_policies)
        next_actions, next_log_pis = self._policy.actions_and_log_probs(
            next_observations)

        if self._exploitation_q_target_reduce_type == 'min':
            q_reduce_fn = tf.reduce_min
        elif self._exploitation_q_target_reduce_type == 'mean':
            q_reduce_fn = tf.reduce_mean
        else:
            raise ValueError(self._exploitation_q_target_reduce_type)

        def compute_Q_target(Q_target, observations, actions):
            return q_reduce_fn([
                Q_target.values(observations, actions)
                for _ in range(self._exploitation_num_target_q_samples)
            ], axis=0)

        choices = {
            i: partial(compute_Q_target, Q_target, next_observations, next_actions)
            for i, Q_target in enumerate(self._exploitation_Q_targets)
        }
        if self._exploitation_q_target_ensemble_subset_size < len(self._exploitation_Q_targets):
            Q_indices = tf.random.uniform(
                [self._exploitation_q_target_ensemble_subset_size],
                minval=0,
                maxval=len(self._exploitation_Q_targets),
                dtype=tf.int32)
        elif self._exploitation_q_target_ensemble_subset_size == len(self._exploitation_Q_targets):
            Q_indices = tf.range(0, self._exploitation_q_target_ensemble_subset_size)
        else:
            raise ValueError(self._exploitation_q_target_ensemble_subset_size, len(self._exploitation_Q_targets))

        next_Qs_values = tf.vectorized_map(
            lambda i: tf.switch_case(i, choices),
            Q_indices,
            fallback_to_while_loop=True)

        next_Q_values = q_reduce_fn(next_Qs_values, axis=0)

        Q_targets = compute_Q_targets(
            next_Q_values,
            next_log_pis,
            rewards,
            terminals,
            discount,
            entropy_scale,
            reward_scale)

        return tf.stop_gradient(Q_targets)

    @tf.function(experimental_relax_shapes=True)
    def _update_critic_exploration(self, batch):
        """Update the Q-function."""
        observations = batch['observations']
        actions = batch['actions']
        rewards = batch['rewards']

        assert len(self._Qs) == len(self._exploration_policies), (
            self._Qs, self._exploration_policies)
        assert len(self._Qs) == len(self._Q_targets), (
            self._Qs, self._Q_targets)
        assert len(self._Qs) == len(self._Q_optimizers), (
            self._Qs, self._Q_optimizers)

        Qs_values = []
        Qs_losses = []
        for Q, Q_target, policy, optimizer in zip(self._Qs,
                                                  self._Q_targets,
                                                  self._exploration_policies,
                                                  self._Q_optimizers):
            Q_targets = self._compute_Q_targets_exploration(
                batch, policy, Q_target)
            tf.debugging.assert_shapes((
                (Q_targets, ('B', 1)), (rewards, ('B', 1))))

            with tf.GradientTape() as tape:
                Q_values = Q.values(observations, actions)
                prior_loss = tf.add_n(Q.model.losses)
                Q_losses = 0.5 * (
                    tf.losses.MSE(y_true=Q_targets, y_pred=Q_values))
                Q_loss = tf.nn.compute_average_loss(Q_losses + prior_loss)

            gradients = tape.gradient(Q_loss, Q.trainable_variables)
            optimizer.apply_gradients(zip(gradients, Q.trainable_variables))
            Qs_losses.append(Q_losses)
            Qs_values.append(Q_values)

        return Qs_values, Qs_losses, prior_loss

    @tf.function(experimental_relax_shapes=True)
    def _update_critic_exploitation(self, batch):
        """Update the Q-function."""
        Q_targets = self._compute_Q_targets_exploitation(batch)

        observations = batch['observations']
        actions = batch['actions']
        rewards = batch['rewards']

        tf.debugging.assert_shapes((
            (Q_targets, ('B', 1)), (rewards, ('B', 1))))

        assert len(self._exploitation_Qs) == len(self._exploitation_Q_optimizers), (
            self._exploitation_Qs, self._exploitation_Q_optimizers)

        Qs_values = []
        Qs_losses = []
        for Q, optimizer in zip(self._exploitation_Qs, self._exploitation_Q_optimizers):
            with tf.GradientTape() as tape:
                Q_values = Q.values(observations, actions)
                prior_loss = tf.add_n(Q.model.losses) if Q.model.losses else 0.0
                Q_losses = 0.5 * (
                    tf.losses.MSE(y_true=Q_targets, y_pred=Q_values))
                Q_loss = tf.nn.compute_average_loss(Q_losses + prior_loss)

            gradients = tape.gradient(Q_loss, Q.trainable_variables)
            optimizer.apply_gradients(zip(gradients, Q.trainable_variables))
            Qs_losses.append(Q_losses)
            Qs_values.append(Q_values)

        return Qs_values, Qs_losses, prior_loss

    @tf.function(experimental_relax_shapes=True)
    def _update_actor_exploration(self, batch):
        """Update the policy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.

        See Section 4.2 in [1], for further information of the policy update,
        and Section 5 in [1] for further information of the entropy update.
        """
        observations = batch['observations']

        assert len(self._Qs) == len(self._exploration_policies), (
            self._Qs, self._exploration_policies)
        assert len(self._exploration_policy_optimizers) == len(self._exploration_policies), (
            self._exploration_policy_optimizers, self._exploration_policies)
        assert len(self._Qs) == len(self._Q_targets), (
            self._Qs, self._Q_targets)

        if self._exploration_policy_q_type == 'Q':
            Qs = self._Qs
        elif self._exploration_policy_q_type == 'Q_target':
            Qs = self._Q_targets
        else:
            raise ValueError(self._exploration_policy_q_type)

        # TODO: Currently, each policy has its own Q function and
        # no ensembling is used.
        assert self._exploration_policy_q_ensemble_subset_size is None, (
            self._exploration_policy_q_ensemble_subset_size)
        # if self._exploration_policy_q_ensemble_subset_size < len(Qs):
        #     Q_indices = tf.random.uniform(
        #         [self._exploration_policy_q_ensemble_subset_size],
        #         minval=0,
        #         maxval=len(Qs),
        #         dtype=tf.int32)
        # elif self._exploration_policy_q_ensemble_subset_size == len(Qs):
        #     Q_indices = tf.range(0, self._exploration_policy_q_ensemble_subset_size)
        # else:
        #     raise ValueError(self._exploration_policy_q_ensemble_subset_size, len(Qs))

        all_policy_losses = []
        for policy, optimizer, Q in zip(self._exploration_policies,
                                        self._exploration_policy_optimizers,
                                        self._Qs):
            with tf.GradientTape() as tape:
                actions, log_pis = policy.actions_and_log_probs(observations)

                # choices = {
                #     i: lambda: Q.values(observations, actions)
                #     for i, Q in enumerate(Qs)
                # }
                # Qs_values = tf.vectorized_map(
                #     lambda i: tf.switch_case(i, choices),
                #     Q_indices,
                #     fallback_to_while_loop=True)

                Q_values = tf.reduce_mean([
                    Q.values(observations, actions)
                    for i in range(self._exploration_policy_num_q_samples)
                ], axis=0)

                if self._exploration_policy_update_type == 'vbac':
                    policy_losses = - Q_values
                elif self._exploration_policy_update_type == 'sac':
                    policy_losses = self._alpha * log_pis - Q_values
                else:
                    raise ValueError(self._exploration_policy_update_type)

                policy_loss = tf.nn.compute_average_loss(policy_losses)

            all_policy_losses.append(policy_loss)
            policy_gradients = tape.gradient(
                policy_loss, policy.trainable_variables)

            optimizer.apply_gradients(zip(
                policy_gradients, policy.trainable_variables))

            tf.debugging.assert_shapes((
                (actions, ('B', 'nA')),
                (log_pis, ('B', 1)),
                (policy_losses, ('B', 1)),
            ))

        return policy_losses

    @tf.function(experimental_relax_shapes=True)
    def _update_actor_exploitation(self, batch):
        observations = batch['observations']

        if self._exploitation_policy_q_reduce_type == 'min':
            q_reduce_fn = tf.reduce_min
        elif self._exploitation_policy_q_reduce_type == 'mean':
            q_reduce_fn = tf.reduce_mean
        else:
            raise ValueError(self._exploitation_policy_q_reduce_type)

        if self._exploitation_policy_q_type == 'exploitation_Q':
            Qs = self._exploitation_Qs
        elif self._exploitation_policy_q_type == 'exploitation_Q_target':
            Qs = self._exploitation_Q_targets
        elif self._exploitation_policy_q_type == 'exploration_Q':
            Qs = self._Qs
        elif self._exploitation_policy_q_type == 'exploration_Q_target':
            Qs = self._Q_targets
        else:
            raise ValueError(self._exploitation_policy_q_type)

        if (self._exploitation_policy_q_ensemble_subset_size is None
            or self._exploitation_policy_q_ensemble_subset_size == len(Qs)):
            Q_indices = tf.range(0, len(Qs))
        elif self._exploitation_policy_q_ensemble_subset_size < len(Qs):
            Q_indices = tf.random.uniform(
                [self._exploitation_policy_q_ensemble_subset_size],
                minval=0,
                maxval=len(Qs),
                dtype=tf.int32)
        else:
            raise ValueError(self._exploitation_policy_q_ensemble_subset_size, len(Qs))

        assert self._exploitation_policy_num_q_samples == 1, (
            self._exploitation_policy_num_q_samples)

        choices = {
            i: lambda: Q.values(observations, actions)
            for i, Q in enumerate(Qs)
        }
        with tf.GradientTape() as tape:
            actions, log_pis = self._policy.actions_and_log_probs(observations)

            Qs_values = tf.vectorized_map(
                lambda i: tf.switch_case(i, choices),
                Q_indices,
                fallback_to_while_loop=True)

            Q_values = q_reduce_fn(Qs_values, axis=0)

            if self._exploitation_policy_update_type == 'vbac':
                policy_losses = - Q_values
            elif self._exploitation_policy_update_type == 'sac':
                policy_losses = self._alpha * log_pis - Q_values
            else:
                raise ValueError(self._exploitation_policy_update_type)

            policy_loss = tf.nn.compute_average_loss(policy_losses)

        tf.debugging.assert_shapes((
            (actions, ('B', 'nA')),
            (log_pis, ('B', 1)),
            (policy_losses, ('B', 1)),
        ))

        policy_gradients = tape.gradient(
            policy_loss, self._policy.trainable_variables)

        self._policy_optimizer.apply_gradients(zip(
            policy_gradients, self._policy.trainable_variables))

        return policy_losses

    @tf.function(experimental_relax_shapes=True)
    def _update_target(self, tau):
        assert len(self._Qs) == len(self._Q_targets), (
            len(self._Qs), len(self._Q_targets))
        for Q, Q_target in zip(self._Qs, self._Q_targets):
            for source_weight, target_weight in zip(
                    Q.variables, Q_target.variables):
                target_weight.assign(
                    tau * source_weight + (1.0 - tau) * target_weight)

    @tf.function(experimental_relax_shapes=True)
    def _update_target_exploitation(self, tau):
        assert len(self._exploitation_Qs) == len(self._exploitation_Q_targets), (
            len(self._exploitation_Qs), len(self._exploitation_Q_targets))
        for Q, Q_target in zip(self._exploitation_Qs, self._exploitation_Q_targets):
            for source_weight, target_weight in zip(
                    Q.variables, Q_target.variables):
                target_weight.assign(
                    tau * source_weight + (1.0 - tau) * target_weight)

    @tf.function(experimental_relax_shapes=True)
    def _do_updates(self, batch):
        raise NotImplementedError()

    @tf.function(experimental_relax_shapes=True)
    def _do_training(self, iteration, batch):
        raise NotImplementedError()

    def _do_training_repeats(self, timestep):
        """Repeat training _n_train_repeat times every _train_every_n_steps"""

        exploration_Q_diagnostics = []
        for i in range(self._exploration_num_critic_updates):
            exploration_Q_diagnostics.append(
                self._update_critic_exploration(self._training_batch()))
            self._update_target(tau=tf.constant(self._tau))
        exploration_Qs_values, exploration_Qs_losses, exploration_prior_loss = (
            zip(*exploration_Q_diagnostics))

        exploration_Q_diagnostics = OrderedDict((
            ('Q_value-mean', tf.reduce_mean(exploration_Qs_values)),
            ('Q_loss-mean', tf.reduce_mean(exploration_Qs_losses)),
            ('prior_loss', tf.reduce_mean(exploration_prior_loss)),
        ))

        exploitation_Q_diagnostics = []
        for i in range(self._exploitation_num_critic_updates):
            exploitation_Q_diagnostics.append(
                self._update_critic_exploitation(self._training_batch()))
            self._update_target_exploitation(tau=tf.constant(self._tau))

        if exploitation_Q_diagnostics:
            exploitation_Qs_values, exploitation_Qs_losses, exploitation_prior_loss = (
                zip(*exploitation_Q_diagnostics))
        else:
            exploitation_Qs_values = 0.0
            exploitation_Qs_losses = 0.0
            exploitation_prior_loss = 0.0

        exploitation_Q_diagnostics = OrderedDict((
            ('Q_value-mean', tf.reduce_mean(exploitation_Qs_values)),
            ('Q_loss-mean', tf.reduce_mean(exploitation_Qs_losses)),
            ('prior_loss', tf.reduce_mean(exploitation_prior_loss)),
        ))

        exploration_policy_losses = [
            self._update_actor_exploration(self._training_batch())
            for i in range(self._exploration_num_actor_updates)
        ]
        exploration_policy_diagnostics = OrderedDict((
            ('loss-mean', tf.reduce_mean(exploration_policy_losses)),
        ))

        exploitation_policy_losses = [
            self._update_actor_exploitation(self._training_batch())
            for i in range(self._exploitation_num_actor_updates)
        ]
        exploitation_policy_diagnostics = OrderedDict((
            ('loss-mean', tf.reduce_mean(exploitation_policy_losses)),
        ))

        if self._exploitation_policy_update_type != 'vbac':
            alpha_losses = self._update_alpha(self._training_batch())
        else:
            alpha_losses = 0.0

        diagnostics = OrderedDict((
            ('exploration-Q', exploration_Q_diagnostics),
            ('exploitation-Q', exploitation_Q_diagnostics),
            ('exploration-policy', exploration_policy_diagnostics),
            ('exploitation-policy', exploitation_policy_diagnostics),
            ('alpha', tf.convert_to_tensor(self._alpha)),
            ('alpha_loss-mean', tf.reduce_mean(alpha_losses)),
        ))

        return diagnostics

    def get_diagnostics(self,
                        iteration,
                        batch,
                        training_paths,
                        evaluation_paths):
        """Return diagnostic information as an ordered dictionary.

        Also calls the `draw` method of the plotter, if plotter defined.
        """
        diagnostics = super(BayesianBellmanActorCritic, self).get_diagnostics(
            iteration, batch, training_paths, evaluation_paths)

        if (('MountainCarContinuous' in str(getattr(self._training_environment, 'env', None))
             or 'Pendulum' in str(getattr(self._training_environment, 'env', None)))
            and (iteration < 6000 or iteration % 1000 == 0)):
            self._plot_mountain_car_diagnostics(
                iteration, batch, training_paths, evaluation_paths)

        return diagnostics

    def _plot_mountain_car_diagnostics(self,
                                       iteration,
                                       *args,
                                       **kwargs):
        value_and_uncertainty_actions = [-1.0, -0.5, 0.0, 0.5, 1.0]

        gridspec_ncols = np.lcm(len(value_and_uncertainty_actions), 3)
        gridspec_nrows = 22

        figure = plt.figure(
            figsize=3 * plt.figaspect(1.25), constrained_layout=True)
        gridspec = figure.add_gridspec(
            ncols=gridspec_ncols, nrows=gridspec_nrows)

        value_and_uncertainty_axis_width = (
            gridspec_ncols // len(value_and_uncertainty_actions))
        exploration_value_and_uncertainty_axes = [
            figure.add_subplot(
                # gridspec[5, i*3:(i+1)*3],
                gridspec[0:4, i*value_and_uncertainty_axis_width:(i+1)*value_and_uncertainty_axis_width],
                projection='3d')
            for i in range(len(value_and_uncertainty_actions))
        ]

        self._plot_values_and_uncertainties(
            iteration,
            *args,
            **kwargs,
            figure=figure,
            actions=value_and_uncertainty_actions,
            axes=exploration_value_and_uncertainty_axes,
            Qs=self._Qs)

        exploitation_value_and_uncertainty_axes = [
            figure.add_subplot(
                # gridspec[5, i*3:(i+1)*3],
                gridspec[4:8, i*value_and_uncertainty_axis_width:(i+1)*value_and_uncertainty_axis_width],
                projection='3d')
            for i in range(len(value_and_uncertainty_actions))
        ]

        self._plot_values_and_uncertainties(
            iteration,
            *args,
            **kwargs,
            figure=figure,
            actions=value_and_uncertainty_actions,
            axes=exploitation_value_and_uncertainty_axes,
            Qs=self._exploitation_Qs)

        ensemble_diagnostics_uncertainty_axis_width = gridspec_ncols // 3

        j = 8
        i = 0
        exploration_ensemble_diagnostics_uncertainty_contour_axis = figure.add_subplot(
            gridspec[j:j+6, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width], aspect='equal')
        exploration_ensemble_diagnostics_uncertainty_colorbar_axis = figure.add_subplot(
            gridspec[j+6:j+7, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width])
        i += 1
        exploration_ensemble_diagnostics_value_contour_axis = figure.add_subplot(
            gridspec[j:j+6, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width], aspect='equal')
        exploration_ensemble_diagnostics_value_colorbar_axis = figure.add_subplot(
            gridspec[j+6:j+7, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width])

        self._plot_ensemble_diagnostics(
            iteration,
            *args,
            figure=figure,
            uncertainty_contour_axis=(
                exploration_ensemble_diagnostics_uncertainty_contour_axis),
            uncertainty_colorbar_axis=(
                exploration_ensemble_diagnostics_uncertainty_colorbar_axis),
            value_contour_axis=exploration_ensemble_diagnostics_value_contour_axis,
            value_colorbar_axis=exploration_ensemble_diagnostics_value_colorbar_axis,
            **kwargs,
            Qs=self._Qs,
            policies=self._exploration_policies)

        # i += 1
        # state_support_value_histogram_axis = figure.add_subplot(
        #     gridspec[j:j+6, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width], aspect='equal')
        # state_support_value_colorbar_axis = figure.add_subplot(
        #     gridspec[j+6:j+7, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width])

        # self._plot_state_support(
        #     iteration,
        #     *args,
        #     figure=figure,
        #     histogram_axis=state_support_value_histogram_axis,
        #     colorbar_axis=state_support_value_colorbar_axis,
        #     **kwargs)

        j = 8 + 7
        i = 0
        exploitation_ensemble_diagnostics_uncertainty_contour_axis = figure.add_subplot(
            gridspec[j:j+6, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width], aspect='equal')
        exploitation_ensemble_diagnostics_uncertainty_colorbar_axis = figure.add_subplot(
            gridspec[j+6:j+7, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width])
        i += 1
        exploitation_ensemble_diagnostics_value_contour_axis = figure.add_subplot(
            gridspec[j:j+6, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width], aspect='equal')
        exploitation_ensemble_diagnostics_value_colorbar_axis = figure.add_subplot(
            gridspec[j+6:j+7, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width])

        self._plot_ensemble_diagnostics(
            iteration,
            *args,
            figure=figure,
            uncertainty_contour_axis=(
                exploitation_ensemble_diagnostics_uncertainty_contour_axis),
            uncertainty_colorbar_axis=(
                exploitation_ensemble_diagnostics_uncertainty_colorbar_axis),
            value_contour_axis=exploitation_ensemble_diagnostics_value_contour_axis,
            value_colorbar_axis=exploitation_ensemble_diagnostics_value_colorbar_axis,
            **kwargs,
            Qs=self._exploitation_Qs or self._Qs,
            policies=[self._policy] * len(self._exploitation_Qs or self._Qs))

        i += 1
        state_support_value_histogram_axis = figure.add_subplot(
            gridspec[j:j+6, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width], aspect='equal')
        state_support_value_colorbar_axis = figure.add_subplot(
            gridspec[j+6:j+7, i*ensemble_diagnostics_uncertainty_axis_width:(i+1)*ensemble_diagnostics_uncertainty_axis_width])

        self._plot_state_support(
            iteration,
            *args,
            figure=figure,
            histogram_axis=state_support_value_histogram_axis,
            colorbar_axis=state_support_value_colorbar_axis,
            **kwargs)

        figure_dir = os.path.join(os.getcwd(), 'figures')
        os.makedirs(figure_dir, exist_ok=True)
        figure_path = os.path.join(figure_dir, f'{iteration:05}.png')
        plt.savefig(figure_path)
        figure.clf()

    def _plot_values_and_uncertainties(self,
                                       iteration,
                                       batch,
                                       training_paths,
                                       evaluation_paths,
                                       *,
                                       actions,
                                       figure,
                                       axes,
                                       Qs):
        observation_space = self._training_environment.observation_space[
            'observations']
        low, high = observation_space.low, observation_space.high
        assert low.size == high.size, (low, high)

        if 'MountainCarContinuous' in str(self._training_environment.env):
            xy = np.stack(np.meshgrid(
                *np.split(np.linspace(low, high), low.size, axis=-1),
                indexing='ij',
            ), axis=-1).reshape(-1, low.size)
            encoded_xy = xy
            xlabel = 'position'
            ylabel = 'velocity'
            xlim = low[0], high[0]
            ylim = low[1], high[1]
        elif 'Pendulum' in str(self._training_environment.env):
            encoded_xy = np.stack(np.meshgrid(
                *np.split(np.linspace([-np.pi, low[2]], [np.pi, high[2]]), 2, axis=-1),
                indexing='ij',
            ), axis=-1).reshape(-1, 2)
            xy =  np.concatenate((
                np.cos(encoded_xy[:, [0]]),
                np.sin(encoded_xy[:, [0]]),
                encoded_xy[:, [1]],
            ), axis=-1)
            encoded_xy[:, 0] /= np.pi
            xlabel = 'angle'
            ylabel = 'velocity'
            xlim = -1.0, 1.0
            ylim = low[2], high[2]
        else:
            raise NotImplementedError(str(self._training_environment.env))

        for axis, action in zip(axes, actions):
            full_actions = np.full((*xy.shape[:-1], 1), action)

            for Q in Qs:
                for i in range(1):
                    values = Q.values(xy, full_actions).numpy()
                    axis.plot_surface(
                        encoded_xy[..., 0].reshape(50, 50),
                        encoded_xy[..., 1].reshape(50, 50),
                        values.reshape(50, 50),
                        cmap='PuBuGn')

            axis.set_xlabel(xlabel)
            axis.set_ylabel(ylabel)
            axis.set_zlabel('Q')
            axis.set_xlim(*xlim)
            axis.set_ylim(*ylim)
            axis.set_title(str(action))

    def _plot_ensemble_diagnostics(self,
                                   iteration,
                                   batch,
                                   training_paths,
                                   evaluation_paths,
                                   *,
                                   figure,
                                   uncertainty_contour_axis,
                                   uncertainty_colorbar_axis,
                                   value_contour_axis,
                                   value_colorbar_axis,
                                   Qs,
                                   policies):
        observation_space = self._training_environment.observation_space[
            'observations']
        low, high = observation_space.low, observation_space.high
        assert low.size == high.size, (low, high)

        if 'MountainCarContinuous' in str(self._training_environment.env):
            xy = np.stack(np.meshgrid(
                *np.split(np.linspace(low, high), low.size, axis=-1),
                indexing='ij',
            ), axis=-1).reshape(-1, low.size)
            encoded_xy = xy
            xlabel = 'position'
            ylabel = 'velocity'
            xlim = low[0], high[0]
            ylim = low[1], high[1]
        elif 'Pendulum' in str(self._training_environment.env):
            encoded_xy = np.stack(np.meshgrid(
                *np.split(np.linspace([-np.pi, low[2]], [np.pi, high[2]]), 2, axis=-1),
                indexing='ij',
            ), axis=-1).reshape(-1, 2)
            xy =  np.concatenate((
                np.cos(encoded_xy[:, [0]]),
                np.sin(encoded_xy[:, [0]]),
                encoded_xy[:, [1]],
            ), axis=-1)
            encoded_xy[:, 0] /= np.pi
            xlabel = 'angle'
            ylabel = 'velocity'
            xlim = -1.0, 1.0
            ylim = low[2], high[2]
        else:
            raise NotImplementedError(str(self._training_environment.env))

        Q_values = [
            Q.values(xy, policy.actions(tf.constant(xy)))
            for Q, policy in zip(Qs, policies)
        ]

        uncertainties = tf.math.reduce_std(Q_values, axis=0).numpy()
        values = tf.math.reduce_mean(Q_values, axis=0).numpy()

        contourf = uncertainty_contour_axis.contourf(
            encoded_xy[..., 0].reshape(50, 50),
            encoded_xy[..., 1].reshape(50, 50),
            uncertainties.reshape(50, 50),
            levels=25,
            cmap='PuBuGn')
        uncertainty_contour_axis.contour(
            contourf, colors='black', linewidths=0.5)
        figure.colorbar(
            contourf,
            cax=uncertainty_colorbar_axis,
            orientation='horizontal')
        uncertainty_contour_axis.set_title('uncertainty')
        uncertainty_contour_axis.set_xlabel(xlabel)
        uncertainty_contour_axis.set_ylabel(ylabel)
        uncertainty_contour_axis.set_xlim(*xlim)
        uncertainty_contour_axis.set_ylim(*ylim)

        contourf = value_contour_axis.contourf(
            encoded_xy[..., 0].reshape(50, 50),
            encoded_xy[..., 1].reshape(50, 50),
            values.reshape(50, 50),
            levels=25,
            cmap='PuBuGn')
        value_contour_axis.contour(
            contourf, colors='black', linewidths=0.5)
        figure.colorbar(
            contourf,
            cax=value_colorbar_axis,
            orientation='horizontal')
        value_contour_axis.set_title('values')
        value_contour_axis.set_xlabel(xlabel)
        value_contour_axis.set_ylabel(ylabel)
        value_contour_axis.set_xlim(*xlim)
        value_contour_axis.set_ylim(*ylim)

    def _plot_state_support(self,
                            iteration,
                            batch,
                            training_paths,
                            evaluation_paths,
                            *,
                            figure,
                            histogram_axis,
                            colorbar_axis):
        observation_space = self._training_environment.observation_space[
            'observations']
        low, high = observation_space.low, observation_space.high
        assert set(self.pool.data['observations'].keys()) == {'observations'}
        observations = self.pool.data[
            'observations']['observations'][:self.pool.size]

        # default_figsize = plt.rcParams.get('figure.figsize')
        # figsize = np.array((1, 1)) * np.max(default_figsize[0])
        # figure, axis = plt.subplots(nrows=1, ncols=1, figsize=figsize)

        if 'MountainCarContinuous' in str(self._training_environment.env):
            encoded_observations = observations
            xlabel = 'position'
            ylabel = 'velocity'
            xlim = low[0], high[0]
            ylim = low[1], high[1]
        elif 'Pendulum' in str(self._training_environment.env):
            encoded_observations =  np.concatenate((
                np.arctan2(
                    observations[:, 1],
                    observations[:, 0],
                )[..., None] / np.pi,
                observations[:, 2:],
            ), axis=-1)
            xlabel = 'angle'
            ylabel = 'velocity'
            xlim = -1.0, 1.0
            ylim = low[2], high[2]
        else:
            raise NotImplementedError(str(self._training_environment.env))

        h, xedges, yedges, image = histogram_axis.hist2d(
            encoded_observations[..., 0],
            encoded_observations[..., 1],
            bins=100,
            # range=[[low[0], high[0]], [low[1], high[1]]],
            # range=np.stack((low, high)).T.tolist(),
            # range=np.array((low[0], high[0]),
            #                (low[1], high[1])),
            cmap='PuBuGn',
            norm=mpl.colors.PowerNorm(0.1))

        if 'MountainCarContinuous' in str(self._training_environment.env):
            from softlearning.environments.gym.wrappers.rescale_observation import rescale_values
            goal_x_position = rescale_values(
                self._training_environment.env.goal_position,
                self._training_environment.env.unwrapped.observation_space.low[0],
                self._training_environment.env.unwrapped.observation_space.high[0],
                self._training_environment.env.observation_space.low[0],
                self._training_environment.env.observation_space.high[0])
            histogram_axis.axvline(x=goal_x_position, linestyle=':')

        figure.colorbar(
            image,
            cax=colorbar_axis,
            orientation='horizontal',
            # anchor=(),
            # shrink=0.9,
        )

        histogram_axis.set_title("State support")
        histogram_axis.set_xlabel(xlabel)
        histogram_axis.set_ylabel(ylabel)
        histogram_axis.set_xlim(*xlim)
        histogram_axis.set_ylim(*ylim)

