import os
import json
import argparse
from collections import defaultdict

import matplotlib.pyplot as plt
plt.switch_backend('agg')


def get_span_stats(entrys):
    span_stats = defaultdict(int)
    spans = []
    i = 0
    while i < len(entrys):
        j = i
        while j < len(entrys) and entrys[i]['instruction'] == entrys[j]['instruction']:
            j += 1
        span = j - i
        if span > 50:
            print(span, entrys[i]['unique_id'])
        span_stats[span] += 1
        spans.append(span)
        i = j

    return span_stats, spans


def render_graph(datapath):
    entrys = json.load(open(datapath, 'r'))
    span_stats, spans = get_span_stats(entrys)

    print('span, count:')
    for key in sorted(span_stats.keys()):
        print(key, span_stats[key])

    plt.hist(spans, bins=100)
    img_name = datapath + '_inst_span.png'
    print('saving img to:', img_name)
    plt.savefig(img_name)
    plt.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='filter dataset')
    parser.add_argument(
        '--dataset', type=str, default='data/fix2_prev_ins_dev.json_min10')

    args = parser.parse_args()
    # entrys = json.load(open(args.dataset, 'r'))
    # span_stats, _ = get_span_stats(entrys)
    # print('span, count:')
    # for key in sorted(span_stats.keys()):
    #     print(key, span_stats[key])

    render_graph(args.dataset)
