import os
import glob
import time
import scipy
import random
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from skimage.io import imread, imsave

import cv2
import argparse

parser = argparse.ArgumentParser(description='ToyFG dataset')
parser.add_argument('--level', default='E', type=str)
parser.add_argument('--PascalDir', default='/xxx/VOC/SegmentationObject/', type=str)
parser.add_argument('--DTDDir', default='/xxx/DTD dataset/dtd/images/', type=str)
parser.add_argument('--SaveDir', default='/xxx/VOC/FGDataset', type=str)

args = parser.parse_args()

palette = [[128, 0, 0],
        [  0, 128,   0],
        [128, 128,   0],
        [  0,   0, 128],
        [128,   0, 128],
        [  0, 128, 128],
        [128, 128, 128],
        [ 64,   0,   0],
        [192,   0,   0],
        [ 64, 128,   0],
        [192, 128,   0],
        [ 64,   0, 128],
        [192,   0, 128],
        [ 64, 128, 128],
        [192, 128, 128],
        [  0,  64,   0],
        [128,  64,   0],
        [  0, 192,   0],
        [128, 192,   0],
        [  0,  64, 128]]

palette = np.array(palette)


def rotate_bound(image, angle):

    (h, w) = image.shape[:2]
    (cX, cY) = (w // 2, h // 2)

    M = cv2.getRotationMatrix2D((cX, cY), -angle, 1.0)
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])

    nW = int((h * sin) + (w * cos))
    nH = int((h * cos) + (w * sin))

    M[0, 2] += (nW / 2) - cX
    M[1, 2] += (nH / 2) - cY

    return cv2.warpAffine(image, M, (nW, nH))


def mask_to_onehot(mask, pallete):

    semantic_map = []
    for colour in palette:
        equality = np.equal(mask, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
    return semantic_map


def mask_to_ins(mask):

    _, _, c = mask.shape
    mask_view = mask.reshape((-1, c))
    mask_view = mask_view.tolist()
    
    pl = [] 
    [pl.append(i) for i in mask_view if not i in pl] 
    pl = pl[2:]
    
    if len(pl) > 0: 
        instance_map = []
        for colour in pl:
            equality = np.equal(mask, colour)
            object_map = np.all(equality, axis=-1)
            instance_map.append(object_map)
        instance_map = np.stack(instance_map, axis=-1).astype(np.float32)
    else:
        instance_map = np.zeros((1, 1, 3))

    return instance_map


def cut(img):

    w, h, _ = img.shape
    cropImg = img[(w//2-120):(w//2+120), (h//2-150):(h//2+150), :]
    
    return cropImg

if args.level == 'E':
    split = 40
    region = 10000
    trainR_b = 1
    trainR_t = 90
    testR_b = 1
    testR_t = 90
elif args.level == 'N':
    split = 10
    region = 1500
    trainR_b = 30
    trainR_t = 90
    testR_b = 1
    testR_t = 60
elif args.level == 'H':
    split = 3
    region = 0
    trainR_b = 60
    trainR_t = 90
    testR_b = 1
    testR_t = 30

root_dir = args.PascalDir            # Pascal VOC dir
root_save = args.SaveDir                     # save dir
root_texture_dir = args.DTDDir                # texture source

label_ext = '.png'                                                            # label extension

save_gt_dirT = root_save + args.level +'/TrainDataset/GT/'
save_img_dirT = root_save + args.level +'/TrainDataset/Imgs/'
save_gt_dirE = root_save + args.level +'/TestDataset/ToyFG/GT/'
save_img_dirE = root_save + args.level +'/TestDataset/ToyFG/Imgs/'

texture_list = os.listdir(root_texture_dir)

print('training used texture:')
print(texture_list[:split])
print('testing used texture:')
print(texture_list[split:])

texture_list_T = []
texture_list_E = []
k = 0

for t in texture_list:
    if k <= split:  
        texture_list_T.append(glob.glob(root_texture_dir + t + '/*' + '.jpg'))
    else:
        texture_list_E.append(glob.glob(root_texture_dir + t + '/*' + '.jpg'))
    k = k + 1

texture_list_T = sum(texture_list_T, [])
texture_list_E = sum(texture_list_E, [])

label_list = glob.glob(root_dir + '*' + label_ext)

c = 0
t = time.time()
for a in range(1):
    k = 0.0
    for label in label_list:
        print("[ %.2f %%]" % (k * 100 / len(label_list)), end="\r", flush=True)
        lab = cv2.imread(label)
        SM = mask_to_ins(lab)
        for i in range(np.shape(SM)[2]):
            temp = SM[:, :, i]
            if temp.sum() > region:
                
                # Training data
                if c < 2000:
                    texture_add = random.choice(texture_list_T)
                    texture = cv2.imread(texture_add)
                    tempa = np.expand_dims(temp, axis=2)
                    gt = np.concatenate((tempa, tempa, tempa), axis=-1) 
                    w, h, _ = gt.shape
                    texture_f = cv2.resize(texture, (h, w), interpolation=cv2.INTER_CUBIC)

                    ang = random.randint(trainR_b, trainR_t)
                    texture_g = rotate_bound(texture_f, ang)
                    texture_g = cut(texture_g)
                    texture_g = cv2.resize(texture_g, (h, w), interpolation=cv2.INTER_CUBIC)
                    imgs = texture_f * gt + texture_g * (1 - gt)
                    gt = gt * 255

                    saveImgs = save_img_dirT + str(c) + '.jpg'
                    saveGT = save_gt_dirT + str(c) + '.png'
                    cv2.imwrite(saveImgs, imgs)
                    cv2.imwrite(saveGT, gt)

                # Testing data
                elif c > 2000 and c < 2500:
                    texture_add = random.choice(texture_list_E)
                    texture = cv2.imread(texture_add)
                    tempa = np.expand_dims(temp, axis=2)
                    gt = np.concatenate((tempa, tempa, tempa), axis=-1) 
                    w, h, _ = gt.shape
                    texture_f = cv2.resize(texture, (h, w), interpolation=cv2.INTER_CUBIC)

                    ang = random.randint(testR_b, testR_t)
                    texture_g = rotate_bound(texture_f, ang)
                    texture_g = cut(texture_g)
                    texture_g = cv2.resize(texture_g, (h, w), interpolation=cv2.INTER_CUBIC)
                    imgs = texture_f * gt + texture_g * (1 - gt)
                    gt = gt * 255

                    saveImgs = save_img_dirE + str(c) + '.jpg'
                    saveGT = save_gt_dirE + str(c) + '.png'
                    cv2.imwrite(saveImgs, imgs)
                    cv2.imwrite(saveGT, gt)
                c = c + 1

        k = k + 1

print(c)
print(time.time() - t)
