import os
from typing import Dict

import numpy as np
import pandas as pd
from ray import tune

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from .visualize_cartpole import smooth_dataframe


def load_data() -> Dict[str, tune.Analysis]:
    SAC_DATA_PATH = os.path.expanduser(
        '~/ray_results/cetus/gym/MountainCar/Continuous-v0/'
        '2021-02-04T08-57-43-sac-entropy-sweep-1/'
    )
    SAC_STAR_DATA_PATH = os.path.expanduser(
        '~/ray_results/cetus/gym/MountainCar/Continuous-v0/'
        '2021-02-04T08-57-43-sac-entropy-sweep-1/')
    BBAC_DATA_PATH = os.path.expanduser(
        '~/ray_results/cetus/gym/MountainCar/Continuous-v0/'
        '2021-02-02T15-23-35-bbac-visualization-1')
    BAC_DATA_PATH = os.path.expanduser(
        '~/ray_results/gandalf/gym/MountainCar/Continuous-v0/'
        '2021-05-28T14-40-18-bbac-visualization-1')

    sac_analysis = tune.Analysis(SAC_DATA_PATH)
    sac_star_analysis = tune.Analysis(SAC_STAR_DATA_PATH)
    bbac_analysis = tune.Analysis(BBAC_DATA_PATH)
    bac_analysis = tune.Analysis(BAC_DATA_PATH)

    data = {
        'sac': sac_analysis,
        'sac_star': sac_star_analysis,
        'bbac': bbac_analysis,
        'bac': bac_analysis,
    }

    return data

def filter_sac_data(config):
    should_include = True
    should_filter = not should_include
    return should_filter


def filter_sac_star_data(config):
    should_include = True
    should_filter = not should_include
    return should_filter


def filter_bbac_data(config):
    should_include = True
    should_filter = not should_include
    return should_filter


def filter_bac_data(config):
    should_include = True
    should_filter = not should_include
    return should_filter


def visualize(data: Dict[str, tune.Analysis]) -> None:
    sac_analysis = data['sac']
    sac_star_analysis = data['sac_star']
    bbac_analysis = data['bbac']
    bac_analysis = data['bac']

    sac_visualization_dataframes = []
    sac_trial_keys = sac_analysis.dataframe().logdir
    sac_configs = sac_analysis.get_all_configs()
    for sac_trial_key in sac_trial_keys:
        sac_trial_config = sac_configs[sac_trial_key]
        if filter_sac_data(sac_trial_config):
            continue
        sac_trial_dataframe = sac_analysis.trial_dataframes[sac_trial_key]
        sac_trial_dataframe = smooth_dataframe(sac_trial_dataframe)
        sac_visualization_dataframes += [sac_trial_dataframe]

    sac_visualization_dataframe = pd.concat(sac_visualization_dataframes)
    assert 'algorithm' not in sac_visualization_dataframe.columns
    sac_visualization_dataframe['algorithm'] = '$SAC$'

    sac_star_visualization_dataframes = []
    sac_star_trial_keys = sac_star_analysis.dataframe().logdir
    sac_star_configs = sac_star_analysis.get_all_configs()
    for sac_star_trial_key in sac_star_trial_keys:
        sac_star_trial_config = sac_star_configs[sac_star_trial_key]
        print(sac_star_trial_config['Q_params']['class_name'])
        if filter_sac_star_data(sac_star_trial_config):
            continue
        sac_star_trial_dataframe = sac_star_analysis.trial_dataframes[sac_star_trial_key]
        sac_star_trial_dataframe = smooth_dataframe(sac_star_trial_dataframe)
        sac_star_visualization_dataframes += [sac_star_trial_dataframe]

    sac_star_visualization_dataframe = pd.concat(sac_star_visualization_dataframes)
    assert 'algorithm' not in sac_star_visualization_dataframe.columns
    sac_star_visualization_dataframe['algorithm'] = '$SAC^{*}$'

    bbac_visualization_dataframes = []
    bbac_trial_keys = bbac_analysis.dataframe().logdir
    bbac_configs = bbac_analysis.get_all_configs()
    for bbac_trial_key in bbac_trial_keys:
        bbac_trial_config = bbac_configs[bbac_trial_key]
        if filter_bbac_data(bbac_trial_config):
            continue
        bbac_trial_dataframe = bbac_analysis.trial_dataframes[bbac_trial_key]
        bbac_trial_dataframe = smooth_dataframe(bbac_trial_dataframe)
        bbac_visualization_dataframes += [bbac_trial_dataframe]

    bbac_visualization_dataframe = pd.concat(bbac_visualization_dataframes)
    assert 'algorithm' not in bbac_visualization_dataframe.columns
    bbac_visualization_dataframe['algorithm'] = '$BBAC$'

    bac_visualization_dataframes = []
    bac_trial_keys = bac_analysis.dataframe().logdir
    bac_configs = bac_analysis.get_all_configs()
    for bac_trial_key in bac_trial_keys:
        bac_trial_config = bac_configs[bac_trial_key]
        if filter_bac_data(bac_trial_config):
            continue
        bac_trial_dataframe = bac_analysis.trial_dataframes[bac_trial_key]
        bac_trial_dataframe = smooth_dataframe(bac_trial_dataframe)
        bac_visualization_dataframes += [bac_trial_dataframe]

    bac_visualization_dataframe = pd.concat(bac_visualization_dataframes)
    assert 'algorithm' not in bac_visualization_dataframe.columns
    bac_visualization_dataframe['algorithm'] = '$BAC$'

    x_key = 'sampler/total-samples'
    y_key = 'training/episode-reward-mean'
    hue_key = 'algorithm'
    visualization_dataframe_keys = [x_key, y_key, hue_key]
    full_visualization_dataframe = pd.concat((
        bac_visualization_dataframe[visualization_dataframe_keys],
        sac_visualization_dataframe[visualization_dataframe_keys],
        sac_star_visualization_dataframe[visualization_dataframe_keys],
        bbac_visualization_dataframe[visualization_dataframe_keys],
    ))

    x_axis_unit = 'thousands'
    # x_axis_unit = 'millions'
    unit_labels = {'thousands': '1e3', 'millions': '1e6'}
    label_map = {
        x_key: f'samples [{unit_labels[x_axis_unit]}]',
        y_key: 'return',
    }

    if x_axis_unit == 'thousands':
        full_visualization_dataframe[x_key] /= 1e3
    elif x_axis_unit == 'millions':
        full_visualization_dataframe[x_key] /= 1e6
    else:
        raise NotImplementedError(f"TODO: x_axis_unit={x_axis_unit}")

    figure_scale = 0.3
    # figsize = figure_scale * plt.figaspect(2/3)
    # figsize = figure_scale * np.array([7.2, 5.2])
    figsize = figure_scale * np.array([7.2, 4.2])
    figure, axis = plt.subplots(1, 1, figsize=figsize)

    sns.lineplot(
        data=full_visualization_dataframe,
        x=x_key,
        y=y_key,
        hue=hue_key,
        ax=axis)

    handles, labels = axis.get_legend_handles_labels()
    handles, labels = list(zip(*sorted(
        zip(handles, labels), key=lambda x: x[1])))
    legend = axis.legend(
        handles=handles,
        labels=labels,
        ncol=2,
        columnspacing=0.3,
        handlelength=0.75,
        framealpha=0.8,
        edgecolor=None,
        loc=(0.3, 0.01),
        handletextpad=0.5,
        labelspacing=0.2,
        borderpad=0.1,
    )
    legend.get_frame().set_linewidth(0)
    axis.set_xlabel(label_map[x_key])
    axis.set_ylabel(label_map[y_key])
    # axis.set_title('MountainCar-v0')

    # thousand_formatter = FuncFormatter(lambda value, pos: f'{value / 1e3}e3')
    # axis.xaxis.set_major_formatter(thousand_formatter)
    axis.tick_params(axis='y', labelrotation=90)

    # plt.tight_layout()
    # plt.savefig('/tmp/mountain-car-v0.pdf')
    plt.savefig('/tmp/bbac-results-mountain-car-v0.pdf', bbox_inches='tight')


def main() -> None:
    data = load_data()
    # transformed_data = transform_data()
    transformed_data = data
    visualize(transformed_data)


if __name__ == '__main__':
    main()
