import argparse
from pathlib import Path
from pprint import pprint

import os

import numpy as np
from ray import tune
import tree
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


LABELS = (
    "TDC",
    "Direct BBO w/o prior (TD(0))",
    "Direct BBO w/ prior",
    "Gradient BBO w/o prior",
    "Gradient BBO w/ prior",
)

COLORS = {
    label: color for label, color in zip(LABELS, sns.color_palette())
}

GREEK_ALPHABET = ("alpha", "beta", "gamma", "epsilon")

BASE_DIR = Path('~/Desktop/bbo/ablation').expanduser()
FIGURE_TITLES = {
    'n_link_pendulum': '20-Link Pendulum',
    'puddle_world': 'Puddle World',
    'mountain_car': 'Mountain Car',
}

EXPERIMENT_IDS = {
    'n_link_pendulum': (
        '2020-06-04T12-01-13-gradient-bbo-with-regularization-final',
        '2020-06-04T14-03-34-gradient-bbo-without-regularization-final',
        '2020-06-04T19-44-49-direct-bbo-with-regularization-final',
        '2020-06-04T20-14-00-td0',
        '2020-05-29T11-22-24-tdc-final',
    ),
    'puddle_world': (
        '2020-05-03T14-18-32-td0-final',
        '2020-06-04T10-41-48-gradient-bbo-with-regularization-final',
        '2020-06-04T13-51-50-gradient-bbo-without-regularization-final',
        '2020-06-04T20-01-51-direct-bbo-with-regularization-final',
        '2020-05-24T10-39-37-tdc-final',
    ),
    'mountain_car': (
        '2020-05-03T22-23-50-td0-final',
        '2020-06-03T19-09-19-gradient-bbo-without-regularization-final',
        '2020-06-03T20-29-53-gradient-bbo-with-regularization-final',
        '2020-06-03T22-30-15-direct-bbo-with-regularization-final',
        '2020-05-28T11-21-58-tdc-final',
    ),
}


EXPERIMENT_PATHS = tree.map_structure_with_path(
    lambda path, experiment_id: BASE_DIR / path[0] / experiment_id,
    EXPERIMENT_IDS)


def plot_experiment(visualization_dataframe, metric_to_use, axis, legend=False):
    labels = visualization_dataframe['label'].unique().copy().tolist()
    labels.sort()
    bbo_label_index = labels.index("Gradient BBO w/ prior")
    bbo_label = labels.pop(bbo_label_index)
    labels = (*labels, bbo_label)

    X_UNITS = 'thousands'
    if X_UNITS == 'thousands':
        visualization_dataframe['training_steps'] /= 1000
    else:
        raise ValueError

    sns.lineplot(
        x='training_steps',
        y=metric_to_use,
        hue='label',
        hue_order=LABELS,
        data=visualization_dataframe,
        legend=legend,
        ax=axis,
    )

    axis.set_ylim(np.clip(axis.get_ylim(), 0.0, float('inf')))

    if X_UNITS == 'thousands':
        axis.set_xlabel('Training Steps [$10^3$]')
    else:
        axis.set_xlabel('Training Steps')


def create_visualization_dataframe(result):
    trial_dataframes = result.fetch_trial_dataframes()
    trial_configs = result.get_all_configs()

    def create_visualization_dataframe(dataframe, config):
        if config['algorithm_params']['class_name'] == 'TD0':
            # label = "TD(0)"
            label = "Direct BBO w/o prior (TD(0))"
            style = 'w/o prior'
            hue = 'Direct BBO'
        elif config['algorithm_params']['class_name'] == 'BBORandomizedPrior':
            assert config['algorithm_params']['config']['omega_lr'] <= 1

            algorithm_config = config['algorithm_params']['config']
            prior_loss_weight = algorithm_config['prior_loss_weight']
            omega_lr = algorithm_config['omega_lr']

            style = {
                True: 'w/o prior',
                False: 'w prior',
            }[prior_loss_weight == 0]
            hue = {
                True: 'Gradient BBO',
                False: 'Direct BBO',
            }[omega_lr < 1]

            if prior_loss_weight == 0 and omega_lr == 1:
                # label = "Direct BBO w/o prior"
                label = "Direct BBO w/o prior (TD(0))"
            elif 0 < prior_loss_weight and omega_lr == 1:
                label = "Direct BBO w/ prior"
            elif prior_loss_weight == 0 and omega_lr < 1:
                label = "Gradient BBO w/o prior"
            elif 0 < prior_loss_weight and omega_lr < 1:
                label = "Gradient BBO w/ prior"
            else:
                pprint(config['algorithm_params'])
                raise ValueError("Should not be here.")
        elif config['algorithm_params']['class_name'] == 'TDC':
            label = "TDC"
            style = "TDC"
            hue = "TDC"
        else:
            raise ValueError(config['algorithm_params']['class_name'])

        dataframe['label'] = label
        dataframe['style'] = style
        dataframe['hue'] = hue

        keys = [
            'errors/MSE',
            'errors/RMSE',
            'label',
            'style',
            'hue',
            'training_steps'
        ]

        return dataframe[keys]

    common_keys = set(trial_dataframes.keys()) & set(trial_configs.keys())
    trial_dataframes = {
        key: trial_dataframes[key]
        for key in common_keys
    }
    common_keys = set(trial_dataframes.keys()) & set(trial_configs.keys())
    trial_configs = {
        key: trial_configs[key]
        for key in common_keys
    }
    if set(trial_dataframes.keys()) != set(trial_configs.keys()):
        breakpoint()
        raise ValueError

    visualization_dataframes = tree.map_structure_up_to(
        {x: None for x in trial_dataframes},
        create_visualization_dataframe, trial_dataframes, trial_configs)

    visualization_dataframe = pd.concat(visualization_dataframes.values())

    return visualization_dataframe


def visualize_experiments_for_paper():
    results = tree.map_structure(tune.Analysis, EXPERIMENT_PATHS)
    dataframes = tree.map_structure(create_visualization_dataframe, results)

    metric_to_use = 'errors/RMSE'

    def validate_dataframe(dataframe):
        if dataframe[metric_to_use].hasnans:
            raise ValueError("Found nans from the visualization dataframes.")

    tree.map_structure(validate_dataframe, dataframes)

    environment_ids = tuple(dataframes.keys())

    num_subplots = len(environment_ids)
    default_figsize = plt.rcParams.get('figure.figsize')
    figure_scale = 0.5
    figsize = np.array((
        0.9 * num_subplots, 0.4)) * np.max(default_figsize[0] * figure_scale)
    figure, axes = plt.subplots(1, num_subplots, figsize=figsize)
    axes = np.atleast_1d(axes)

    save_dir = BASE_DIR
    os.makedirs(save_dir, exist_ok=True)

    for i, ((environment_id, environment_dataframes), axis) in enumerate(zip(
            dataframes.items(), axes)):
        visualization_dataframe = pd.concat(environment_dataframes)
        plot_experiment(
            visualization_dataframe,
            metric_to_use,
            axis,
            legend='brief')

        handles, labels = axis.get_legend_handles_labels()

        handles, labels = zip(*[
            [handle, label]
            for handle, label
            in zip(handles, labels)
            if label in LABELS
        ])

        axis.get_legend().remove()

        axis.set_title(FIGURE_TITLES[environment_id])

        if i == 0:
            if metric_to_use == 'errors/RMSE':
                axis.set(ylabel='$\sqrt{MSE}$')
            elif metric_to_use == 'errors/MSE':
                axis.set(ylabel='$MSE$')
            else:
                raise ValueError(metric_to_use)
        else:
            axis.set(ylabel=None)

    handles, labels = axes.flatten()[-1].get_legend_handles_labels()
    handles, labels = zip(*[
        [handle, label]
        for handle, label
        in zip(handles, labels)
        if label in LABELS
    ][::-1])

    legend = figure.legend(
        handles=handles,
        labels=labels,
        ncol=3,
        handlelength=1.25,
        handletextpad=0.25,
        columnspacing=1.25,

        loc='lower center',
        bbox_to_anchor=(0.47, 1.05),
        bbox_transform=figure.transFigure,

        fontsize='medium'
    )

    legend.set_in_layout(True)
    titles = ()

    plt.savefig(
        save_dir / 'non-linear-policy-evaluation-ablation-result.pdf',
        bbox_extra_artists=(*titles, legend),
        bbox_inches='tight')

    plt.savefig(
        save_dir / 'non-linear-policy-evaluation-ablation-result.png',
        bbox_extra_artists=(*titles, legend),
        bbox_inches='tight')


def get_argument_parser():
    parser = argparse.ArgumentParser()
    return parser


if __name__ == '__main__':
    argument_parser = get_argument_parser()
    cli_args = argument_parser.parse_args()
    visualize_experiments_for_paper()
