#!/usr/bin/python3

import pandas
import seaborn
import matplotlib.pyplot as plt

ORIGINAL_FILENAME = "../synthetic_data.csv"
CLASS_LABEL = "xPy_Label"

PREDICTION_ERROR_FILENAME = "csv_files/prediction_error.csv"
DISCRIMINATOR_ERROR_FILENAME = "csv_files/discriminator_error.csv"
RECONSTRUCTION_ERROR_FILENAME_STUB = "csv_files/reconstruction_error_"

FEATURES = ["x", "x2", "xSquared", "y", "y2", "ySquared", "z", "z2", "zSquared"]
FEATURE_NAMES = ["x", "2x", "x\u00b2", "y", "2y", "y\u00b2", "c", "2c", "c\u00b2"]

FONTSIZE = 24
TITLESIZE = 30

def boxplot_error(error_df, fig_title, outfile):
    errorplot = seaborn.boxplot(data=error_df)
    label = plt.xlabel("Feature", fontsize = FONTSIZE)
    plt.ylabel("Error", fontsize = FONTSIZE)
    plt.title(fig_title, fontsize = TITLESIZE)
    errorplot.set_xticklabels(FEATURE_NAMES, fontsize = FONTSIZE)
    # plt.show()
    errorplot.figure.savefig(outfile, bbox_extra_artists=(label,), bbox_inches='tight')
    plt.clf()

def barplot_error(error_df, fig_title, outfile):
    errorplot = seaborn.barplot(data=error_df)
    label = plt.xlabel("Feature", fontsize = FONTSIZE)
    plt.ylabel("Error", fontsize = FONTSIZE)
    plt.title(fig_title, fontsize = TITLESIZE)
    axes = plt.gca()
    axes.set_ylim([0.0,1.5])
    errorplot.set_xticklabels(FEATURE_NAMES, fontsize = FONTSIZE)
    errorplot.figure.savefig(outfile, bbox_extra_artists=(label,), bbox_inches='tight')
    # plt.show()
    plt.clf()

def get_reconstruction_filename(stub, featurename):
    return stub + featurename + ".csv"

def avg_perinstance(filename_stub, features):
    mean_df = pandas.DataFrame()
    for feature in features:
        filename = get_reconstruction_filename(filename_stub, feature)
        df = pandas.read_csv(filename)
        feature_mean_df = df.mean(axis=1)
        mean_df[feature] = feature_mean_df
    return mean_df

def combine_error_percol(filename_stub, features):
    combined_df = pandas.DataFrame()
    frames = []
    for feature in features:
        filename = get_reconstruction_filename(filename_stub, feature)
        df = pandas.read_csv(filename)
        frames.append(df)
    return pandas.concat(frames)

# sqrt[mse(p,phat)/var(p)]
def normalize(df, original_data):
    var = original_data.var()
    squared = df * df
    mse = squared.mean()
    div = mse / var
    normalized = div ** .5
    return normalized.drop(CLASS_LABEL).transpose()

# (p-phat)^2/var(p)
def normalize_dist(df, original_data):
    var = original_data.drop(CLASS_LABEL, axis=1).var()
    squared = df * df
    div = squared / var
    return div

original_data_df = pandas.read_csv(ORIGINAL_FILENAME)

prediction_error_df = pandas.read_csv(PREDICTION_ERROR_FILENAME)
boxplot_error(prediction_error_df, "Prediction Error", "figures/prediction_error.png")

reconstruction_error_df = combine_error_percol(RECONSTRUCTION_ERROR_FILENAME_STUB, FEATURES)
boxplot_error(reconstruction_error_df, "Reconstruction Error", "figures/reconstruction_error.png")

discriminator_error_df = pandas.read_csv(DISCRIMINATOR_ERROR_FILENAME)
normalized = normalize_dist(discriminator_error_df, original_data_df)
barplot_error(normalized, "Disentanglement Error", "figures/discriminator_error.png")
boxplot_error(discriminator_error_df, "Disentanglement Error", "figures/discriminator_box_error.png")

