import pickle
import os
import json
import argparse
import subprocess
import random
from tqdm import tqdm
import pprint

import global_consts as gc
from process_dataset import process_dataset
from process_instruction import correct_entrys
from inst_dict import create_dictionary


def generate_dataset(state_files,
                     trainset_path,
                     validset_path,
                     validset_ratio,
                     min_num_targets,
                     min_num_instructions,
                     dict_only):
    # may needs optimization
    dataset = []
    trainset, validset = [], []
    # single process for deterministic split
    for state_file in tqdm(state_files):
        data = process_state_file(state_file, min_num_targets, min_num_instructions)
        if len(data) == 0:
            continue
        dataset.extend(data)

        if random.random() < validset_ratio:
            validset.extend(data)
        else:
            trainset.extend(data)

    print('total frames processed:', len(dataset))

    # process dataset
    print('========================')
    print('processing dataset......')
    trainset = process_dataset(trainset)
    print('======')
    validset = process_dataset(validset)
    # process instructions
    print('==========================')
    print('correcting instructions......')
    trainset, corrections = correct_entrys(trainset, None)
    validset, _ = correct_entrys(validset, corrections)
    # generate dict
    print('===================')
    print('creating dictionary......')
    inst_dict = create_dictionary(trainset, corrections, 5, 0)
    dict_path = trainset_path + '_dict.pt'
    print('writing dictionary to ', dict_path)
    pickle.dump(inst_dict, open(dict_path, 'wb'))

    if dict_only:
        return

    if len(validset) > 0:
        print('validset size: %d' % len(validset))
        print('writing validset to file:', validset_path)
        with open(validset_path, 'w') as fout:
            json.dump(validset, fout)

    if len(trainset) > 0:
        print('trainset size: %d' % len(trainset))
        print('writing trainset to file:', trainset_path)
        with open(trainset_path, 'w') as fout:
            json.dump(trainset, fout)


def main():
    parser = argparse.ArgumentParser(description='generate dataset from states')
    parser.add_argument('--states-root', type=str)
    parser.add_argument('--trainset-path', type=str, default='train.json')
    parser.add_argument('--validset-path', type=str, default='valid.json')
    parser.add_argument('--validset-ratio', type=float, default=0.1)
    # parser.add_argument('--num-workers', type=int, default=40)
    parser.add_argument('--state-file-extension', type=str, default='.json')
    parser.add_argument('--min-num-targets', type=int, default=0)
    parser.add_argument('--min-num-instructions', type=int, default=0)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--dict-only', action='store_true')
    args = parser.parse_args()

    prefix = '_min_tar_%d_min_inst_%d' % (args.min_num_targets, args.min_num_instructions)
    args.trainset_path = args.trainset_path + prefix
    args.validset_path = args.validset_path + prefix

    random.seed(args.seed)

    state_files = get_all_files(args.states_root, args.state_file_extension)
    state_files = sorted(state_files)
    generate_dataset(state_files,
                     args.trainset_path,
                     args.validset_path,
                     args.validset_ratio,
                     args.min_num_targets,
                     args.min_num_instructions,
                     args.dict_only)


if __name__ == '__main__':
    main()
