import torch
import torch.nn.functional as F
import numpy as np
import random

def average_missing_pixels(masked_image, mask, window_size=(4, 4)):
    padded_image = F.pad(masked_image, (window_size[1]//2, window_size[1]//2, window_size[0]//2, window_size[0]//2), mode='constant')
    # get the padded mask
    padded_mask = torch.zeros_like(padded_image)
    padded_mask[:, :, window_size[0] // 2: padded_mask.shape[2] - window_size[0] // 2, window_size[1] // 2: padded_mask.shape[3] - window_size[1] // 2] = mask
    avg_tensor = padded_image.clone()
    kernel = torch.ones((3, 1) + window_size, device=masked_image.device, dtype=masked_image.dtype)

    neighborhood_sum = F.conv2d(padded_image, kernel, stride=1, padding='same', groups=3)
    neighborhood_mask_sum = F.conv2d(padded_mask, kernel, stride=1, padding='same', groups=3)
    # pointwise maximum with 1 to avoid division by zero
    neighborhood_mask_sum = torch.max(torch.ones_like(neighborhood_mask_sum), neighborhood_mask_sum)
    
    # Calculate the average of the neighborhood pixels
    avg_tensor = neighborhood_sum / neighborhood_mask_sum

    img_size = masked_image.shape[-2]
    avg_tensor = avg_tensor * (1 - padded_mask) + padded_image * padded_mask
    filled_image = avg_tensor[:, :, window_size[0] // 2: img_size + window_size[0]//2, window_size[0] // 2: img_size + window_size[0]//2]
    return filled_image


def get_random_mask(image_shape, survival_probability, mask_full_rgb=False, same_for_all_batch=False, device='cuda', seed=None):
    if seed is not None:
        np.random.seed(seed)
    if same_for_all_batch:
        corruption_mask = np.random.binomial(1, survival_probability, size=image_shape[1:]).astype(np.float32)
        corruption_mask = torch.tensor(corruption_mask, device=device, dtype=torch.float32).repeat([image_shape[0], 1, 1, 1])
    else:
        corruption_mask = np.random.binomial(1, survival_probability, size=image_shape).astype(np.float32)
        corruption_mask = torch.tensor(corruption_mask, device=device, dtype=torch.float32)
    
    if mask_full_rgb:
        corruption_mask = corruption_mask[:, 0]
        corruption_mask = corruption_mask.repeat([3, 1, 1, 1]).transpose(1, 0)
    return corruption_mask

def get_mask_subset(mask, survival_probability, same_for_all_batch=False, mask_full_rgb=False):
    """
        Returns a mask that has the same size as the input mask but it only keeps survival_probability * mask.size() pixels.
    """
    # if we want more pixels than what we already have, just return what we already have.
    if survival_probability * mask.numel() > mask.sum():
        return mask

    batch_size = mask.shape[0]

    if mask_full_rgb:
        mask = torch.clone(mask[:, 0])

    if same_for_all_batch:
        mask = torch.clone(mask[0])
          
    num_on_locations = int(survival_probability * mask.numel())
    
    flattened_mask = mask.view(-1)
    unmasked_indices = torch.nonzero(flattened_mask == 1)
    indices_to_keep = torch.tensor(random.sample(unmasked_indices.tolist(), num_on_locations), device=mask.device).squeeze()

    subset_mask = torch.zeros_like(flattened_mask)
    subset_mask[indices_to_keep] = 1.0
    subset_mask = subset_mask.reshape(mask.shape)

    if same_for_all_batch:
        subset_mask = subset_mask.repeat(batch_size, *[1 for _ in range(len(subset_mask.shape))])
    
    if mask_full_rgb:
        subset_mask = subset_mask.repeat(3, 1, 1, 1).transpose(1, 0)

    return subset_mask


def refresh_mask(mask_1, mask_2, refresh_rate, mask_full_rgb=False, same_for_all_batch=False, replace=True):
    """
        Replace refresh_rate of the unmasked_pixels in mask_1 with masked_pixels from mask_2.
    """
    batch_size = mask_1.shape[0]

    if mask_full_rgb:
        mask_1 = torch.clone(mask_1[:, 0])
        mask_2 = torch.clone(mask_2[:, 0])

    if same_for_all_batch:
        mask_1 = torch.clone(mask_1[0])
        mask_2 = torch.clone(mask_2[0])

    flattened_mask_1 = mask_1.view(-1)
    flattened_mask_2 = mask_2.view(-1)

    # refresh refresh_rate of the gt_corruption_mask
    masked_indices = torch.nonzero(flattened_mask_2 == 0)
    unmasked_indices_in_subset_mask = torch.nonzero(flattened_mask_1 == 1)

    num_pixels_to_modify = min(int(refresh_rate * mask_1.numel()), masked_indices.shape[0])
    num_pixels_to_modify = min(num_pixels_to_modify, unmasked_indices_in_subset_mask.shape[0])

    if num_pixels_to_modify == 0:
        return None, 0


    # mask refresh_rate pixels that were previously unmasked
    selected_indices_to_mask = torch.tensor(random.sample(unmasked_indices_in_subset_mask.tolist(), k=num_pixels_to_modify), device=mask_1.device).squeeze()
    # unmask refresh_rate pixels that were previously masked
    selected_indices_to_unmask = torch.tensor(random.sample(masked_indices.tolist(), k=num_pixels_to_modify), device=mask_1.device).squeeze()
    if replace:
        flattened_mask_1[selected_indices_to_mask] = 0.0
    flattened_mask_1[selected_indices_to_unmask] = 1.0

    subset_mask = flattened_mask_1.reshape(mask_1.shape)

    if same_for_all_batch:
        subset_mask = subset_mask.repeat(batch_size, *[1 for _ in range(len(subset_mask.shape))])
    
    if mask_full_rgb:
        subset_mask = subset_mask.repeat(3, 1, 1, 1).transpose(1, 0)
    
    return subset_mask, num_pixels_to_modify



def get_grid_positions(mask, comp_value=0):
    """
    Takes as input a PyTorch tensor mask (num_channels, image_size, image_size), flattens it, checks 
    in which indices the mask == comp_value and returns a list of tuples with the grid positions.
    """
    flat_mask = mask.view(-1)  # flatten mask
    masked_indices = torch.nonzero(flat_mask == comp_value).squeeze()  # get indices where mask == 0
    positions = []
    for idx in masked_indices:
        # calculate grid positions using flattened index and image size
        row_idx = (idx % (mask.shape[1] * mask.shape[2])) // mask.shape[2]
        col_idx = (idx % (mask.shape[1] * mask.shape[2])) % mask.shape[2]
        positions.append((row_idx, col_idx))
    return torch.tensor(positions, device=mask.device)




def refresh_mask_set_distance(mask_1, mask_2, refresh_rate, mask_full_rgb=False, same_for_all_batch=False, replace=True):
    """
        Replace refresh_rate of the unmasked_pixels in mask_1 with masked_pixels from mask_2.
        It works by finding the masked_pixels in mask_2 that are closer to the unmasked pixels in mask_1. 
    """
    batch_size = mask_1.shape[0]

    if same_for_all_batch:
        mask_1 = torch.clone(mask_1[0])
        mask_2 = torch.clone(mask_2[0])
    else:
        raise NotImplementedError("The function refresh_mask_nearest_neighbor only works for same_for_all_batch=True for now.")
    
    masked_grid_positions = get_grid_positions(mask_2, comp_value=0)
    unmasked_grid_positions = get_grid_positions(mask_1, comp_value=1)

    # Find the pairwise differences
    diffs = masked_grid_positions[:, None, :] - unmasked_grid_positions[None, :, :]

    # Take the absolute values and sum them along the last dimension
    abs_diffs = torch.abs(diffs).sum(dim=-1)

    # Calculate the total distances to unmasked positions (sum the absolute differences along the second dimension)
    distances_to_unmasked = abs_diffs.sum(dim=-1)

    flat_mask_1 = mask_1.view(-1)
    flat_mask_2 = mask_2.view(-1)
    masked_indices = torch.nonzero(flat_mask_2 == 0).squeeze()
    sorted_masked_indices = masked_indices[torch.argsort(distances_to_unmasked)]
    # select refresh_rate of them and make them 1 in flat_mask_1
    num_refreshed = int(sorted_masked_indices.numel() * refresh_rate)
    selected_indices = sorted_masked_indices[:num_refreshed]
    flat_mask_1[selected_indices] = 1.0

    subset_mask = flat_mask_1.reshape(mask_1.shape)

    if same_for_all_batch:
        subset_mask = subset_mask.repeat(batch_size, *[1 for _ in range(len(subset_mask.shape))])
    
    if mask_full_rgb:
        subset_mask = subset_mask.repeat(3, 1, 1, 1).transpose(1, 0)
    
    return subset_mask, num_refreshed




def get_box_mask(image_shape, survival_probability, same_for_all_batch=False, device='cuda'):
    """Creates a mask with a box of size survival_probability * image_shape[1] somewhere randomly in the image.
        Args:
            image_shape: (batch_size, num_channels, height, width)
            survival_probability: probability of a pixel being unmasked
            same_for_all_batch: if True, the same mask is applied to all images in the batch
            device: device to use for the mask
        Returns:
            mask: (batch_size, num_channels, height, width)
    """
    batch_size = image_shape[0]
    num_channels = image_shape[1]
    height = image_shape[2]
    width = image_shape[3]

    # create a mask with the same size as the image
    mask = torch.zeros((batch_size, num_channels, height, width), device=device)

    # decide where to place the box randomly -- set the box at a different location for each image in the batch
    box_start_row = torch.randint(0, height, (batch_size, 1, 1), device=device)
    box_start_col = torch.randint(0, width, (batch_size, 1, 1), device=device)
    box_height = torch.ceil(torch.tensor((1 - survival_probability) * height)).int()
    box_width = torch.ceil(torch.tensor((1 - survival_probability) * width)).int()
    
    
    # mask[:, :, box_start_row:box_start_row + box_height, box_start_col:box_start_col + box_width] = 1.0

    box_start_row_expanded = box_start_row.view(batch_size, 1, 1, 1)
    box_start_col_expanded = box_start_col.view(batch_size, 1, 1, 1)

    rows = torch.arange(height, device=device).view(1, 1, -1, 1).expand_as(mask)
    cols = torch.arange(width, device=device).view(1, 1, 1, -1).expand_as(mask)

    inside_box_rows = (rows >= box_start_row_expanded) & (rows < (box_start_row_expanded + box_height))
    inside_box_cols = (cols >= box_start_col_expanded) & (cols < (box_start_col_expanded + box_width))

    inside_box = inside_box_rows & inside_box_cols
    mask[inside_box] = 1.0
    
    return 1 - mask


def get_patch_mask(image_shape, crop_size, same_for_all_batch=False, device='cuda'):
    """
        Args:
            image_shape: (batch_size, num_channels, height, width)
            crop_size: probability of a pixel being unmasked
            same_for_all_batch: if True, the same mask is applied to all images in the batch
            device: device to use for the mask
        Returns:
            mask: (batch_size, num_channels, height, width)
    """
    batch_size = image_shape[0]
    num_channels = image_shape[1]
    height = image_shape[2]
    width = image_shape[3]

    # create a mask with the same size as the image
    mask = torch.zeros((batch_size, num_channels, height, width), device=device)

    max_x = width - crop_size
    max_y = height - crop_size
    box_start_row = torch.randint(0, max_x, (batch_size, 1, 1, 1), device=device)
    box_start_col = torch.randint(0, max_y, (batch_size, 1, 1, 1), device=device)

    rows = torch.arange(height, device=device).view(1, 1, -1, 1).expand_as(mask)
    cols = torch.arange(width, device=device).view(1, 1, 1, -1).expand_as(mask)

    inside_box_rows = (rows >= box_start_row) & (rows < (box_start_row + crop_size))
    inside_box_cols = (cols >= box_start_col) & (cols < (box_start_col + crop_size))
    inside_box = inside_box_rows & inside_box_cols
    mask[inside_box] = 1.0
    
    return mask


def get_hat_patch_mask(patch_mask, crop_size, hat_crop_size, same_for_all_batch=False, device='cuda'):
    hat_mask = get_patch_mask((patch_mask.shape[0], patch_mask.shape[1], crop_size, crop_size), hat_crop_size, same_for_all_batch=same_for_all_batch, device=device)
    patch_indices = torch.nonzero(patch_mask.view(-1) == 1).squeeze()
    expanded_hat_mask = hat_mask.view(-1).expand_as(patch_indices)
    hat_patch_mask = torch.clone(patch_mask)
    hat_patch_mask.view(-1)[patch_indices] = expanded_hat_mask
    hat_patch_mask = hat_patch_mask.reshape(patch_mask.shape)
    return hat_patch_mask