# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color
import numpy as np
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
#==========================dataset load==========================
import scipy.io as sio
import random
import skimage
import skimage.filters.rank as sfr


class Randomflip(object):

    def __init__(self, p=0.5):
        self.p = p
        
    def __call__(self, sample):

        image, label, edge = sample['image'], sample['label'], sample['edge']

        flip_flag = random.random()
        if flip_flag > self.p:
            img = image[:,:,::-1].copy()
            lbl = label[:,:,::-1].copy()
            edg = edge[:,:,::-1].copy()
        else:
            img = image
            lbl = label
            edg = edge
        
        return {'image': img,'label': lbl, 'edge': edg}

class RescaleT(object):

    def __init__(self,output_size):
        assert isinstance(output_size,(int,tuple))
        self.output_size = output_size

    def __call__(self,sample):

        image, label, edge = sample['image'], sample['label'], sample['edge']

        h, w = image.shape[:2]

        if isinstance(self.output_size,int):
            if h > w:
                new_h, new_w = self.output_size*h/w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size*w/h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

		# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
        img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
        lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0, preserve_range=True)
        edg = transform.resize(edge, (self.output_size, self.output_size), mode='constant', order=0, preserve_range=True)

        return {'image': img,'label': lbl, 'edge': edg}

class RandomCrop(object):

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
    def __call__(self,sample):
        
        image, label, edge = sample['image'], sample['label'], sample['edge']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h, left: left + new_w]
        label = label[top: top + new_h, left: left + new_w]
        edge = edge[top: top + new_h, left: left + new_w]

        return {'image': image, 'label': label, 'edge': edge}

class ToTensorLab(object):
    """Convert ndarrays in sample to Tensors."""
    def __init__(self, flag=0, state='FG'):
        self.flag = flag
        self.state = state

    def __call__(self, sample):

        image, label = sample['image'], sample['label']
        
        # label
        tmpLbl = np.zeros(label.shape)

        if(np.max(label)<1e-6):
            label = label
        else:
            label = label / np.max(label)
        
            
        '''
        # edge
        tmpEdg = np.zeros(edge.shape)

        if(np.max(edge)<1e-6):
            edge = edge
        else:
            edge = edge / np.max(edge)
            
        # gen structure
        structure = label
        
        y_coord = np.ones_like(structure, dtype=np.float32)
        x_coord = np.ones_like(structure, dtype=np.float32)
        y_coord = np.cumsum(y_coord, axis=0) - 1
        x_coord = np.cumsum(x_coord, axis=1) - 1
        
        offset = np.zeros((2, structure.shape[0], structure.shape[1]), dtype=np.float32)
        
        mask_index = np.where(structure > 0)
        center_y, center_x = np.mean(mask_index[0]), np.mean(mask_index[1])
        
        offset_y_index = (np.zeros_like(mask_index[0]), mask_index[0], mask_index[1])
        offset_x_index = (np.ones_like(mask_index[0]), mask_index[0], mask_index[1])
        offset[offset_y_index] = center_y - y_coord[mask_index]
        offset[offset_x_index] = center_x - x_coord[mask_index]
        '''

        if(self.state=='FG'):
            # gen boundary order
            kernel = skimage.morphology.disk(5)
            boundary_o = label[:, :, 0]
            
            label_ero = skimage.morphology.erosion(boundary_o, kernel)
            label_ero = boundary_o - label_ero
            
            label_dil = skimage.morphology.dilation(boundary_o, kernel)
            label_dil = label_dil - boundary_o
            
            boundary_order = np.zeros((2, label.shape[0], label.shape[1]), dtype=np.float32)
            
            boundary_order[0] = label_ero
            boundary_order[1] = label_dil
            
            # gen boundary morphology
            kernel = skimage.morphology.disk(10)
            boundary_o = label[:, :, 0]
            
            label_op = skimage.morphology.opening(boundary_o, kernel)
            label_op = boundary_o - label_op
            
            label_cl = skimage.morphology.closing(boundary_o, kernel)
            label_cl = label_cl - boundary_o
            
            #label_opcl = label_cl + label_op
            
            boundary_opcl = np.zeros((2, label.shape[0], label.shape[1]), dtype=np.float32)
            
            boundary_opcl[0] = label_cl
            boundary_opcl[1] = label_op
        
		# change the color space
        if self.flag == 2: # with rgb and Lab colors
            tmpImg = np.zeros((image.shape[0], image.shape[1],6))
            tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
            if image.shape[2]==1:
                tmpImgt[:,:,0] = image[:,:,0]
                tmpImgt[:,:,1] = image[:,:,0]
                tmpImgt[:,:,2] = image[:,:,0]
            else:
                tmpImgt = image
            tmpImgtl = color.rgb2lab(tmpImgt)

			# nomalize image to range [0,1]
            tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
            tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
            tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
            tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
            tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
            tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
            tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
            tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
            tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])

        elif self.flag == 1: #with Lab color
            tmpImg = np.zeros((image.shape[0],image.shape[1],3))

            if image.shape[2]==1:
                tmpImg[:,:,0] = image[:,:,0]
                tmpImg[:,:,1] = image[:,:,0]
                tmpImg[:,:,2] = image[:,:,0]
            else:
                tmpImg = image

            tmpImg = color.rgb2lab(tmpImg)

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])

        else: # with rgb color
            tmpImg = np.zeros((image.shape[0],image.shape[1],3))
            image = image/np.max(image)
            if image.shape[2]==1:
                tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
            else:
                tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
                tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225

        tmpLbl[:,:,0] = label[:,:,0]

		# change the r,g,b to b,r,g from [0,255] to [0,1]
		#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        
        #tmpEdg[:,:,0] = edge[:,:,0]
        #tmpEdg = edge.transpose((2, 0, 1))
        
        tmpImg = tmpImg.transpose((2, 0, 1))
        tmpLbl = label.transpose((2, 0, 1))
        
        '''
        img_name = '/home/zw/pytorch/zhaiwei/Priming/PrimingNet/test.mat'
        sio.savemat(img_name, {'order': boundary_order, 'opcl': boundary_opcl, 'mask': tmpLbl})
        '''
        
        if(self.state=='FG'):
            return {'image': torch.from_numpy(tmpImg),
    			'label': torch.from_numpy(tmpLbl),
                'order': torch.from_numpy(boundary_order),
                'conx': torch.from_numpy(boundary_opcl)}
        else:
            return {'image': torch.from_numpy(tmpImg),
			'label': torch.from_numpy(tmpLbl)}

class CObjDataset(Dataset):
    def __init__(self, img_name_list, lbl_name_list, edg_name_list, transform=None):

        self.image_name_list = img_name_list
        self.label_name_list = lbl_name_list
        self.edg_name_list = edg_name_list
        self.transform = transform

    def __len__(self):
        return len(self.image_name_list)

    def __getitem__(self, idx):
        
        # image load
        image = io.imread(self.image_name_list[idx])
        
        if(0==len(self.label_name_list)):
            label_temp = np.zeros(image.shape)
        else:
            label_temp = io.imread(self.label_name_list[idx])

        label = np.zeros(label_temp.shape[0:2])
        if(3==len(label_temp.shape)):
            label = label_temp[:,:,0]
        elif(2==len(label_temp.shape)):
            label = label_temp
        
        #edge
        if(0==len(self.edg_name_list)):
            edge_temp = np.zeros(image.shape)
        else:
            edge_temp = io.imread(self.edg_name_list[idx])

        edge = np.zeros(edge_temp.shape[0:2])
        if(3==len(edge_temp.shape)):
            edge = edge_temp[:,:,0]
        elif(2==len(edge_temp.shape)):
            edge = edge_temp

        if(3==len(image.shape) and 2==len(label.shape) and 2==len(edge.shape)):
            label = label[:,:,np.newaxis]
            edge = edge[:,:,np.newaxis]
        elif(2==len(image.shape) and 2==len(label.shape) and 2==len(edge.shape)):
            image = image[:,:,np.newaxis]
            label = label[:,:,np.newaxis]
            edge = edge[:,:,np.newaxis]
        
        sample = {'image':image, 'label':label, 'edge':edge}

        if self.transform:
            sample = self.transform(sample)

        return sample
   
