
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net_2out(nn.Module):
    # the network has 2 outputs: x,y
    # the input is a CANN with 512 neurons
    # the label is a scalar
    # the network has 3 layers: 512->100->20->2
    # the readout layer is a linear layer with relu activation

    def __init__(self, num_CANN=512):
        # num_CANN: number of neurons in CANN
        # num_label: number of neurons in label

        super(Net_2out, self).__init__()
        self.layer1 = nn.Linear(num_CANN, 100, bias=True)
        self.layer2 = nn.Linear(100, 20, bias=True)
        self.readout = nn.Linear(20, 2)
        # readout is a linear layer which outputs 2 values:x,y in each time step

    def forward(self, x):
        # x: CANN output at each time step
        x = F.relu(self.layer1(x))
        x = torch.tanh(self.layer2(x))
        x = self.readout(x)
        return x
