import torch.nn as nn
from .build import MODELS
from modules.voxelconv import ResVoxelConv, VoxelConv


@MODELS.register_module()
class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()

        if config.get('with_color', False):
            input_channel = 4
        else:
            input_channel = 1
        base_dim = config.base_dim
        layers = [
            VoxelConv(input_channel, base_dim, pooling=True, with_se=False),
            ResVoxelConv(base_dim, with_se=False),
            VoxelConv(base_dim, base_dim * 2, pooling=True, with_se=False),
            ResVoxelConv(base_dim * 2, with_se=False),
            VoxelConv(base_dim * 2, 1, pooling=False, with_se=False),
        ]
        self.blocks = nn.Sequential(*layers)
        self.apply(weights_init)

    def forward(self, x):
        x = self.blocks(x)
        return x


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv3d') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm3d') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
