import numpy as np

from tensorflow.keras import backend as K

import scipy.ndimage as ndi
import random


def transform_matrix_offset_center(matrix, x, y):
    o_x = float(x) / 2 + 0.5
    o_y = float(y) / 2 + 0.5
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
    return transform_matrix

def apply_transform(x, transform_matrix, channel_index=0, fill_mode='nearest', cval=0.):
    x = np.rollaxis(x, channel_index, 0)
    final_affine_matrix = transform_matrix[:2, :2]
    final_offset = transform_matrix[:2, 2]
    channel_images = [ndi.interpolation.affine_transform(x_channel, final_affine_matrix,
                      final_offset, order=0, mode=fill_mode, cval=cval) for x_channel in x]
    x = np.stack(channel_images, axis=0)
    x = np.rollaxis(x, 0, channel_index+1)
    return x

def add_gaussian_noise(img, level):
    std=level*np.random.uniform(0, 1)
    noise = np.abs(np.random.normal(0,std,img.shape))
    img=img+noise
    
    #return np.clip(img, 0, 1)
    return img

def change_contrast(img, level):
    if len(img.shape) ==0:
        cont=1+level-2*level*np.random.uniform(0, 1)
        img=img*cont
    else :
        for i in range(img.shape[2]):
            cont=1+level-2*level*np.random.uniform(0, 1)
            img[:,:,i]*=cont
    
    return img
def add_salt_and_pepper(img, level):
    num_salt = np.ceil(level * img.size *np.random.uniform(0, 1))
    
    coords = [np.random.randint(0, i - 1, int(num_salt))
              for i in img.shape[:2]]
    img[tuple(coords)] = 1
    return img

def translate(img,level,fill_mode):
    nb=int(max(img.shape[0],img.shape[1])*level)
    x=random.randint(0,nb)
    y=random.randint(0,nb)
    #print(x,y)
    img=ndi.shift(img, (x,y,0), mode=fill_mode)
    return img
           
def random_crop_img(img,size_crop):
    x = np.random.randint(0,img.shape[0]-size_crop)
    y = np.random.randint(0,img.shape[1]-size_crop)
    img[x:x+size_crop, y:y+size_crop,:]=0.      
    return img
class ImageTransformer:
    
    def __init__(self, 
                 rotation_range=0.,
                 shear_range=0.,
                 zoom_range=0.,
                 fill_mode='reflect',
                 contrast_level=0,
                 gaussian_noise=0,
                 salt_and_pepper=0,
                 height_shift_range=0,
                 width_shift_range=0,
                 random_crop = 0,
                 flip_horizontal=False,
                 cval=0.):
        
        self.data_format = K.image_data_format()
        if self.data_format == 'channels_first':
            self.channel_axis = 1
            self.row_axis = 2
            self.col_axis = 3
        if self.data_format == 'channels_last':
            self.channel_axis = 3
            self.row_axis = 1
            self.col_axis = 2
        self.flip_horizontal=flip_horizontal
        self.width_shift_range = width_shift_range
        self.height_shift_range = height_shift_range
        self.rotation_range = rotation_range
        self.random_crop = random_crop
        self.shear_range = shear_range
        self.zoom_range = zoom_range
        if np.isscalar(zoom_range):
            self.zoom_range = [1 - zoom_range, 1 + zoom_range]
        self.fill_mode = fill_mode
        self.cval = cval
        self.contrast_level=contrast_level
        self.gaussian_noise=gaussian_noise
        self.salt_and_pepper=salt_and_pepper
    
    def random_transform(self, x):
        img_row_axis = self.row_axis - 1
        img_col_axis = self.col_axis - 1
        img_channel_axis = self.channel_axis - 1
      
        if self.height_shift_range:
            tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) * x.shape[img_row_axis]
        else:
            tx = 0

        if self.width_shift_range:
            ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) * x.shape[img_col_axis]
        else:
            ty = 0
        if self.rotation_range:
            theta = np.pi / 180 * np.random.uniform(-self.rotation_range, self.rotation_range)
        else:
            theta = 0

        if self.shear_range:
            shear = np.random.uniform(-self.shear_range, self.shear_range)
        else:
            shear = 0

        if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
            zx, zy = 1, 1
        else:
            zx,zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2)
            zy=zx

        transform_matrix = None
        if theta != 0:
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                        [np.sin(theta), np.cos(theta), 0],
                                        [0, 0, 1]])
            transform_matrix = rotation_matrix

        if tx != 0 or ty != 0:
            shift_matrix = np.array([[1, 0, tx],
                                     [0, 1, ty],
                                     [0, 0, 1]])
            transform_matrix = shift_matrix if transform_matrix is None else np.dot(transform_matrix, shift_matrix)
        if shear != 0:
            shear_matrix = np.array([[1, -np.sin(shear), 0],
                                    [0, np.cos(shear), 0],
                                    [0, 0, 1]])
            transform_matrix = shear_matrix if transform_matrix is None else np.dot(transform_matrix, shear_matrix)

        if zx != 1 or zy != 1:
            zoom_matrix = np.array([[zx, 0, 0],
                                    [0, zy, 0],
                                    [0, 0, 1]])
            transform_matrix = zoom_matrix if transform_matrix is None else np.dot(transform_matrix, zoom_matrix)

        if transform_matrix is not None:
            h, w = x.shape[img_row_axis], x.shape[img_col_axis]
            transform_matrix = transform_matrix_offset_center(transform_matrix, h, w)
            x = apply_transform(x, transform_matrix,img_channel_axis,
                                fill_mode=self.fill_mode, cval=self.cval)
        if self.flip_horizontal and random.choice([True, False]):
            x=np.fliplr(x)
        
        if self.salt_and_pepper!=0:
            x=add_salt_and_pepper(x, self.salt_and_pepper)
        if self.contrast_level!=0:
            x=change_contrast(x, self.contrast_level)
        if self.gaussian_noise!=0:
            x=add_gaussian_noise(x, self.gaussian_noise)
        if self.random_crop != 0:
            x = random_crop_img(x,self.random_crop)

        return x