"""
COCO dataset which returns image_id for evaluation.

Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""

import torch
import json
from PIL import Image, ImageDraw

from .modulated_coco import ConvertCocoPolysToMask
from .tsv import ODTSVDataset
from pycocotools.coco import COCO
from maskrcnn_benchmark.structures.bounding_box import BoxList
import random
from .od_to_grounding import convert_object_detection_to_grounding_optimized_for_od, check_for_positive_overflow, sanity_check_target_after_processing


class CocoDetectionTSV(ODTSVDataset):
    def __init__(self,
                 name,
                 yaml_file,
                 transforms,
                 return_tokens,
                 tokenizer,
                 extra_fields,
                 random_sample_negative=-1,
                 add_detection_prompt=False,
                 add_detection_prompt_advanced=False,
                 use_od_data_aug=False,
                 control_probabilities={},
                 disable_shuffle=False,
                 prompt_engineer_version="v2",
                 prompt_limit_negative=-1,
                 positive_question_probability=0.6,
                 negative_question_probability=0.8,
                 full_question_probability=0.5,
                 disable_clip_to_image=False,
                 separation_tokens=" ",
                 no_mask_for_od=False,
                 max_num_labels=-1,
                 max_query_len=256,
                 **kwargs
                 ):
        super(CocoDetectionTSV, self).__init__(yaml_file, extra_fields, **kwargs)

        self._transforms = transforms
        self.name = name
        self.max_query_len = max_query_len
        self.prepare = ConvertCocoPolysToMask(
            return_masks=False,
            return_tokens=return_tokens,
            tokenizer=tokenizer,
            max_query_len=max_query_len
        )
        self.tokenizer = tokenizer

        self.control_probabilities = control_probabilities
        self.random_sample_negative = random_sample_negative
        self.add_detection_prompt = add_detection_prompt
        self.add_detection_prompt_advanced = add_detection_prompt_advanced
        self.use_od_data_aug = use_od_data_aug

        self.prompt_engineer_version = prompt_engineer_version
        self.prompt_limit_negative = prompt_limit_negative
        self.positive_question_probability = positive_question_probability
        self.negative_question_probability = negative_question_probability
        self.full_question_probability = full_question_probability
        self.separation_tokens = separation_tokens
        self.disable_clip_to_image = disable_clip_to_image
        self.disable_shuffle = disable_shuffle
        self.no_mask_for_od = no_mask_for_od
        self.max_num_labels = max_num_labels

    def __len__(self):
        return super(CocoDetectionTSV, self).__len__()

    def categories(self, no_background=True):
        categories = self.coco.dataset["categories"]
        label_list = {}
        for index, i in enumerate(categories):
            # assert(index + 1 == i["id"])
            if not no_background or (i["name"] != "__background__" and i['id'] != 0):
                label_list[i["id"]] = i["name"]
        return label_list

    def __getitem__(self, idx):
        # tgt is a BoxList
        img, target, _, scale = super(CocoDetectionTSV, self).__getitem__(idx)
        image_id = self.get_img_id(idx)
        restricted_negative_list = None

        if not self.disable_clip_to_image:
            target = target.clip_to_image(remove_empty=True)

        original_box_num = len(target)

        target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens

        if len(target) < original_box_num:
            print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target)))

        annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od(
            target=target,
            image_id=image_id,
            ind_to_class=self.ind_to_class,
            disable_shuffle=self.disable_shuffle,
            add_detection_prompt=self.add_detection_prompt,
            add_detection_prompt_advanced=self.add_detection_prompt_advanced,
            random_sample_negative=self.random_sample_negative,
            control_probabilities=self.control_probabilities,
            restricted_negative_list=restricted_negative_list,
            separation_tokens=self.separation_tokens,
            max_num_labels=self.max_num_labels,
            positive_caption_length=positive_caption_length,
            tokenizer=self.tokenizer,
            max_seq_length=self.max_query_len-2
        )

        # assert(len(self.tokenizer.tokenize(caption)) <= self.max_query_len-2)

        # print(caption)
        anno = {"image_id": image_id, "annotations": annotations, "caption": caption, "label_to_positions": label_to_positions}
        anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective

        if self.no_mask_for_od:
            anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))

        img, anno = self.prepare(img, anno, box_format="xyxy")

        if self._transforms is not None:
            img, target = self._transforms(img, target)
        
        # add additional property
        for ann in anno:
            target.add_field(ann, anno[ann])

        sanity_check_target_after_processing(target)

        return img, target, idx

    def get_raw_image(self, idx):
        image, *_ = super(CocoDetectionTSV, self).__getitem__(idx)
        return image

    def get_img_id(self, idx):
        line_no = self.get_line_no(idx)
        if self.label_tsv is not None:
            row = self.label_tsv.seek(line_no)
            img_id = row[0]
            try:
                return int(img_id)
            except:
                return idx
