import numpy as np

def main(input_grid):
    # Find the positions of the non-zero elements
    non_zero_elements = np.unique(input_grid)[1:] # exclude 0
    positions = [np.argwhere(input_grid == i) for i in non_zero_elements]
    
    # Find the longest list of positions
    longest_list = max(positions, key=len)
    
    # Find the max and min positions in the longest list
    max_pos = np.amax(longest_list, axis=0)
    min_pos = np.amin(longest_list, axis=0)
    
    # Extract the block from the input grid using the max and min positions
    output_grid = input_grid[min_pos[0]:max_pos[0]+1, min_pos[1]:max_pos[1]+1]
    
    return output_grid