import argparse
import glob
import os
import pickle
from typing import Dict

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import pandas as pd
from ray import tune
import tree

import seaborn as sns
import matplotlib as mpl
# mpl.use('macosx')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import FuncFormatter

from softlearning.environments.utils import get_environment_from_params
from softlearning.environments.gym.wrappers.rescale_observation import (
    rescale_values)
from softlearning import policies
from softlearning import value_functions
from softlearning import replay_pools
from softlearning.samplers import rollouts
from softlearning.utils.tensorflow import set_gpu_memory_growth
from softlearning.utils.video import save_video
from softlearning.utils.serialization import custom_object_scope

from bbac.runners.bbac.main import ExperimentRunner
from bbac.value_functions import (
    random_prior_ensemble_feedforward_Q_function,
    random_prior_feedforward_Q_function,
    bayesian_feedforward_Q_function,
    flipout_feedforward_Q_function)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('experiment_path',
                        type=str,
                        help='Path to the experiment.')

    args = parser.parse_args()

    return args


def plot_values_and_uncertainties(*,
                                  environment,
                                  figure,
                                  axis,
                                  Qs,
                                  policies):
    observation_space = environment.observation_space['observations']
    low, high = observation_space.low, observation_space.high
    assert low.size == high.size, (low, high)

    nx = ny = 16
    if 'MountainCarContinuous' in str(environment.env):
        xy = np.stack(np.meshgrid(
            *np.split(np.linspace(low, high, nx), low.size, axis=-1),
            indexing='ij',
        ), axis=-1).reshape(-1, low.size)
        encoded_xy = xy
        xlabel = 'position'
        ylabel = 'velocity'

        goal_x_position = rescale_values(
            environment.env.goal_position,
            environment.env.unwrapped.observation_space.low[0],
            environment.env.unwrapped.observation_space.high[0],
            environment.env.observation_space.low[0],
            environment.env.observation_space.high[0])

        xlim = low[0], goal_x_position
        ylim = low[1], high[1]
    elif 'Pendulum' in str(environment.env):
        encoded_xy = np.stack(np.meshgrid(
            *np.split(np.linspace([-np.pi, low[2]], [np.pi, high[2]]), 2, axis=-1),
            indexing='ij',
        ), axis=-1).reshape(-1, 2)
        xy =  np.concatenate((
            np.cos(encoded_xy[:, [0]]),
            np.sin(encoded_xy[:, [0]]),
            encoded_xy[:, [1]],
        ), axis=-1)
        encoded_xy[:, 0] /= np.pi
        xlabel = 'angle'
        ylabel = 'velocity'
        xlim = -1.0, 1.0
        ylim = low[2], high[2]
    else:
        raise NotImplementedError(str(environment.env))

    np.testing.assert_equal(len(Qs), len(policies))

    Qs_values = np.stack([
        Q.values(xy, policy.actions(tf.constant(xy)))
        for Q, policy in zip(Qs, policies)
    ])

    min_Q_value = Qs_values.min()
    max_Q_value = Qs_values.max()

    for Q_values in Qs_values:
        surface = axis.plot_surface(
            encoded_xy[..., 0].reshape(nx, ny),
            encoded_xy[..., 1].reshape(nx, ny),
            Q_values.reshape(nx, ny),
            cmap='PuBuGn',
            # cmap='viridis',
            vmin=min_Q_value,
            vmax=max_Q_value,
            alpha=0.9,
            linewidth=0.1,
            edgecolors='black')
        # surface.set_rasterized(True)

    axis.tick_params(axis='x', which='major', pad=-6)
    axis.tick_params(axis='y', which='major', pad=-6)
    axis.tick_params(axis='z', which='major', rotation=45, pad=-2)

    axis.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    axis.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    axis.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

    axis.set_xlabel(xlabel, labelpad=-10)
    axis.set_ylabel(ylabel, labelpad=-10)

    axis.zaxis._axinfo['label']['space_factor'] = 10

    axis.set_xlim(*xlim)
    axis.set_ylim(*ylim)


def plot_state_support(*,
                       environment,
                       figure,
                       pool,
                       timestep,
                       histogram_axis,
                       colorbar_axis):
    observation_space = environment.observation_space[
        'observations']
    low, high = observation_space.low, observation_space.high
    assert set(pool.data['observations'].keys()) == {'observations'}
    observations = pool.data['observations']['observations'][:timestep]

    # default_figsize = plt.rcParams.get('figure.figsize')
    # figsize = np.array((1, 1)) * np.max(default_figsize[0])
    # figure, axis = plt.subplots(nrows=1, ncols=1, figsize=figsize)

    if 'MountainCarContinuous' in str(environment.env):
        encoded_observations = observations
        xlabel = 'position'
        ylabel = 'velocity'

        goal_x_position = rescale_values(
            environment.env.goal_position,
            environment.env.unwrapped.observation_space.low[0],
            environment.env.unwrapped.observation_space.high[0],
            environment.env.observation_space.low[0],
            environment.env.observation_space.high[0])

        xlim = low[0], goal_x_position
        ylim = low[1], high[1]
    elif 'Pendulum' in str(environment.env):
        encoded_observations =  np.concatenate((
            np.arctan2(
                observations[:, 1],
                observations[:, 0],
            )[..., None] / np.pi,
            observations[:, 2:],
        ), axis=-1)
        xlabel = 'angle'
        ylabel = 'velocity'
        xlim = -1.0, 1.0
        ylim = low[2], high[2]
    else:
        raise NotImplementedError(str(environment.env))

    # cmap = cm.get_cmap('PuBuGn', 512)
    # new_cmap = cmap(np.linspace(0, 1, 512))
    # low_color = cmap(512 // 2)
    # new_cmap[:512 // 2, :] = low_color

    h, xedges, yedges, image = histogram_axis.hist2d(
        encoded_observations[..., 0],
        encoded_observations[..., 1],
        # bins=100,
        bins=80,
        # range=[[low[0], high[0]], [low[1], high[1]]],
        # range=np.stack((low, high)).T.tolist(),
        # range=np.array((low[0], high[0]),
        #                (low[1], high[1])),
        vmin=0,
        cmin=1,
        cmap='PuBuGn',
        # cmap=new_cmap,
        # cmap='viridis',
        # norm=mpl.colors.LogNorm(),
        # norm=mpl.colors.PowerNorm(0.5, vmin=-5, vmax=5),
        norm=mpl.colors.PowerNorm(0.5, vmin=-5, vmax=9),
        # norm=mpl.colors.PowerNorm(0.5, vmin=-100, vmax=100),
    )

    # if 'MountainCarContinuous' in str(environment.env):
    #     from softlearning.environments.gym.wrappers.rescale_observation import rescale_values
    #     goal_x_position = rescale_values(
    #         environment.env.goal_position,
    #         environment.env.unwrapped.observation_space.low[0],
    #         environment.env.unwrapped.observation_space.high[0],
    #         environment.env.observation_space.low[0],
    #         environment.env.observation_space.high[0])
    #     # histogram_axis.axvline(x=goal_x_position, linestyle=':')

    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(colorbar_axis)
    cax_size = (1.0 - 0.8333333134651191) * 100.0
    cax = divider.new_horizontal(size=f"{cax_size}%", pad=0, pack_start=False)
    from matplotlib.ticker import MaxNLocator
    figure.add_axes(cax)

    colorbar = figure.colorbar(
        # image.get_cmap(),
        image,
        cax=cax,
        # cax=colorbar_axis,
        # ax=colorbar_axis,
        # orientation='horizontal',
        orientation='vertical',
        extend='max',
        # fraction=1.0 - 0.87,
        extendrect=True,
        # anchor=(),
        # shrink=0.9,
        # format='%1i',
        ticks=MaxNLocator(integer=True)
        # norm=mpl.colors.Normalize(vmin=0),
    )

    # cax.yaxis.set_major_locator(MaxNLocator(integer=True))
    # colorbar.update_normal()

    colorbar.set_ticks(MaxNLocator(integer=True, nbins=4))

    histogram_axis.set_xlabel(xlabel)
    histogram_axis.set_ylabel(ylabel)
    histogram_axis.set_xlim(*xlim)
    histogram_axis.set_ylim(*ylim)


def load_variant_progress_metadata(checkpoint_path):
    checkpoint_path = checkpoint_path.rstrip('/')
    trial_path = os.path.dirname(checkpoint_path)

    variant_path = os.path.join(trial_path, 'params.pkl')
    with open(variant_path, 'rb') as f:
        variant = pickle.load(f)

    metadata_path = os.path.join(checkpoint_path, ".tune_metadata")
    if os.path.exists(metadata_path):
        with open(metadata_path, "rb") as f:
            metadata = pickle.load(f)
    else:
        metadata = None

    progress_path = os.path.join(trial_path, 'progress.csv')
    progress = pd.read_csv(progress_path)

    return variant, progress, metadata


def load_environment(variant):
    environment_params = (
        variant['environment_params']['training']
        if 'evaluation' in variant['environment_params']
        else variant['environment_params']['training'])

    environment = get_environment_from_params(environment_params)
    return environment


def load_policy(checkpoint_dir, variant, environment):
    policy_params = variant['policy_params'].copy()
    policy_params['config'] = {
        **policy_params['config'],
        'action_range': (environment.action_space.low,
                         environment.action_space.high),
        'input_shapes': environment.observation_shape,
        'output_shape': environment.action_shape,
    }

    policy = policies.get(policy_params)

    policy_save_path = ExperimentRunner._policy_save_path(checkpoint_dir)
    status = policy.load_weights(policy_save_path)
    status.assert_consumed().run_restore_ops()

    return policy


def load_bbac_models(variant, logdir):
    environment = load_environment(variant)
    Q_input_shapes = (
        environment.observation_shape,
        environment.action_shape)
    policy_input_shapes = environment.observation_shape

    variant['Q_params']['config'].update({
        'input_shapes': Q_input_shapes,
    })

    variant['exploitation_Q_params']['config'].update({
        'input_shapes': Q_input_shapes,
    })
    with custom_object_scope({
            'random_prior_ensemble_feedforward_Q_function':
            random_prior_ensemble_feedforward_Q_function,
            'random_prior_feedforward_Q_function':
            random_prior_feedforward_Q_function,
            'bayesian_feedforward_Q_function':
            bayesian_feedforward_Q_function,
            'flipout_feedforward_Q_function':
            flipout_feedforward_Q_function,
    }):
        exploration_Qs = tree.flatten(
            value_functions.get(variant['Q_params']))
        exploitation_Qs = tree.flatten(
            value_functions.get(variant['exploitation_Q_params']))

    variant['policy_params']['config'].update({
        'action_range': (environment.action_space.low,
                         environment.action_space.high),
        'input_shapes': policy_input_shapes,
        'output_shape': environment.action_shape,
    })
    variant['exploration_policy_params']['config'].update({
        'action_range': (environment.action_space.low,
                         environment.action_space.high),
        'input_shapes': policy_input_shapes,
        'output_shape': environment.action_shape,
    })
    exploitation_policy = policies.get(
        variant['policy_params'])
    exploration_policies = [
        policies.get(variant['exploration_policy_params'])
        for i in range(len(exploration_Qs))
    ]

    variant['replay_pool_params']['config'].update({
        'environment': environment,
    })
    replay_pool = replay_pools.get(
        variant['replay_pool_params'])

    replay_pool.load_experience(os.path.join(logdir, 'replay_pool.pkl'))

    models = (
        environment,
        replay_pool,
        exploitation_policy,
        exploration_policies,
        exploration_Qs,
        exploitation_Qs,
    )

    return models


def load_sac_models(variant, logdir):
    environment = load_environment(variant)
    Q_input_shapes = (
        environment.observation_shape,
        environment.action_shape)
    policy_input_shapes = environment.observation_shape

    variant['Q_params']['config'].update({
        'input_shapes': Q_input_shapes,
    })

    Qs = tree.flatten(value_functions.get(variant['Q_params']))

    variant['policy_params']['config'].update({
        'action_range': (environment.action_space.low,
                         environment.action_space.high),
        'input_shapes': policy_input_shapes,
        'output_shape': environment.action_shape,
    })
    policy = policies.get(variant['policy_params'])

    variant['replay_pool_params']['config'].update({
        'environment': environment,
    })
    replay_pool = replay_pools.get(
        variant['replay_pool_params'])
    replay_pool.load_experience(os.path.join(logdir, 'replay_pool.pkl'))

    models = (
        environment,
        replay_pool,
        policy,
        Qs,
    )

    return models


def initialize_figure_and_axes(num_checkpoints):
    gridspec_ncols = np.lcm(num_checkpoints, 3)
    gridspec_nrows = 20

    figure_scale = 1
    figure = plt.figure(
        figsize=figure_scale * plt.figaspect(2.0/num_checkpoints),
        # constrained_layout=True,
    )
    gridspec = figure.add_gridspec(
        ncols=gridspec_ncols, nrows=gridspec_nrows)

    axis_width = gridspec_ncols // num_checkpoints
    colorbar_axis_height = 0
    value_axis_height = (gridspec_nrows - colorbar_axis_height) // 2
    state_support_axis_height = value_axis_height

    value_axes = [
        figure.add_subplot(
            gridspec[
                0:value_axis_height,
                i*axis_width:(i+1)*axis_width
            ],
            projection='3d',
            proj_type='ortho',
            azim=-45,
            elev=25,
        )
        for i in range(num_checkpoints)
    ]

    for axis in value_axes:
        axis.dist = 9

    state_support_axes = [
        figure.add_subplot(gridspec[
            value_axis_height:value_axis_height+state_support_axis_height,
            i*axis_width:(i+1)*axis_width
        ])
        for i in range(num_checkpoints)
    ]

    for axis in state_support_axes:
        axis.set_aspect('equal', 'box')

    colorbar_axes = state_support_axes

    return figure, value_axes, state_support_axes, colorbar_axes


def visualize_bbac(trial_variant,
                   trial_folder,
                   trial_checkpoint_folders,
                   figure,
                   value_axes,
                   state_support_axes,
                   colorbar_axes):
    (environment,
     replay_pool,
     exploitation_policy,
     exploration_policies,
     exploration_Qs,
     exploitation_Qs) = load_bbac_models(trial_variant, trial_folder)

    for i, trial_checkpoint_folder in enumerate(trial_checkpoint_folders):
        checkpoint_index = int(
            trial_checkpoint_folder.split('/')[-1].split('_')[1])
        timestep = checkpoint_index * 1000

        exploitation_Q_path = os.path.join(
            trial_checkpoint_folder, 'exploitation-Q')
        exploration_Q_path = os.path.join(
            trial_checkpoint_folder, 'exploration-Q')
        exploitation_policy_path = os.path.join(
            trial_checkpoint_folder, 'exploitation-policy')
        exploration_policy_path = os.path.join(
            trial_checkpoint_folder, 'exploration-policy')

        assert os.path.exists(exploitation_Q_path)
        assert os.path.exists(exploration_Q_path)
        assert os.path.exists(exploitation_policy_path)
        assert os.path.exists(exploration_policy_path)

        def load_policy_weights(policy, save_path):
            status = policy.load_weights(save_path)
            status.assert_consumed().run_restore_ops()

        load_policy_weights(
            exploitation_policy, os.path.join(exploitation_policy_path, '0'))

        tree.map_structure_with_path(
            lambda path, policy: load_policy_weights(
                policy,
                os.path.join(
                    exploration_policy_path, '-'.join([str(x) for x in path]))),
            exploration_policies)

        tree.map_structure_with_path(
            lambda path, Q: Q.load_weights(os.path.join(
                exploitation_Q_path,
                '-'.join([str(x) for x in path]))),
            exploitation_Qs)

        tree.map_structure_with_path(
            lambda path, Q: Q.load_weights(os.path.join(
                exploration_Q_path,
                '-'.join([str(x) for x in path]))),
            exploration_Qs)

        plot_values_and_uncertainties(
            environment=environment,
            figure=figure,
            axis=value_axes[i],
            Qs=exploration_Qs,
            policies=exploration_policies)

        plot_state_support(
            environment=environment,
            pool=replay_pool,
            timestep=timestep,
            figure=figure,
            histogram_axis=state_support_axes[i],
            colorbar_axis=colorbar_axes[i])


def visualize_sac(trial_variant,
                   trial_folder,
                   trial_checkpoint_folders,
                   figure,
                   value_axes,
                   state_support_axes,
                   colorbar_axes):
    (environment,
     replay_pool,
     policy,
     Qs) = load_sac_models(trial_variant, trial_folder)

    for i, trial_checkpoint_folder in enumerate(trial_checkpoint_folders):
        checkpoint_index = int(
            trial_checkpoint_folder.split('/')[-1].split('_')[1])
        timestep = checkpoint_index * 1000

        def load_policy_weights(policy, save_path):
            status = policy.load_weights(save_path)
            status.assert_consumed().run_restore_ops()

        load_policy_weights(
            policy, os.path.join(trial_checkpoint_folder, 'policy'))

        tree.map_structure_with_path(
            lambda path, Q: Q.load_weights(os.path.join(
                trial_checkpoint_folder, 'Q-' + '-'.join([str(x) for x in path]))),
            Qs)

        plot_values_and_uncertainties(
            environment=environment,
            figure=figure,
            axis=value_axes[i],
            Qs=Qs,
            policies=[policy] * len(Qs))

        plot_state_support(
            environment=environment,
            pool=replay_pool,
            timestep=timestep,
            figure=figure,
            histogram_axis=state_support_axes[i],
            colorbar_axis=colorbar_axes[i])


def main() -> None:
    set_gpu_memory_growth(True)
    args = parse_args()

    experiment_path = args.experiment_path
    analysis = tune.Analysis(experiment_path)

    trial_configs = analysis.get_all_configs()
    trial_folders = list(trial_configs.keys())

    # seed_or_id = '9502'
    # trial_folder = next(x for x in trial_folders if seed_or_id in x)
    trial_folder = trial_folders[0]

    trial_variant = trial_configs[trial_folder]
    trial_checkpoint_folders = [
        os.path.dirname(x) for x in
        glob.glob(os.path.join(trial_folder, 'checkpoint_*/.is_checkpoint'))
    ]

    sorted_trial_checkpoint_folders = list(sorted(
        trial_checkpoint_folders,
        key=lambda x: int(x.split('/')[-1].split('_')[1])))

    trial_checkpoint_folders = [
        x for x in sorted_trial_checkpoint_folders
        if int(x.split('/')[-1].split('_')[1]) in (1, 3, 9, 27, 81)
    ]

    figure, value_axes, state_support_axes, colorbar_axes = (
        initialize_figure_and_axes(
            num_checkpoints=len(trial_checkpoint_folders)))

    algorithm_class = trial_variant['algorithm_params']['class_name']
    if algorithm_class == 'BayesianBellmanActorCritic':
        visualize_bbac(trial_variant,
                       trial_folder,
                       trial_checkpoint_folders,
                       figure,
                       value_axes,
                       state_support_axes,
                       colorbar_axes)
    elif algorithm_class == 'SAC':
        visualize_sac(trial_variant,
                      trial_folder,
                      trial_checkpoint_folders,
                      figure,
                      value_axes,
                      state_support_axes,
                      colorbar_axes)
    else:
        raise ValueError(algorithm_class)

    for axis in state_support_axes[1:]:
        axis.yaxis.label.set_visible(False)

    timesteps = [
        1000 * int(trial_checkpoint_folder.split('/')[-1].split('_')[1])
        for trial_checkpoint_folder in trial_checkpoint_folders
    ]
    for timestep, axis in zip(timesteps, value_axes):
        axis.set_title(f'$T={timestep:d}$')

    first_value_axis_position = value_axes[0].get_position()
    Q_label_y = (first_value_axis_position.y1 + first_value_axis_position.y0) / 2

    figure.text(
        -0.4,
        0.5,
        "$\mathbf{\mathbb{E}}_{a\sim\pi}[Q(\cdot,a)]$",
        verticalalignment='center',
        rotation='vertical',
        transform=value_axes[0].transAxes,
    )

    plt.tight_layout()
    plt.savefig('/tmp/what.pdf')
    figure.clf()

    # plt.show()
    # plt.draw()
    # plt.show(block=True)


if __name__ == '__main__':
    main()
