import torch.nn as nn


class SmallCNN(nn.Module):
    def __init__(self, num_classes, drop_rate):
        super().__init__()
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        if self.drop_rate > 0.0:
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 16, (3, 3), (1, 1)),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Dropout(p=self.drop_rate),
                nn.Conv2d(16, 16, (3, 3), (1, 1)),
                nn.BatchNorm2d(num_features=16),
                nn.ReLU(),
                nn.Dropout(p=self.drop_rate),
                nn.MaxPool2d((2, 2), (2, 2)),
                nn.Conv2d(16, 32, (3, 3), (1, 1)),
                nn.BatchNorm2d(num_features=32),
                nn.ReLU(),
                nn.Dropout(p=self.drop_rate),
                nn.Conv2d(32, 32, (3, 3), (1, 1)),
                nn.BatchNorm2d(num_features=32),
                nn.ReLU(),
                nn.Dropout(p=self.drop_rate),
                nn.MaxPool2d((2, 2), (2, 2)),
                nn.Flatten()
            )

        else:
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 16, (3, 3), (1, 1)),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, (3, 3), (1, 1)),
                nn.BatchNorm2d(num_features=16),
                nn.ReLU(),
                nn.MaxPool2d((2, 2), (2, 2)),
                nn.Conv2d(16, 32, (3, 3), (1, 1)),
                nn.BatchNorm2d(num_features=32),
                nn.ReLU(),
                nn.Conv2d(32, 32, (3, 3), (1, 1)),
                nn.BatchNorm2d(num_features=32),
                nn.ReLU(),
                nn.MaxPool2d((2, 2), (2, 2)),
                nn.Flatten()
            )

        self.linear = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Linear(256, self.num_classes)
        )

    def get_grad_cam_target_layer(self):
        return self.backbone[-3]

    def forward(self, x):
        features = self.backbone(x)
        logits = self.linear(features)

        return logits
