from __future__ import print_function

from collections import defaultdict
import json


def check_train_file(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()

    label_count = defaultdict(int)
    sublabel_count = defaultdict(int)

    for line in lines:
        line = line.strip()
        if line == '':
            continue

        label = int(line.split(' ')[1])
        label_count[label] += 1

        sublabel = line.split('/')[0]
        if label not in sublabel_count:
            sublabel_count[label] = defaultdict(int)
        sublabel_count[label][sublabel] += 1

    print('first-level label set: {}'.format(label_count))
    print()
    print()
    print()

    for k in label_count.keys():
        print('first label label: {}'.format(k))
        print('sub-level label: {}'.format(sublabel_count[k]))
        for filename in sublabel_count[k].keys():
            print('[{}: {}]'.format(filename2label[filename], sublabel_count[k][filename]), end='\t')
        print()
    print()
    print()
    print()

    return


def check_test_file(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()

    label_count = defaultdict(int)
    sublabel_count = defaultdict(int)

    for line in lines:
        line = line.strip()
        if line == '':
            continue

        label = int(line.split(' ')[1])
        label_count[label] += 1

        sublabel = line.split('/')[0]
        if label not in sublabel_count:
            sublabel_count[label] = defaultdict(int)
        sublabel_count[label][sublabel] += 1

    print('first-level label set: {}'.format(label_count))
    print()

    return


if __name__ == '__main__':
    with open('imagenet_class_index.json', 'r') as f:
        raw_mapping_dict = json.load(f)

    filename2label = {}
    for k in range(1000):
        k = '{}'.format(k)
        filename, labelname = raw_mapping_dict[k]
        filename2label[filename] = k
    
    check_train_file('train.txt')
    check_test_file('val.txt')
