import argparse
import numpy
import os
import shutil
import time
import sys
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.utils.data as data
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader


def get_dataloader(args):
    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}

    if not args.DA_for_train:
        print('No Data Augmentation.')
        transforms_ = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    else:
        print('Use Data Augmentation.')
        transforms_ = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

    mode = args.mode
    if mode.startswith('random_init'):
        print('Apply random initialization.')
        train_dataset = CustomizedImageNetFolder('./restricted_imagenet_data/restricted_imagenet/train',
                                                 transform=transforms_, confusion=False)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    elif mode.startswith('adversarial_init'):
        print('Apply adversarial initialization.')
        if args.Is_Init:
            print('In Adversarial Initialization / Pre Training.')
            train_dataset = CustomizedImageNetFolder('./restricted_imagenet_data/confusion_{}_zero_out_{}/train'.format(1, args.zero_out_ratio),
                                                     confusion_label_file='./restricted_imagenet_data/randomized_label_{}.npz'.format(1),
                                                     transform=transforms_, confusion=True)
            for confusion_t in range(2, args.confusion_T+1):
                train_dataset = ConcatedConfusionFolder('./restricted_imagenet_data/confusion_{}_zero_out_{}/train'.format(confusion_t, args.zero_out_ratio),
                                                        confusion_label_file='./restricted_imagenet_data/randomized_label_{}.npz'.format(confusion_t),
                                                        transform=transforms_, image_folder=train_dataset)
            print('Confusion set size:', len(train_dataset))
        else:
            print('In Main Training / Fine Tuning.')
            train_dataset = CustomizedImageNetFolder('./restricted_imagenet_data/restricted_imagenet/train',
                                                     transform=transforms_, confusion=False)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

    else:
        raise ValueError('Mode not included.'.format(mode))

    test_dataset = CustomizedImageNetFolder('./restricted_imagenet_data/restricted_imagenet/val',
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                     (0.2023, 0.1994, 0.2010)),
                                            ]))
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader


class ConcatedConfusionFolder(datasets.ImageFolder):
    def __init__(self, root, confusion_label_file, image_folder, transform=None, target_transform=None, loader=default_loader):
        super(ConcatedConfusionFolder, self).__init__(root, loader=loader, transform=transform,
                                                      target_transform=target_transform)
        print('Merge with image folder, data from {}, and confusion label from {}.'.format(root, confusion_label_file))
        confused_data = np.load(confusion_label_file)
        confused_label = confused_data['randomized_label_list']
        confused_label = confused_label.astype(int)
        self.imgs = [[self.imgs[i][0], confused_label[i]] for i in range(len(self.imgs))]

        self.imgs += list(image_folder.imgs)
        self.imgs = tuple(self.imgs)


class CustomizedImageNetFolder(datasets.ImageFolder):
    def __init__(self, root, confusion_label_file=None, transform=None, target_transform=None,
                 loader=default_loader, confusion=False):
        super(CustomizedImageNetFolder, self).__init__(root, loader=loader, transform=transform,
                                                       target_transform=target_transform)
        # self.samples = self.samples
        # print type(self.imgs[0][0]), self.imgs[0][0],'\t', type(self.imgs[0][1]), self.imgs[:10]
        # print type(self.imgs), len(self.imgs)
        if confusion:
            print('Load confusion data from {}, confusion labels from {}.'.format(root, confusion_label_file))
            confused_data = np.load(confusion_label_file)
            confused_label = confused_data['randomized_label_list']
            confused_label = confused_label.astype(int)
            self.imgs = tuple([[self.imgs[i][0], confused_label[i]] for i in range(len(self.imgs))])
        else:
            print('Load actual data and label from {}'.format(root))
        # print type(self.imgs[0][0]), self.imgs[0][0],'\t', type(self.imgs[0][1]), self.imgs[:10]


class ImageFolderWithPaths(datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        print(original_tuple[1], type(original_tuple[1]))
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path


def check_file_path():
    train_dataset = ImageFolderWithPaths(
            '././tiny_imagenet_data/tiny-imagenet-200/train',
            transforms.Compose([
                transforms.RandomSizedCrop(224), #224 , 299
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]))
    print(train_dataset.classes)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, num_workers=1)

    for idx,[x,y, path] in enumerate(train_loader):
        print(idx, x[0].size(), y[0], path[0])
        break


def check_confusion():
    train_dataset = CustomizedImageNetFolder(
        '././tiny_imagenet_data/tiny-imagenet-200/train',
        transforms.Compose([
            transforms.RandomSizedCrop(224),  # 224 , 299
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]),
        confusion=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, num_workers=1)

    for idx,[x,y] in enumerate(train_loader):
        print(idx, x.size())
        print(y)
        break


def check_class(root):
    classes, class_to_idx = _find_classes(root)
    prin(classes)
    print(class_to_idx)


def check_val():
    test_dataset = CustomizedImageNetValFolder('././tiny_imagenet_data/tiny-imagenet-200/val',
                                            transform=transforms.Compose([
                                                transforms.Scale(256),
                                                transforms.CenterCrop(224),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                     (0.2023, 0.1994, 0.2010)),
                                            ]))
    print('test_dataset\t', test_dataset.classes)

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, num_workers=4)

    for batch_idx, (data, target) in enumerate(test_loader):
        print(batch_idx, '\t', data.size(), '\t', target.size())


if __name__ == '__main__':
    # check_file_path()
    # check_confusion()
    # check_class('././tiny_imagenet_data/tiny-imagenet-200/train')
    check_val()