from torch import nn, optim
import torch
import torch.nn.functional as F
import numpy as np


# Define model
class CNN(nn.Module):
    def __init__(self, input_shape, num_classes):
        super(CNN, self).__init__()

        # Conv Layer 1
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Conv Layer 2
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(5, 5), stride=1, padding=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Conv Layer 3
        self.conv3 = nn.Conv2d(32, 64, kernel_size=(7, 7), stride=1, padding=3)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d((1, 1))

        # Fully connected layers
        self.fc1 = nn.Linear(64, 128)
        self.drop = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 64)

        # Fully connected layers for arousal and valence
        self.fc_arousal = nn.Linear(64, num_classes)
        self.fc_valence = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))

        x = self.gap(x)  # Global Average Pooling
        x = x.view(x.size(0), -1)  # Flatten the tensor

        x = self.drop(F.relu(self.fc1(x)))
        x = F.relu(self.fc2(x))

        arousal = self.fc_arousal(x)
        valence = self.fc_valence(x)

        return arousal, valence
