from collections import defaultdict
from copy import deepcopy
import numpy as np
import argparse
import json
from collections import defaultdict, deque

import global_consts as gc


# there seems to be no empty frame in the dataset
def remove_empty_frame(dataset):
    new_dataset = []
    for d in dataset:
        if len(d['my_units']) == 0:
            print('removing %s due to no our units' % d['unique_id'])
        elif len(d['instruction']) == 0:
            print('removing %s due to no instruction' % d['unique_id'])
        else:
            new_dataset.append(d)
    return new_dataset


# def filter_beginning(dataset):
#     print('filtering all-cont at beginning')

#     filtered_count = []
#     new_dataset = []
#     i = 0
#     while i < len(dataset):
#         instruction = dataset[i]['instruction']
#         j = i

#         # filter beginning
#         must_include = False
#         while j < len(dataset) and dataset[j]['instruction'] == instruction:
#             if must_include:
#                 new_dataset.append(dataset[j])
#                 j += 1
#                 continue

#             all_cont = True
#             for unit in dataset[j]['my_units']:
#                 target_cmd_type = unit['target_cmd']['cmd_type']
#                 if target_cmd_type != gc.CmdTypes.CONT.value:
#                     all_cont = False

#             if not all_cont:
#                 must_include = True
#                 new_dataset.append(dataset[j])
#                 filtered_count.append(j-i)
#                 # if j - i > 10:
#                 #     print(j, dataset[j]['instruction'])

#             j += 1

#         i = j

#     print('before filter, dataset len:', len(dataset))
#     print('after filter, dataset len:', len(new_dataset))
#     print('percent: %.2f' % (len(new_dataset) / len(dataset)))
#     print('avg filtered frames per instruction, %.2f' % np.mean(filtered_count))
#     print('----------------')
#     return new_dataset


# def label_all_cont(dataset):
#     num_all_cont = 0
#     for d in dataset:
#         all_cont = True
#         for unit in d['my_units']:
#             if unit['target_cmd']['cmd_type'] != gc.CmdTypes.CONT.value:
#                 all_cont = False

#         d['glob_cont'] = int(all_cont)
#         num_all_cont += all_cont

#     print('pecent of all_cont frames: %.2f' % (num_all_cont / len(dataset)))
#     return dataset


# def split_into_replays(dataset):
#     replays = []
#     current_rep_name = ''
#     rep_frames = []
#     for d in dataset:
#         rep_name = d['unique_id']
#         rep_name = rep_name.split('.rep')[0]
#         # inst = entry['instruction']
#         # print('current:', current_rep_name)
#         if rep_name != current_rep_name:
#             if len(rep_frames):
#                 replays.append(rep_frames)
#             current_rep_name = rep_name
#             rep_frames = [d]
#         else:
#             rep_frames.append(d)

#     replays.append(rep_frames)
#     return replays


# def add_prev_instruction(dataset):
#     replays = split_into_replays(dataset)
#     for replay in replays:
#         replay[0]['prev_instruction'] = ''
#         for i in range(1, len(replay)):
#             replay[i]['prev_instruction'] = replay[i-1]['instruction']
#             # print(replay[i]['prev_instruction'])

#     new_dataset = [d for replay in replays for d in replay]
#     print('before adding prev instruction:', len(dataset))
#     print('after adding prev instruction :', len(new_dataset))
#     assert(len(dataset) == len(new_dataset))
#     return new_dataset


# this feature is not used anywhere
def mark_current_cmd_continue(dataset):
    replays = split_into_replays(dataset)
    prev_targets = {}
    num_cont = 0
    num_not_cont = 0
    for replay in replays:
        for i in range(len(replay)):
            units = replay[i]['my_units']
            new_targets = {}

            for unit in units:
                unit_id = unit['unit_id']
                unit['current_cmd_cont'] = 0
                num_not_cont += 1
                if (unit_id in prev_targets
                    and prev_targets[unit_id]['cmd_type'] == gc.CmdTypes.CONT.value):
                    unit['current_cmd_cont'] = 1
                    num_not_cont -= 1
                    num_cont += 1

                new_targets[unit_id] = unit['target_cmd']

            prev_targets = new_targets

    print('num_cont:', num_cont)
    print('num_not_cont:', num_not_cont)
    return dataset


# def add_prev_cmd_and_base_frame(dataset):
#     """decide base frame set up prev_cmd"""
#     instruction_count = 0
#     cmd_count = 0

#     # recent_instructions = deque(maxlen=5)
#     i = 0
#     instruction_span = []
#     unit2pre_prev_cmds = defaultdict(list)
#     prev_replay_name = ''
#     pre_ins_base_frame_idx = 0
#     while i < len(dataset):
#         replay_name = dataset[i]['unique_id'].split('.rep')[0]
#         # new file, clear pre_cmds
#         if prev_replay_name != replay_name:
#             unit2pre_prev_cmds = defaultdict(list)
#             prev_replay_name = replay_name
#             pre_ins_base_frame_idx = i

#         instruction = dataset[i]['instruction']
#         j = i
#         unit2prev_cmds = defaultdict(list)
#         while j < len(dataset) and dataset[j]['instruction'] == instruction:
#             dataset[j]['base_frame_idx'] = i
#             dataset[j]['pre_ins_base_frame_idx'] = pre_ins_base_frame_idx
#             if i == j:
#                 pre_ins_base_frame_idx = i

#             for unit in dataset[j]['my_units']:
#                 unit_id = unit['unit_id']
#                 unit['pre_ins_prev_cmd'] = deepcopy(unit2pre_prev_cmds[unit_id])
#             if i == j:
#                 unit2pre_prev_cmds = defaultdict(list)

#             for unit in dataset[j]['my_units']:
#                 unit_id = unit['unit_id']
#                 unit['prev_cmd'] = deepcopy(unit2prev_cmds[unit_id])

#                 if 'target_cmd' not in unit:
#                     continue

#                 target_cmd = unit['target_cmd']

#                 if (target_cmd['cmd_type'] == gc.CmdTypes.IDLE.value
#                     or target_cmd['cmd_type'] == gc.CmdTypes.CONT.value):
#                     continue

#                 unit2prev_cmds[unit_id].append(target_cmd)
#                 unit2pre_prev_cmds[unit_id].append(target_cmd)

#             j += 1

#         instruction_span.append(j-i)
#         i = j

#     print('average span: %.2f' % np.mean(instruction_span))
#     return dataset


def process_dataset(dataset):
    dataset = remove_empty_frame(dataset)
    dataset = filter_beginning(dataset)
    dataset = label_all_cont(dataset)
    dataset = add_prev_cmd_and_base_frame(dataset)
    dataset = add_prev_instruction(dataset)
    dataset = mark_current_cmd_continue(dataset)
    return dataset


# def check_same_inst_diff_replay(entrys):
#     prev_inst = ''
#     prev_rep = ''

#     false_count = 0
#     false = []
#     for entry in entrys:
#         rep = entry['unique_id']
#         rep = rep.split('.rep')[0]
#         inst = entry['instruction']
#         if inst == prev_inst:
#             assert rep == prev_rep
#             if rep != prev_rep:
#                 false.append((rep, inst))
#         if inst != prev_inst:
#             prev_inst = inst
#         if rep != prev_rep:
#             prev_rep = rep

#         # print(prev_rep, prev_inst)
#     return false


# if __name__ == '__main__':
#     # this part is merely for creating dev dataset
#     parser = argparse.ArgumentParser(description='filter dataset')
#     parser.add_argument('--input-dataset',
#                         type=str,
#                         default='data/fix2_prev_ins_valid.json_min10')
#     parser.add_argument('--input-dataset2',
#                         type=str,
#                         default='data/new_inst_valid.json_min10')
#     parser.add_argument('--output-dataset',
#                         type=str,
#                         default='test_processed.json')
#     parser.add_argument('--first-k', type=int, default=0)
#     args = parser.parse_args()

#     with open(args.input_dataset, 'r') as f:
#         dataset = json.loads(f.read())

#     # mark_current_cmd_continue(dataset)

#     # with open(args.input_dataset2, 'r') as f:
#     #     dataset2 = json.loads(f.read())

#     # assert(len(dataset2) == len(dataset))
#     # for i in range(len(dataset2)):
#     #     assert(len(dataset[i]['my_units']) == len(dataset2[i]['my_units']))
#     #     for j in range(len(dataset[i]['my_units'])):
#     #         prev1 = dataset[i]['my_units'][j]['prev_cmd']
#     #         prev2 = dataset2[i]['my_units'][j]['prev_cmd']

#     #         if prev1 != prev2:
#     #             import pdb
#     #             pdb.set_trace()

#     #         pre_prev2 = dataset2[i]['my_units'][j]['pre_ins_prev_cmd']
#     #         same_prev = (pre_prev2 == prev2)
#     #         same_instruction = (
#     #             dataset2[i]['instruction'] == dataset2[i]['prev_instruction'])
#     #         if same_instruction:
#     #             if not same_prev:
#     #                 import pdb
#     #                 pdb.set_trace()
#     #         if len(prev2) and not same_prev:
#     #             if same_instruction:
#     #                 import pdb
#     #                 pdb.set_trace()

#     if args.first_k > 0:
#         dataset = dataset[ : args.first_k]

#     with open(args.output_dataset, 'w') as f:
#         json.dump(dataset, f)
