import random
import numpy as np


def gen_input(n_max=10, m_max=10):
    n = random.randint(1, n_max)
    m = random.randint(1, m_max)
    grid = np.random.randint(0, 16, size=(n, m))
    grid_str = '\n'.join((' '.join((str(cell) for cell in row)) for row in grid))
    return f'{n} {m}\n{grid_str}'

def batch_gen_inputs(batch_size,):
    batch_inputs = [gen_input() for _ in range(batch_size)]
    return batch_inputs
