from torch import nn
from torch.nn.functional import relu


class PrefixSumInputPreprocessor(nn.Module):
    def __init__(self, width, kernel_size, stride, padding, bias):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=width, kernel_size=kernel_size, stride=(stride,),
                               padding=padding, bias=bias)
        self.out_features = None

    def forward(self, xs):
        out = relu(self.conv1(xs))
        self.out_features = list(out.shape[1:])

        return out
