'''
Implementation of a simple DeepSet Module (to be used later).
'''

import torch.nn as nn
import torch.nn.functional as F


class SimpleDeepSet(nn.Module):
    '''
    A very simple DeepSet module - project all elements of the set and average them in the end.
    '''
    def __init__(self, input_dim=256, output_dim=256):
        super().__init__()
        self.lin1 = nn.Linear(input_dim, output_dim)
        self.act1 = nn.ReLU()
        self.lin2 = nn.Linear(output_dim, output_dim)
        self.act2 = nn.ReLU()

    def forward(self, x):
        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)
        # transform elements of the set, then pool them together
        return x.mean(dim=0)

