import math
import warnings
import random
import torch
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision


def load_image_in_PIL(path, mode='RGB'):
    img = Image.open(path)
    img.load()
    return img.convert(mode)

def remove_borders(image, borders=None):
    image = np.array(image)                  # (H, W, C)    
    
    rows = np.mean(image, axis=(1, 2))       # (H)
    columns = np.mean(image, axis=(0, 2))    # (W)
    
    if borders is None:
        borders = {}
        top = 0
        for i in rows:
            if i < 1: top += 1
            else: break

        bottom = 0
        for i in rows[::-1]:
            if i < 1: bottom += 1
            else: break

        left = 0
        for i in columns:
            if i < 1: left += 1
            else: break

        right = 0
        for i in columns[::-1]:
            if i < 1: right += 1
            else: break

        borders["top"] = top
        borders["bottom"] = bottom
        borders["left"] = left
        borders["right"] = right

    else:
        bottom = borders["bottom"]
        top = borders["top"]
        right = borders["right"]
        left = borders["left"]

    if bottom == 0: image = image[top:]
    else: image = image[top:-bottom]
    
    if right == 0: image = image[:, left:]
    else: image = image[:, left:-right]
        
    image = Image.fromarray(np.uint8(image))

    
    return image, borders


class To_One_Hot(object):
    def __init__(self, max_obj_n, shuffle):
        self.max_obj_n = max_obj_n
        self.shuffle = shuffle

    def __call__(self, mask, obj_list=None):
        new_mask = np.zeros((self.max_obj_n, *mask.shape), np.uint8)

        if not obj_list:
            obj_list = list()
            obj_max = mask.max() + 1
            for i in range(1, obj_max):
                tmp = (mask == i).astype(np.uint8)
                if tmp.max() > 0:
                    obj_list.append(i)

            if self.shuffle:
                random.shuffle(obj_list)
            obj_list = obj_list[:self.max_obj_n - 1]

        for i in range(len(obj_list)):
            new_mask[i + 1] = (mask == obj_list[i]).astype(np.uint8)
        new_mask[0] = 1 - np.sum(new_mask, axis=0)

        return torch.from_numpy(new_mask), obj_list

    def __repr__(self):
        return self.__class__.__name__ + '(max_obj_n={})'.format(self.max_obj_n)