import os
import os.path
from pathlib import Path
from typing import Any, Callable, Optional, Tuple

import torch
from maskrcnn_benchmark.structures.bounding_box import BoxList

from PIL import Image, ImageDraw
from torchvision.datasets.vision import VisionDataset

from .modulated_coco import ConvertCocoPolysToMask, has_valid_annotation


class CustomCocoDetection(VisionDataset):
    """Coco-style dataset imported from TorchVision.
        It is modified to handle several image sources

    Args:
        root_coco (string): Path to the coco images
        root_vg (string): Path to the vg images
        annFile (string): Path to json annotation file.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        transforms (callable, optional): A function/transform that takes input sample and its target as entry
            and returns a transformed version.
    """

    def __init__(
            self,
            root_coco: str,
            root_vg: str,
            annFile: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            transforms: Optional[Callable] = None,
    ) -> None:
        super(CustomCocoDetection, self).__init__(root_coco, transforms, transform, target_transform)
        from pycocotools.coco import COCO

        self.coco = COCO(annFile)
        self.ids = list(sorted(self.coco.imgs.keys()))

        ids = []
        for img_id in self.ids:
            if isinstance(img_id, str):
                ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
            else:
                ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
            anno = self.coco.loadAnns(ann_ids)
            if has_valid_annotation(anno):
                ids.append(img_id)
        self.ids = ids

        self.root_coco = root_coco
        self.root_vg = root_vg

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
        """
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        target = coco.loadAnns(ann_ids)

        img_info = coco.loadImgs(img_id)[0]
        path = img_info["file_name"]
        dataset = img_info["data_source"]

        cur_root = self.root_coco if dataset == "coco" else self.root_vg
        img = Image.open(os.path.join(cur_root, path)).convert("RGB")
        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.ids)


class MixedDataset(CustomCocoDetection):
    """Same as the modulated detection dataset, except with multiple img sources"""

    def __init__(self,
                 img_folder_coco,
                 img_folder_vg,
                 ann_file,
                 transforms,
                 return_masks,
                 return_tokens,
                 tokenizer=None,
                 disable_clip_to_image=False,
                 no_mask_for_gold=False,
                 max_query_len=256,
                 **kwargs):
        super(MixedDataset, self).__init__(img_folder_coco, img_folder_vg, ann_file)
        self._transforms = transforms
        self.max_query_len = max_query_len
        self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len)
        self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
        self.disable_clip_to_image = disable_clip_to_image
        self.no_mask_for_gold = no_mask_for_gold

    def __getitem__(self, idx):
        img, target = super(MixedDataset, self).__getitem__(idx)

        image_id = self.ids[idx]
        caption = self.coco.loadImgs(image_id)[0]["caption"]
        anno = {"image_id": image_id, "annotations": target, "caption": caption}
        anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))]
        if self.no_mask_for_gold:
            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))

        img, anno = self.prepare(img, anno)

        # convert to BoxList (bboxes, labels)
        boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4)  # guard against no boxes
        target = BoxList(boxes, img.size, mode="xyxy")
        classes = anno["labels"]
        target.add_field("labels", classes)
        if not self.disable_clip_to_image:
            num_boxes = len(boxes)
            target = target.clip_to_image(remove_empty=True)
            assert len(target.bbox) == num_boxes, "Box removed in MixedDataset!!!"

        if self._transforms is not None:
            img, target = self._transforms(img, target)

        # add additional property
        for ann in anno:
            target.add_field(ann, anno[ann])

        return img, target, idx

    def get_img_info(self, index):
        img_id = self.id_to_img_map[index]
        img_data = self.coco.imgs[img_id]
        return img_data
