from __future__ import print_function
from PIL import Image
import os
import os.path
import errno
import copy
import numpy as np
import sys
if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

import torch.utils.data as data
from random import randint
import random
import os


train_list = [
    ['train', '16019d7e3df5f24257cddd939b257f8d'],
]

test_list = [
    ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]

base_folder = 'cifar-100-python'


def compare_two_images(a, b):
    cnt = 0
    for i in range(32):
        for j in range(32):
            for k in range(3):
                if a[i,j,k] == b[i,j,k]:
                    cnt += 1
    return cnt


def compare(confusion_data, actual_data, duplicate_num):
    for i in range(50000):
        raw_ = actual_data[i]
        for j in range(duplicate_num):
            generate_ = confusion_data[i*duplicate_num+j]
            print(j, '\t', 1.*(32*32*3-compare_two_images(raw_, generate_)) / (32*32*3))
        if i >= 100:
            break
    return


def load_actual_train_data_and_label(root):
    root = os.path.expanduser(root)

    actual_train_data = []
    actual_train_labels = []
    for fentry in train_list:
        f = fentry[0]
        file = os.path.join(root, base_folder, f)
        fo = open(file, 'rb')
        if sys.version_info[0] == 2:
            entry = pickle.load(fo)
        else:
            entry = pickle.load(fo, encoding='latin1')
        actual_train_data.append(entry['data'])
        if 'labels' in entry:
            actual_train_labels += entry['labels']
        else:
            actual_train_labels += entry['fine_labels']
        fo.close()

    actual_train_data = np.concatenate(actual_train_data)
    actual_train_data = actual_train_data.reshape((50000, 3, 32, 32))
    actual_train_data = actual_train_data.transpose((0, 2, 3, 1))
    return actual_train_data, actual_train_labels


def check_random_data(root='./data', confusion_T=0, zero_out_ratio=0):
    actual_train_data, actual_train_labels = load_actual_train_data_and_label(root)

    duplicate_num = confusion_T
    print('duplicate {} times'.format(duplicate_num))

    data = np.load('confusion_random_train_label/zero_out_{}_{}.npz'.format(confusion_T, zero_out_ratio))
    confusion_data = data['training_data']
    confusion_labels = data['training_label']

    compare(confusion_data, actual_train_data, duplicate_num)

    print('Training data size\t', confusion_data.shape)
    print('Training label size\t', confusion_labels.shape)
    print()

    return


if __name__ == '__main__':
    check_random_data(confusion_T=10, zero_out_ratio=0.1)
