R"""Script for converting raw synthetic data experiment output to CSV.

Prints the CSV to stdout.
"""

import collections
import csv
import io
import json
import os
import re
from statistics import median
from typing import Dict

from absl import app
from absl import flags
from absl import logging

import numpy as np


FLAGS = flags.FLAGS

flags.DEFINE_string('data_dir', None, 'Path to directory to read from.')

flags.DEFINE_string('gls_subdir', 'gls',
                    'Subdirectory of data_dir containing results from GLS experiments. Set to empty string if not present.')

flags.DEFINE_string('gradient_descent_subdir', 'gradient_descent',
                    'Subdirectory of data_dir containing results from gradient descent experiments. Set to empty string if not present.')

flags.DEFINE_string('random_vertex_subdir', 'random_vertex',
                    'Subdirectory of data_dir containing results from random vertex experiments. Set to empty string if not present.')

flags.DEFINE_enum('reduction', 'median', ['mean', 'median'],
                  'The reduction to perform to reduce a list of values to a single value representing it.')


DatasetParams = collections.namedtuple('DatasetParams', ['d', 'm_gen', 'm_train'])


def read_file(filepath):
    filepath = os.path.expanduser(filepath)
    with open(filepath, 'r') as f:
        results = json.load(f)

    final_losses = results['final_losses']
    config = results['config']

    ds_params = DatasetParams(
        m_gen=config["m_gen"],
        m_train=config["m_train"],
        d=config["d"],
    )
    return ds_params, final_losses


def process_final_losses(final_losses):
    if FLAGS.reduction == 'mean':
        score = np.mean(final_losses)
    elif FLAGS.reduction == 'median':
        score = median(final_losses)
    else:
        raise ValueError(FLAGS.reduction)
    std = np.std(final_losses)
    return score, std


def get_subdirs_dict():
    ret = {
        'gls': FLAGS.gls_subdir,
        'gradient_descent': FLAGS.gradient_descent_subdir,
        'random_vertex': FLAGS.random_vertex_subdir,
    }
    return {k: v for k, v in ret.items() if v}


def process_method_results(subdirs: Dict[str, str], method: str):
    data_dir = os.path.expanduser(FLAGS.data_dir)
    the_dir = os.path.join(data_dir, subdirs[method])
    ret = {}
    for filename in os.listdir(the_dir):
        filepath = os.path.join(the_dir, filename)
        ds_params, final_losses = read_file(filepath)
        ret[ds_params] = process_final_losses(final_losses)
    return ret


def main(_):
    subdirs = get_subdirs_dict()
    methods = list(subdirs.keys())

    potato = {}
    for method in methods:
        potato[method] = process_method_results(subdirs, method)

    all_keys = set()
    for chip in potato.values():
        all_keys.update(chip.keys())
    all_keys = sorted(list(all_keys))

    rows = [['d', 'm_gen', 'm_train']]
    for method in methods:
        rows[0].append(f'{method} ({FLAGS.reduction})')
        rows[0].append(f'{method} (stddev)')

    for key in all_keys:
        row = [key.d, key.m_gen, key.m_train]
        for method in methods:
            data = potato[method].get(key, ('', ''))
            row.extend(data)
        rows.append(row)

    output = io.StringIO()
    writer = csv.writer(output)
    writer.writerows(rows)
    csv_str = output.getvalue()
    print(csv_str)


if __name__ == "__main__":
    app.run(main)
