from __future__ import print_function

import re
import os
import copy
import numpy as np
import glob
import sys

import skimage
from skimage import io
import argparse


actual_dir = './restricted_imagenet'
actual_train_dir = './restricted_imagenet/train'


def _find_classes(dir):
    if sys.version_info >= (3, 5):
        # Faster and available in Python 3.5 and above
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    else:
        classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def load_filenames_labels():
    label_dict, class_to_idx = _find_classes(actual_train_dir)
    print('label dict: {}, class_to_idx: {}'.format(label_dict, class_to_idx))

    filenames_labels = []

    for target in sorted(class_to_idx.keys()):
        d = '{}/{}'.format(actual_train_dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                filenames_labels.append((path, class_to_idx[target]))

    return filenames_labels


def generate_random_data(confusion_train_dir_, confusion_T, zero_out_ratio):
    print('confusion size: {}\tzero out ratio: {}'.format(confusion_T, zero_out_ratio))
    confusion_train_dir = confusion_train_dir_.format(confusion_T, zero_out_ratio)
    print(confusion_train_dir)

    for dir_ in os.listdir('{}/'.format(actual_train_dir)):
        actual_dir = '{}/{}'.format(actual_train_dir, dir_)
        neo_dir = '{}/{}'.format(confusion_train_dir, dir_)
        if not os.path.isdir(neo_dir):
            os.makedirs(neo_dir)

        for filename in os.listdir(actual_dir):
            file_path = '{}/{}/{}'.format(actual_train_dir, dir_, filename)
            # print(file_path)

            image = io.imread(file_path)
            # print(image.shape)

            l = reduce(lambda x,y:x*y, image.shape)
            n = int(l * zero_out_ratio)
            list_ = np.arange(l)
            np.random.shuffle(list_)
            sample_index = list_[:n]
            data = copy.deepcopy(image)

            if l == 32*32*3:
                for index in sample_index:
                    a = index % 32
                    c = index / (32 * 32)
                    b = (index - a - c * 32 * 32) / 32
                    data[a, b, c] = 0
            else:
                print(data.shape, '\t', type(data), '\t', l, '\t', n)
                for index in sample_index:
                    a = index % 32
                    b = (index - a) / 32
                    data[a, b] = 0

            confusion_file_path = '{}/{}'.format(neo_dir, filename, quality=100)
            io.imsave(confusion_file_path, data)
    return


def generate_confusion_label(outputname, confusion_T):
    outputname = outputname.format(confusion_T)

    filenames_labels = load_filenames_labels()
    print(filenames_labels[:100])
    randomized_label_list = []
    for _, _ in filenames_labels:
        neo_label = np.random.randint(9)
        randomized_label_list.append('{}'.format(neo_label))
    print(randomized_label_list[:100])
    np.savez_compressed(outputname,
                        randomized_label_list=randomized_label_list)
    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch Restricted Imagenet Example')
    parser.add_argument('--T', type=int, default=10)
    parser.add_argument('--R', type=float, default=0.1)
    parser.add_argument('--seed', type=int, default=1)
    args = parser.parse_args()

    T = args.T
    R = args.R
    seed = args.seed
    np.random.seed(seed)

    print('Regenerating confuion data: T={}, R={}.\nRandom seed: {}.'.format(T, R, seed))

    for i in range(1, 1+T):
        generate_random_data(confusion_train_dir_='./confusion_{}_zero_out_{}/train',
                             confusion_T=i,
                             zero_out_ratio=R)
        generate_confusion_label(outputname='randomized_label_{}', confusion_T=i)
