import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import othermodels


class Normalize(nn.Module):
    def __init__(self, power=2):
        super(Normalize, self).__init__()
        self.power = power

    def forward(self, x):
        norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
        out = x.div(norm)
        return out


class Flatten(nn.Module):
    """A shape adaptation layer to patch certain networks."""
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.shape[0], -1)


class Unsqueeze(nn.Module):
    """A shape adaptation layer to patch certain networks."""
    def __init__(self):
        super(Unsqueeze, self).__init__()

    def forward(self, x):
        return x.unsqueeze(-1)


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x


def random_weight_init(model):
    for m in model.modules():
        if isinstance(m, nn.Conv3d):
            m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None: 
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm3d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()


def get_video_feature_extractor(vid_base_arch='r2plus1d_18', pretrained=False):
    if vid_base_arch =='r2plus1d_18':
        model = torchvision.models.video.__dict__[vid_base_arch](pretrained=pretrained)
        if not pretrained:
            print("Randomy initializing models")
            random_weight_init(model)
        model.fc = nn.Identity()
    elif vid_base_arch =='r2plus1d_34':
        model = othermodels.r2plus1d_34()
        model.fc = Identity()
    elif vid_base_arch =='r2plus1d_50':
        model = othermodels.r2plus1d_50()
        model.fc = Identity()
    else:
        model = othermodels.S3DG()
        model.fc = Identity()
    return model


def get_audio_feature_extractor(aud_base_arch='resnet18', pretrained=False):
    assert(aud_base_arch in ['resnet9', 'resnet18', 'resnet34', 'resnet50', 'vgg_audio'])
    if aud_base_arch in ['resnet18', 'resnet34', 'resnet50']:
        model = torchvision.models.__dict__[aud_base_arch](pretrained=False)
        model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        if pretrained:
            model.fc = nn.Identity() #nn.Linear(512, 400)
        else:
            model.fc = nn.Identity()
        return model
    elif aud_base_arch == 'resnet9':
        model = torchvision.models.resnet._resnet('resnet9', torchvision.models.resnet.BasicBlock,
                                                 [1,1,1,1], pretrained=False,progress=False)
        model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        if pretrained:
            model.fc = nn.Identity() #nn.Linear(512, 400)
        else:
            model.fc = nn.Identity()
        return model
    elif aud_base_arch == 'vgg_audio':
        model = othermodels.VGG16AudioNet()
        return model


class VideoBaseNetwork(nn.Module):
    def __init__(self, vid_base_arch='r2plus1d_18', pretrained=False, norm_feat=False):
        super(VideoBaseNetwork, self).__init__()
        self.base = get_video_feature_extractor(
            vid_base_arch, 
            pretrained=pretrained
        )
        self.norm_feat = norm_feat
    def forward(self, x):
        x = self.base(x).squeeze()
        if self.norm_feat:
            x = F.normalize(x, p=2, dim=1)
        return x


class AudioBaseNetwork(nn.Module):
    def __init__(self, aud_base_arch='resnet18', pretrained=False, norm_feat=False):
        super(AudioBaseNetwork, self).__init__()
        self.base = get_audio_feature_extractor(
            aud_base_arch, 
            pretrained=pretrained
        )
        self.norm_feat = norm_feat
    def forward(self, x):
        x = self.base(x).squeeze()
        if self.norm_feat:
            x = F.normalize(x, p=2, dim=1)
        return x


class AVC(nn.Module):
    def __init__(self, vid_base_arch='r2plus1d_18', aud_base_arch='resnet18',
                 pretrained=False, norm_feat=False, use_mlp=False,
                 mlptype=0, headcount=1, num_classes=256, return_features=False):
        super(AVC, self).__init__()
        self.video_network = VideoBaseNetwork(
            vid_base_arch, 
            pretrained=pretrained
        )
        self.audio_network = AudioBaseNetwork(
            aud_base_arch, 
            pretrained=pretrained
        )
        self.use_mlp = use_mlp
        self.hc = headcount
        self.norm_feat = norm_feat
        self.return_features = return_features
        encoder_dim = 512 if vid_base_arch in ['r2plus1d_18', 'r2plus1d_34', 'r2plus1d_50'] else 2048
        encoder_dim_a = 2048 if aud_base_arch in ['resnet50'] else 512
        if  self.hc == 1:
            if use_mlp:
                if mlptype == 0:
                    print("Using Regular MLP")
                    self.mlp_v = MLP(encoder_dim, num_classes)
                    self.mlp_a = MLP(encoder_dim_a, num_classes)
                elif mlptype == 1:
                    print("Using MLP to be combined with SyncBN")
                    self.mlp_v = MLPv2(encoder_dim, num_classes)
                    self.mlp_a = MLPv2(encoder_dim_a, num_classes)
            else:
                print("Using Linear Layer")
                self.mlp_v = nn.Linear(encoder_dim, num_classes)
                self.mlp_a = nn.Linear(encoder_dim_a, num_classes)
        elif self.hc > 1:
            if use_mlp:
                if mlptype == 0:
                    print("Using Regular MLP")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp_a%d"%a, MLP(encoder_dim, num_classes))
                        setattr(self, "mlp_v%d"%a, MLP(512, num_classes))
                elif mlptype == 1:
                    print("Using MLP to be combined with SyncBN")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp_a%d"%a, MLPv2(encoder_dim, num_classes))
                        setattr(self, "mlp_v%d"%a, MLPv2(512, num_classes))
                else:
                    print("Using Residual MLP")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp_a%d"%a, MLP_residual(encoder_dim, num_classes))
                        setattr(self, "mlp_v%d"%a, MLP_residual(512, num_classes))
            else:
                for a, i in enumerate(range(self.hc)):
                    setattr(self, "mlp_a%d"%a, nn.Linear(512, num_classes))
                    setattr(self, "mlp_v%d"%a, nn.Linear(512, num_classes))


    def forward(self, img, spec, whichhead=0):
        img_features = self.video_network(img).squeeze()
        aud_features = self.audio_network(spec).squeeze()
        if self.return_features:
            return img_features, aud_features
        if len(aud_features.shape) == 1:
            aud_features = aud_features.unsqueeze(0)
        if len(img_features.shape) == 1:
            img_features = img_features.unsqueeze(0)

        if self.hc == 1:
            img_features = self.mlp_v(img_features)
            aud_features = self.mlp_a(aud_features)
            if self.norm_feat:
                img_features = F.normalize(img_features, p=2, dim=1)
                aud_features = F.normalize(aud_features, p=2, dim=1)
        elif self.hc > 1:
            img_features = [getattr(self, "mlp_v%d"%head)(img_features) for head in range(self.hc)]
            aud_features = [getattr(self, "mlp_a%d"%head)(aud_features) for head in range(self.hc)]
            if self.norm_feat:
                img_features = [ F.normalize(imgf, p=2, dim=1) for imgf in img_features ]
                aud_features = [ F.normalize(audf, p=2, dim=1) for audf in aud_features ]
        return img_features, aud_features

class AVCconcat(nn.Module):
    def __init__(self, vid_base_arch='r2plus1d_18', aud_base_arch='resnet18',
                 pretrained=False, norm_feat=False, use_mlp=False,
                 mlptype=0, headcount=1, num_classes=256, return_features=False, cattype=1):
        super(AVCconcat, self).__init__()
        self.video_network = VideoBaseNetwork(
            vid_base_arch,
            pretrained=pretrained
        )
        self.audio_network = AudioBaseNetwork(
            aud_base_arch,
            pretrained=pretrained
        )
        self.cattype = cattype
        emb_dim = 512*2 if cattype == 1 else 512*3
        self.use_mlp = use_mlp
        self.hc = headcount
        self.norm_feat = norm_feat
        self.return_features = return_features

        if  self.hc == 1:
            if use_mlp:
                if mlptype == 0:
                    print("Using Regular MLP")
                    self.mlp = MLP(emb_dim, num_classes)
                elif mlptype == 1:
                    print("Using MLP to be combined with SyncBN")
                    self.mlp = MLPv2(emb_dim, num_classes)
                self.mlp_v = self.mlp
                self.mlp_a = self.mlp
            else:
                print("Using Linear Layer")
                self.mlp = nn.Linear(emb_dim, num_classes)
                self.mlp_v = self.mlp
                self.mlp_a = self.mlp
        elif self.hc > 1:
            if use_mlp:
                if mlptype == 0:
                    print("Using Regular MLP")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp%d"%a, MLP(emb_dim, num_classes))
                        setattr(self, "mlp_v%d"%a, getattr(self,"mlp%s"%a))
                        setattr(self, "mlp_a%d" % a, getattr(self, "mlp%s" % a))
                elif mlptype == 1:
                    print("Using MLP to be combined with SyncBN")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp%d"%a, MLPv2(emb_dim, num_classes))
                        setattr(self, "mlp_v%d" % a, getattr(self, "mlp%s" % a))
                        setattr(self, "mlp_a%d" % a, getattr(self, "mlp%s" % a))
                else:
                    print("Using Residual MLP")
                    for a, i in enumerate(range(self.hc)):
                        setattr(self, "mlp%d"%a, MLP_residual(emb_dim, num_classes))
                        setattr(self, "mlp_v%d" % a, getattr(self, "mlp%s" % a))
                        setattr(self, "mlp_a%d" % a, getattr(self, "mlp%s" % a))
            else:
                for a, i in enumerate(range(self.hc)):
                    setattr(self, "mlp%d"%a, nn.Linear(emb_dim, num_classes))
                    setattr(self, "mlp_v%d" % a, getattr(self, "mlp%s" % a))
                    setattr(self, "mlp_a%d" % a, getattr(self, "mlp%s" % a))


    def forward(self, img, spec):
        img_features = self.video_network(img).squeeze()
        aud_features = self.audio_network(spec).squeeze()
        if self.cattype == 1:
            img_features = torch.cat([img_features,aud_features], dim=1) # put to img_features to save memory
        else:
            img_features = torch.cat([img_features,aud_features,aud_features-img_features], dim=1) # put to img_features to save memory
        if self.hc == 1:
            img_features = self.mlp(img_features)
        elif self.hc > 1:
            img_features = [getattr(self, "mlp%d"%head)(img_features) for head in range(self.hc)]
        return img_features, img_features


class MLP(nn.Module):
    """from AMDIM
    https://github.com/Philip-Bachman/amdim-public/blob/e01b424ec8d985a906e16580048c114d77ced94c/model.py
    """
    def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
        super(MLP, self).__init__()
        self.n_input = n_input
        self.n_classes = n_classes
        self.n_hidden = n_hidden
        if n_hidden is None:
            # use linear classifier
            self.block_forward = nn.Sequential(
                Flatten(),
                nn.Dropout(p=p),
                nn.Linear(n_input, n_classes, bias=True)
            )
        else:
            # use simple MLP classifier
            self.block_forward = nn.Sequential(
                Flatten(),
                nn.Dropout(p=p),
                nn.Linear(n_input, n_hidden, bias=False),
                nn.BatchNorm1d(n_hidden),
                nn.ReLU(inplace=True),
                nn.Dropout(p=p),
                nn.Linear(n_hidden, n_classes, bias=True)
            )

    def forward(self, x):
        return self.block_forward(x)


class MLPv2(nn.Module):
    """from AMDIM
    https://github.com/Philip-Bachman/amdim-public/blob/e01b424ec8d985a906e16580048c114d77ced94c/model.py
    """
    def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
        super(MLPv2, self).__init__()
        self.n_input = n_input
        self.n_classes = n_classes
        self.n_hidden = n_hidden
        if n_hidden is None:
            # use linear classifier
            self.block_forward = nn.Sequential(
                Flatten(),
                nn.Dropout(p=p),
                nn.Linear(n_input, n_classes, bias=True)
            )
        else:
            # use simple MLP classifier
            self.block_forward = nn.Sequential(
                Flatten(),
                nn.Dropout(p=p),
                nn.Linear(n_input, n_hidden, bias=False),
                Unsqueeze(),
                nn.BatchNorm1d(n_hidden),
                Flatten(),
                nn.ReLU(inplace=True),
                nn.Dropout(p=p),
                nn.Linear(n_hidden, n_classes, bias=True)
            )

    def forward(self, x):
        return self.block_forward(x)


class MLP_residual(nn.Module):
    """residual connection MLP"""
    def __init__(self, in_planes, planes):
        super(MLP_residual, self).__init__()
        withbias = True
        self.block = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(in_planes, planes, bias=withbias),
            nn.BatchNorm1d(planes),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(planes, planes, bias=withbias),
            nn.BatchNorm1d(planes),
            nn.ReLU(inplace=True)
        )

        self.shortcut = nn.Sequential()
        if in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Linear(in_planes, planes, bias=False),
                nn.BatchNorm1d(planes)
            )

    def forward(self, x):
        return self.block(x) + self.shortcut(x)


class VideoWrapper(torch.nn.Module):
    def __init__(
        self, vid_net, mlp_v, norm_feat=False, num_clusters=256, hc=1
    ):
        super(VideoWrapper, self).__init__()
        self.video_network = vid_net
        self.audio_network = torch.nn.Identity()
        self.hc = hc
        if self.hc == 1:
            if mlp_v is None:
                self.mlp_v = nn.Linear(512, num_clusters)
            else:
                self.mlp_v = mlp_v
        else:
            if mlp_v is None:
                for a, i in enumerate(range(self.hc)):
                    setattr(self, "mlp_v%d" % a, nn.Linear(512, num_clusters))
            else:
                for a, i in enumerate(range(self.hc)):
                    setattr(self, "mlp_v%d" % a, mlp_v[a])
        self.return_features = False

        self.norm_feat = norm_feat
    
    def forward(self, img, spec=None):
        img_features = self.video_network(img).squeeze()
        aud_features = None # self.audio_network(spec).squeeze()
        if self.return_features:
            return img_features, aud_features
        if len(img_features.shape) == 1:
            img_features = img_features.unsqueeze(0)

        if self.hc == 1:
            img_features = self.mlp_v(img_features)
            if self.norm_feat:
                img_features = F.normalize(img_features, p=2, dim=1)
        elif self.hc > 1:
            img_features = [getattr(self, "mlp_v%d"%head)(img_features) for head in range(self.hc)]
            if self.norm_feat:
                img_features = [ F.normalize(imgf, p=2, dim=1) for imgf in img_features ]
        return img_features, None  # , aud_features


if __name__ == '__main__':
    l3 = AVC(vid_base_arch='r2plus1d_18', aud_base_arch='vgg_audio', pretrained=False, num_classes=[3000, 3000])
    img = torch.rand(1, 3, 25, 112, 112)
    aud = torch.rand(1, 1, 40, 99)
    out_a, out_v = l3(img, aud)
    print(out_a.shape)
    print(out_v.shape)
