import torch
import torch.distributed as dist
import time
from torchvision.ops import nms
import random
import numpy as np
from PIL import Image, ImageDraw
import pdb
from maskrcnn_benchmark.structures.bounding_box import BoxList
from .modulated_coco import ConvertCocoPolysToMask
from .tsv import ODTSVDataset, TSVYamlDataset
from .od_to_grounding import sanity_check_target_after_processing

class CaptionTSV(TSVYamlDataset):
    def __init__(self,
                 yaml_file,
                 transforms,
                 return_tokens,
                 return_masks,
                 tokenizer,
                 caption_min_box=1,
                 replace_clean_label=False,
                 further_screen=False,
                 caption_conf=0.5,
                 caption_nms=-1,
                 pack_random_caption_number=0,
                 inference_caption=False,
                 sample_negative_for_grounding_data=-1,
                 random_pack_prob=-1.0,
                 no_random_pack_probability=0.0,
                 safeguard_positive_caption=True,
                 mlm_obj_for_only_positive=False,
                 caption_format_version="v1",
                 local_debug=False,
                 max_query_len=256,
                 **kwargs
                 ):
        super(CaptionTSV, self).__init__(yaml_file, None, replace_clean_label)
        self.yaml_file = yaml_file
        self._transforms = transforms
        self.max_query_len = max_query_len
        self.prepare = ConvertCocoPolysToMask(return_masks=return_masks,
                                              return_tokens=return_tokens,
                                              tokenizer=tokenizer,
                                              max_query_len=max_query_len)
        self.tokenizer = tokenizer
        self.caption_min_box = caption_min_box
        self.replace_clean_label = replace_clean_label
        self.further_screen = further_screen
        self.pack_random_caption_number = pack_random_caption_number
        self.caption_format_version = caption_format_version

        self.caption_conf = caption_conf
        self.caption_nms = caption_nms
        self.inference_caption = inference_caption
        self.sample_negative_for_grounding_data = sample_negative_for_grounding_data
        self.random_pack_prob = random_pack_prob
        self.no_random_pack_probability = no_random_pack_probability
        self.safeguard_positive_caption = safeguard_positive_caption
        self.mlm_obj_for_only_positive = mlm_obj_for_only_positive
        try:
            self.rank = dist.get_rank()
        except:
            self.rank = 0

    def __len__(self):
        return super(CaptionTSV, self).__len__()

    def pack_caption(self, positive_caption, negative_captions, original_tokens_positive):
        if len(negative_captions) == 0:
            return positive_caption, original_tokens_positive, [(0, len(positive_caption))]
        if self.safeguard_positive_caption:
            length_of_each_caption = []
            for caption in negative_captions + [positive_caption]:
                tokenized = self.tokenizer(caption, return_tensors="pt")
                length_of_each_caption.append(tokenized.input_ids.size(-1))
            max_length = self.max_query_len - length_of_each_caption[-1]
            indexes = list(range(len(negative_captions)))
            random.shuffle(indexes)
            new_caption_list = [positive_caption]
            for i in indexes:
                if length_of_each_caption[i] < max_length:
                    new_caption_list.append(negative_captions[i])
                    max_length -= length_of_each_caption[i]
        else:
            new_caption_list = [positive_caption] + negative_captions
        random.shuffle(new_caption_list)

        new_caption = ''

        for i in new_caption_list:
            if i == positive_caption:
                start_position = len(new_caption)
            new_caption += i
            if not i.endswith("."):
                new_caption += "."
            new_caption += " "

        # shift the token positions the boxes are aligned to
        for index, i in enumerate(original_tokens_positive):
            original_tokens_positive[index] = [tuple(j) for j in i]
        for i in original_tokens_positive:
            for index, j in enumerate(i):
                i[index] = (j[0] + start_position, j[1] + start_position)

        return new_caption, original_tokens_positive, [(start_position, start_position + len(positive_caption))]

    def __get_negative_captions__(self, idx, negative_size=7):
        negative_captions = []
        for i in range(negative_size):
            img, anno, _, scale = super(CaptionTSV, self).__getitem__(np.random.choice(len(self)))
            caption = anno["caption"]
            negative_captions.append(caption)

        return negative_captions

    def __getitem__(self, idx):
        try:
            img, anno, _, scale = super(CaptionTSV, self).__getitem__(idx)
            if self.inference_caption:
                caption = None
                if isinstance(anno, list):
                    caption = anno[0]["caption"]  # inference mode for bing
                    anno = []
                elif len(anno) == 1:
                    caption = anno["caption"]  # inference mode for googlecc
                    anno = []
                else:
                    caption = " ".join(anno["captions"])
                    anno = []
            else:
                '''
                An example
                {'img_h': 1154, 'img_w': 1600, 'caption': 'xxx', 'tokens_positive': [[[47, 50], [51, 53], [54, 59]], [[32, 35], [36, 41]], [[32, 35], [36, 41]], [[0, 3], [3, 6], [6, 10], [11, 16], [17, 19], [20, 23]], [[32, 35], [36, 41]], [[32, 35], [36, 41]]], 'bboxes': [[7.344961166381836, 10.479412078857422, 1592.2679443359375, 1090.0028076171875], [950.32861328125, 346.572021484375, 1333.2373046875, 679.3215942382812], [927.44140625, 342.7712707519531, 1389.833984375, 719.5758666992188], [90.48786163330078, 363.67572021484375, 1381.8631591796875, 1078.687744140625], [122.84217071533203, 422.6786193847656, 507.845703125, 667.2651977539062], [80.62384033203125, 416.500244140625, 563.1666259765625, 734.603271484375]], 'scores': [0.7966700196266174, 0.8952182531356812, 0.8186006546020508, 0.9995516538619995, 0.8021856546401978, 0.8923134803771973]}
                '''
                if len(anno["bboxes"]) < self.caption_min_box:  # Retry triggered!
                    return self[np.random.choice(len(self))]

                if self.caption_format_version == "v2":
                    anno = self.convert_anno_from_v2_to_v1(anno)

                try:
                    if self.further_screen:
                        conf = self.caption_conf
                        nms_thre = self.caption_nms

                        bboxes = torch.as_tensor(anno["bboxes"]).float()
                        scores = torch.as_tensor(anno["scores"])
                        tokens_positive = anno["tokens_positive"]

                        # print("\n\n\n\n tokens_positive in original data", tokens_positive)

                        keep = scores > conf
                        scores = scores[keep]
                        bboxes = bboxes[keep]
                        tokens_positive = [i for index, i in enumerate(tokens_positive) if keep[index]]

                        assert (len(tokens_positive) == len(bboxes) == len(scores))

                        if len(bboxes) < self.caption_min_box:  # Retry triggered!
                            return self[np.random.choice(len(self))]

                        if nms_thre > 0:
                            keep = nms(boxes=bboxes, scores=scores, iou_threshold=nms_thre)
                            scores = scores[keep]
                            bboxes = bboxes[keep]
                            tokens_positive = [tokens_positive[i] for i in keep]
                            assert (len(tokens_positive) == len(bboxes) == len(scores))

                        # Write back
                        anno["bboxes"] = bboxes.tolist()
                        anno["scores"] = scores.tolist()
                        anno["tokens_positive"] = tokens_positive

                    boxes = torch.as_tensor(anno["bboxes"])

                    if len(boxes) < self.caption_min_box:  # Retry triggered!
                        return self[np.random.choice(len(self))]

                    target = BoxList(boxes, (anno["img_w"], anno["img_h"]), mode="xyxy")
                    target = target.clip_to_image(remove_empty=True)

                    caption = anno["caption"]
                    # print("original caption", caption)
                    empty_everything = False
                    if self.sample_negative_for_grounding_data != -1:
                        if random.random() < self.sample_negative_for_grounding_data:
                            empty_everything = True

                    if empty_everything:
                        caption = self.__get_negative_captions__(idx, negative_size=1)[0]

                    if self.pack_random_caption_number != 0:
                        if self.random_pack_prob != -1.0:
                            if random.random() < self.no_random_pack_probability:
                                negative_pack_number = 0
                            elif random.random() < self.random_pack_prob:
                                negative_pack_number = self.pack_random_caption_number
                            else:
                                negative_pack_number = np.random.choice(self.pack_random_caption_number)
                        else:
                            negative_pack_number = self.pack_random_caption_number

                        negative_captions = self.__get_negative_captions__(idx, negative_size=negative_pack_number)

                        caption, anno["tokens_positive"], greenlight_span_for_masked_lm_objective = self.pack_caption(
                            caption, negative_captions, anno["tokens_positive"])
                    else:
                        greenlight_span_for_masked_lm_objective = [(0, len(caption))]

                    if not self.mlm_obj_for_only_positive:
                        greenlight_span_for_masked_lm_objective = [(0, len(caption))]
                    
                    new_anno = []
                    areas = target.area()
                    for i in range(len(target)):
                        new_anno_i = {}
                        new_anno_i["area"] = areas[i]
                        new_anno_i["iscrowd"] = 0
                        new_anno_i["image_id"] = idx
                        new_anno_i["category_id"] = 1  # following vg and others
                        new_anno_i["id"] = None
                        new_anno_i['bbox'] = target.bbox[i].numpy().tolist()
                        new_anno_i["tokens_positive"] = anno["tokens_positive"][i]
                        new_anno.append(new_anno_i)

                except:
                    return self[np.random.choice(len(self))]

                anno = new_anno
                if empty_everything:
                    anno = []

            annotations = {"image_id": idx, "annotations": anno, "caption": caption}
            annotations["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
            img, annotations = self.prepare(img, annotations, box_format="xyxy")

            if self._transforms is not None:
                img, target = self._transforms(img, target)

            # add additional property
            for ann in annotations:
                target.add_field(ann, annotations[ann])
        except:
            print("Outter Retry triggered!!")
            return self[np.random.choice(len(self))]

        sanity_check_target_after_processing(target)
        
        return img, target, idx

    def convert_anno_from_v2_to_v1(self, anno):
        flatterned_bboxes = []
        flatterned_tokens_positive = []
        flatterned_bboxes_scores = []
        for i in range(len(anno["bboxes"])):
            # i is the index for entity
            for j in range(len(anno["bboxes"][i])):
                # j is the index for each box
                flatterned_bboxes.append(anno["bboxes"][i][j])
                flatterned_tokens_positive.append(
                    anno["tokens_positive"][i])  # Assume this box corresponds to all the token_spans for this entity
                flatterned_bboxes_scores.append(anno["scores"][i][j])
        anno["bboxes"] = flatterned_bboxes
        anno["tokens_positive"] = flatterned_tokens_positive
        anno["scores"] = flatterned_bboxes_scores
        return anno


    def get_raw_image(self, idx):
        image, *_ = super(CaptionTSV, self).__getitem__(idx)
        return image

    def get_img_id(self, idx):
        line_no = self.get_line_no(idx)
        if self.label_tsv is not None:
            row = self.label_tsv.seek(line_no)
            img_id = row[0]
            return img_id
