import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the color of the lines
    line_color = None
    for i in range(input_grid.shape[0]):
        for j in range(input_grid.shape[1]):
            if input_grid[i][j] != black:
                line_color = input_grid[i][j]
                break
        if line_color is not None:
            break
    
    # Find the row and column that is all in the line color
    line_row = None
    line_col = None
    for i in range(input_grid.shape[0]):
        if np.all(input_grid[i] == line_color):
            line_row = i
            break
    for j in range(input_grid.shape[1]):
        if np.all(input_grid[:,j] == line_color):
            line_col = j
            break
    
    # Create the output grid
    output_grid = np.copy(input_grid)
    output_grid[line_row,:] = line_color
    output_grid[:,line_col] = line_color
    
    # Find all pixels in colors other than black and the line color
    colors = set(np.unique(input_grid)) - {black, line_color}
    for color in colors:
        # Find all pixels in this color
        pixels = np.argwhere(input_grid == color)
        for i in range(pixels.shape[0]):
            for j in range(i+1, pixels.shape[0]):
                # If any two pixels are in the same row or column
                if pixels[i][0] == pixels[j][0] or pixels[i][1] == pixels[j][1]:
                    # Turn all black pixels between these two pixels into this color
                    if pixels[i][0] == pixels[j][0]:
                        black_pixels = np.argwhere((input_grid[pixels[i][0], pixels[i][1]+1:pixels[j][1]] == black) & (output_grid[pixels[i][0], pixels[i][1]+1:pixels[j][1]] != color))
                        output_grid[pixels[i][0], pixels[i][1]+1:pixels[j][1]][black_pixels] = color
                    else:
                        black_pixels = np.argwhere((input_grid[pixels[i][0]+1:pixels[j][0], pixels[i][1]] == black) & (output_grid[pixels[i][0]+1:pixels[j][0], pixels[i][1]] != color))
                        output_grid[pixels[i][0]+1:pixels[j][0], pixels[i][1]][black_pixels] = color
    
    return output_grid