import os

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as T
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_file_from_google_drive
from torch.utils import data
try:
    from .utils import TransformTwice
except ImportError:
    from utils import TransformTwice



class Cub2011(VisionDataset):
    """`CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.

        Args:
            root (string): Root directory of the dataset.
            train (bool, optional): If True, creates dataset from training set, otherwise
               creates from test set.
            transform (callable, optional): A function/transform that  takes in an PIL image
               and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
               target and transforms it.
            download (bool, optional): If true, downloads the dataset from the internet and
               puts it in root directory. If dataset is already downloaded, it is not
               downloaded again.
    """
    base_folder = 'CUB_200_2011/images'
    # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
    file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, target_list=list(range(5))):
        super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform)

        self.loader = default_loader
        self.train = train
        if download:
            self._download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')

        self.imgs = []
        self.targets = []
        for idx in range(len(self.data)):
            sample = self.data.iloc[idx]
            path = os.path.join(self.root, self.base_folder, sample.filepath)
            target = sample.target - 1  # np.int64
            img = self.loader(path)     # PIL Image
            self.imgs.append(img)
            self.targets.append(target)

        ind = [
            i for i in range(len(self.targets)) if int(self.targets[i]) in target_list
        ]

        self.imgs = [
            self.imgs[i] for i in range(len(self.targets)) if int(self.targets[i]) in target_list
        ]

        self.targets = np.array(self.targets)
        self.targets = self.targets[ind].tolist()


    def _load_metadata(self):
        images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        self.data = data.merge(train_test_split, on='img_id')

        class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),
                                  sep=' ', names=['class_name'], usecols=[1])
        self.class_names = class_names['class_name'].to_list()
        if self.train:
            self.data = self.data[self.data.is_training_img == 1]
        else:
            self.data = self.data[self.data.is_training_img == 0]


    def _check_integrity(self):
        try:
            self._load_metadata()
        except Exception:
            return False

        for index, row in self.data.iterrows():
            filepath = os.path.join(self.root, self.base_folder, row.filepath)
            if not os.path.isfile(filepath):
                print(filepath)
                return False
        return True

    def _download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5)

        with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
            tar.extractall(path=self.root)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        # sample = self.data.iloc[idx]
        # path = os.path.join(self.root, self.base_folder, sample.filepath)
        # target = sample.target - 1  # Targets start at 1 by default, so shift to 0
        # img = self.loader(path)

        img = self.imgs[idx]
        target = self.targets[idx]

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        if isinstance(target, torch.Tensor):
            target = target.clone().detach().long()
        else:
            target = torch.tensor(target).long()

        return img, target, idx


def CubData(root, split='train', aug=None, target_list=range(80)):
    if aug == None:
        transform = T.Compose([
            T.Resize(size=(224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    elif aug == 'once':
        transform = T.Compose([
            T.Resize(size=(224, 224)),
            T.RandomHorizontalFlip(),
            T.RandomCrop(size=(224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    elif aug == 'twice':
        transform = TransformTwice(
            T.Compose([
                T.Resize(size=(224, 224)),
                T.RandomHorizontalFlip(),
                T.RandomCrop(size=(224, 224)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]))
    return Cub2011(root, train=split == 'train', transform=transform, target_list=target_list)

def CubLoader(root, batch_size, split='train', num_workers=2, aug=None, shuffle=True, target_list=range(100)):
    dataset = CubData(root, split=split, aug=aug, target_list=target_list)
    loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=split=='train')
    return loader

def CubLoaderMix(root, batch_size, split='train', num_workers=2, aug=None, shuffle=True, labeled_list=range(140), unlabeled_list=range(140, 200)):

    dataset_labeled = CubData(root, split=split, aug=aug, target_list=labeled_list)
    dataset_unlabeled = CubData(root, split=split, aug=aug, target_list=unlabeled_list)

    dataset_labeled.targets = np.concatenate((dataset_labeled.targets, dataset_unlabeled.targets))
    # dataset_labeled.data = np.concatenate((dataset_labeled.data, dataset_unlabeled.data), 0)
    dataset_labeled.imgs = dataset_labeled.imgs + dataset_unlabeled.imgs

    assert dataset_labeled.targets.shape[0] == len(dataset_labeled.imgs) == len(dataset_labeled)

    loader = data.DataLoader(dataset_labeled, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
    return loader


if __name__ == '__main__':
    train_dataset = Cub2011('/repository2/zxw', train=True, download=False)
    test_dataset = Cub2011('/repository2/zxw', train=False, download=False)
