import os
import argparse
from collections import defaultdict

import matplotlib.pyplot as plt
plt.switch_backend('agg')


def render_log(filename1, modelname1, filename2, modelname2, omit_first_k, mode):
    if mode == 'executor':
        loss_names = [
            'loss',
            'glob_cont_loss',
            'unit_loss',
            'cmd_type_loss',
            'move_loss',
            'attack_loss',
            'gather_loss',
            'build_unit_loss',
            'build_building_loss'
        ]
    else:
        loss_names = [
            'loss',
            'cont_loss',
            'lang_loss',
        ]

    fig, ax = plt.subplots(len(loss_names), 1, figsize=(10, 6 * len(loss_names)))
    inputs = [(filename1, modelname1), (filename2, modelname2)]

    for fidx, (filename, modelname) in enumerate(inputs):
        if filename is None:
            continue
        train_loss = defaultdict(list)
        valid_loss = defaultdict(list)
        lines = open(filename, 'r').readlines()
        loss_dict = None

        for l in lines:
            l = l.strip()
            if l.startswith('eval'):
                loss_dict = valid_loss
            elif l.startswith('train'):
                loss_dict = train_loss

            for name in loss_names:
                if l.startswith(name):
                    val = float(l.split(' ')[-1])
                    loss_dict[name].append(val)

        for i, lossname in enumerate(loss_names):
            train = train_loss[lossname][omit_first_k:]
            valid = valid_loss[lossname][omit_first_k:]

            x = list(range(omit_first_k, omit_first_k + len(train)))
            ax[i].plot(x, train, label=modelname+'-train', linestyle='-')
            ax[i].plot(x[:len(valid)], valid, label=modelname+'-valid', linestyle='--')
            ax[i].legend(loc='upper right')
            ax[i].set_title(lossname)

    plt.tight_layout()
    figname = filename1 + '-render.png'
    print('writing to ', figname)
    plt.savefig(figname)
    plt.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--log1', type=str)
    parser.add_argument('--name1', type=str, default='model1')
    parser.add_argument('--log2', type=str)
    parser.add_argument('--name2', type=str, default='model2')
    parser.add_argument('--omit-first-k', type=int, default=1)
    parser.add_argument('--mode', type=str, default='executor')

    args = parser.parse_args()
    render_log(args.log1, args.name1, args.log2, args.name2, args.omit_first_k, args.mode)
