import os
import torch
import torch.nn as nn
from torchvision import models
from termcolor import cprint
import torch.nn.functional as F
from utils import set_trainable

class VGG19_Finetune_ConvBase(nn.Module):
    def __init__(self):
        super(VGG19_Finetune_ConvBase, self).__init__()
        """ Shared CONV1-5
        """
        original_model = models.vgg19(pretrained=True).features
        self.conv1 = nn.Sequential(*list(original_model.children())[:5])
        self.conv2 = nn.Sequential(*list(original_model.children())[5:10])
        set_trainable(self.conv1, requires_grad=False)
        set_trainable(self.conv2, requires_grad=False)

        self.conv3 = nn.Sequential(*list(original_model.children())[10:19])
        self.conv4 = nn.Sequential(*list(original_model.children())[19:28])
        # conv5_NoMaxPooling
        self.conv5 = nn.Sequential(*list(original_model.children())[28:-1])

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x


class VGG19_Finetune_STEP1and2_2Group(nn.Module):
    def __init__(self, opt):
        super(VGG19_Finetune_STEP1and2_2Group, self).__init__()

        self.margin = 0.02
        self.convBase = VGG19_Finetune_ConvBase()

        """ STEP1 modules VGG19 classification 150-ways
        """
        self.MaxPooling = nn.MaxPool2d(kernel_size=2, stride=2)
        org_classifier = models.vgg19(pretrained=True).classifier
        self.classifier = nn.Sequential(
            *list(org_classifier.children())[:-1]
        )
        self.FC_Top = nn.Linear(4096, opt.train_ncls)
        self._initialize_weights(self.FC_Top)


        """ STEP2 modules MA_CNN 
        """
        self.avgpool = nn.AvgPool2d(28, stride=1)
        self.avgpool_mask = nn.AvgPool2d((512, 1), stride=1)
        self.tanh = nn.Tanh()
        # self.softmax = nn.Softmax(dim=1)

        self.part_FC_Top_1 = nn.Linear(28 * 28, opt.train_ncls)
        self.part_FC_Top_2 = nn.Linear(28 * 28, opt.train_ncls)

        self.sig_mask_1 = nn.Sequential(
            nn.Linear(512, 512),
            nn.Tanh(),
            nn.Linear(512, 512),
            nn.Sigmoid()
        )
        self.sig_mask_2 = nn.Sequential(
            nn.Linear(512, 512),
            nn.Tanh(),
            nn.Linear(512, 512),
            nn.Sigmoid()
        )

        self._initialize_weights(self.sig_mask_1)
        self._initialize_weights(self.sig_mask_2)
        self._initialize_weights(self.part_FC_Top_1)
        self._initialize_weights(self.part_FC_Top_2)


    def _initialize_weights(self, sub_net):
        for m in sub_net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, FLAG):
        if FLAG == 'STEP1':
            return self.forward_STEP1(x)
        elif FLAG == 'STEP2_Phase1':
            return self.forward_STEP2_Phase1(x)
        elif FLAG == 'STEP2_Phase2':
            return self.forward_STEP2_Phase2(x)
        else:
            raise ValueError('Wrong FLAG...')

    """ STEP1 modules VGG19 classification 150-ways
    """
    def forward_STEP1(self, x):
        x = self.convBase(x)
        x = self.MaxPooling(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        x = self.FC_Top(x)
        return x

    def extract_CONV5_4_feature(self, x):
        x = self.convBase(x)
        return x

    """ STEP2 modules MA_CNN 
    """
    def forward_STEP2_Phase1(self, x):
        x = self.convBase(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        sig_1 = self.sig_mask_1(x)
        sig_2 = self.sig_mask_2(x)
        return sig_1, sig_2

    def forward_STEP2_Phase2(self, x):
        # batch_size = x.shape[0]
        # n * 512 * 28 * 28
        conv5_4 = self.convBase(x)
        # n * 512 * 784
        conv5_4_reshaped = conv5_4.view(conv5_4.shape[0], conv5_4.shape[1], -1)

        """ get sig mask # n * 512 
        """
        feat_vect = self.avgpool(conv5_4)  # 1*1*512
        feat_vect = feat_vect.view(feat_vect.size(0), -1)  # 512
        sig_1_org = self.sig_mask_1(feat_vect)
        sig_2_org = self.sig_mask_2(feat_vect)

        # sig_1 = F.dropout(sig_1_org, p=0.5, training=self.training)
        # sig_2 = F.dropout(sig_2_org, p=0.5, training=self.training)
        sig_1 = sig_1_org
        sig_2 = sig_2_org

        """ get MASK  # n * 28 * 28 
        """
        Mask_1 = conv5_4_reshaped * sig_1.unsqueeze(2).repeat((1, 1, 784))
        Mask_1 = self.avgpool_mask(Mask_1).squeeze()
        Mask_1 = Mask_1
        Mask_1 = self.tanh(Mask_1)  #.view(Mask_1.shape[0], 28, 28)


        Mask_2 = conv5_4_reshaped * sig_2.unsqueeze(2).repeat((1, 1, 784))
        Mask_2 = Mask_2
        Mask_2 = self.avgpool_mask(Mask_2).squeeze()
        Mask_2 = self.tanh(Mask_2)  #.view(Mask_2.shape[0], 28, 28)

        # normalized Mask_1, Mask_2
        max_value, _ = torch.max(Mask_1, dim=1, keepdim=True)
        max_value_matrix = max_value.repeat(1, Mask_1.shape[1])
        Mask_1_norm = Mask_1 / max_value_matrix
        max_value, _ = torch.max(Mask_2, dim=1, keepdim=True)
        max_value_matrix = max_value.repeat(1, Mask_2.shape[1])
        Mask_2_norm = Mask_2 / max_value_matrix

        """ get part_feat  # n * 784, cls_top
        """
        # orginal code this is sum instead of average
        attention = self.avgpool_mask(conv5_4_reshaped).squeeze() * 512
        part_feat_1 = attention * Mask_1_norm
        part_feat_2 = attention * Mask_2_norm

        part_cls_1 = self.part_FC_Top_1(part_feat_1)
        part_cls_2 = self.part_FC_Top_2(part_feat_2)

        return part_cls_1+part_cls_2, part_cls_2, \
               Mask_1, Mask_2,\
               Mask_1_norm, Mask_2_norm,\
               sig_1_org, sig_2_org, \
               torch.mean(conv5_4, dim=1)

    def get_Mask_test(self, x):
        # n * 512 * 28 * 28
        # n * 512 * 28 * 28
        conv5_4 = self.convBase(x)
        # n * 512 * 784
        conv5_4_reshaped = conv5_4.view(conv5_4.shape[0], conv5_4.shape[1], -1)

        """ get sig mask # n * 512 
        """
        feat_vect = self.avgpool(conv5_4)  # 1*1*512
        feat_vect = feat_vect.view(feat_vect.size(0), -1)  # 512
        sig_1 = self.sig_mask_1(feat_vect)
        sig_2 = self.sig_mask_2(feat_vect)

        """ get MASK  # n * 28 * 28 
        """
        Mask_1 = conv5_4_reshaped * sig_1.unsqueeze(2).repeat((1, 1, 784))
        Mask_1 = self.avgpool_mask(Mask_1).squeeze()
        Mask_1 = self.softmax(Mask_1)  #.view(Mask_1.shape[0], 28, 28)

        Mask_2 = conv5_4_reshaped * sig_2.unsqueeze(2).repeat((1, 1, 784))
        Mask_2 = self.avgpool_mask(Mask_2).squeeze()
        Mask_2 = self.softmax(Mask_2)  #.view(Mask_2.shape[0], 28, 28)

        return Mask_1, Mask_2



class VGG19_Finetune_STEP1and2_2Group_New(nn.Module):
    def __init__(self, opt):
        super(VGG19_Finetune_STEP1and2_2Group_New, self).__init__()

        self.margin = 0.02
        self.convBase = VGG19_Finetune_ConvBase()

        """ STEP1 modules VGG19 classification N-ways
        """
        self.MaxPooling = nn.MaxPool2d(kernel_size=2, stride=2)
        org_classifier = models.vgg19(pretrained=True).classifier
        self.classifier = nn.Sequential(
            *list(org_classifier.children())[:-1]
        )
        self.FC_Top = nn.Linear(4096,  opt.train_ncls)
        self._initialize_weights(self.FC_Top)

        """ STEP2 modules MA_CNN 
        """
        # self.avgpool = nn.AvgPool2d(28, stride=1)
        # self.avgpool_mask = nn.AvgPool2d((512, 1), stride=1)
        self.MaxPooling4x4 = nn.MaxPool2d(kernel_size=4, stride=4)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        # self.softmax = nn.Softmax(dim=1)

        # self.part_FC_Top_1 = nn.Linear(28 * 28, 150)
        # self.part_FC_Top_2 = nn.Linear(28 * 28, 150)

        self.conv1x1_1 = nn.Conv2d(512, 1, kernel_size=1, padding=0)
        self.conv1x1_2 = nn.Conv2d(512, 1, kernel_size=1, padding=0)

        self._initialize_weights(self.conv1x1_1)
        self._initialize_weights(self.conv1x1_2)

        # self._initialize_weights(self.part_FC_Top_1)
        # self._initialize_weights(self.part_FC_Top_2)

    def _initialize_weights(self, sub_net):
        for m in sub_net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


    """ STEP1 modules VGG19 classification 150-ways
    """
    def forward_STEP1(self, x):
        x = self.convBase(x)
        x = self.MaxPooling(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        x = self.FC_Top(x)
        return x

    def extract_CONV5_4_feature(self, x):
        x = self.convBase(x)
        return x

    """ STEP2 modules MA_CNN 
    """
    def normalize_atten_maps(self, atten_maps):
        atten_shape = atten_maps.size()
        #--------------------------
        batch_mins, _ = torch.min(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True)
        batch_maxs, _ = torch.max(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True)
        atten_normed = torch.div(atten_maps.view(atten_shape[0:-2] + (-1,))-batch_mins,
                                 batch_maxs - batch_mins)
        atten_normed = atten_normed.view(atten_shape)

        return atten_normed

    def forward(self, x):
        # batch_size = x.shape[0]
        # n * 512 * 28 * 28
        conv5_4 = self.convBase(x)
        Mask_1 = self.relu(self.conv1x1_1(conv5_4)) / 512

        # normalize to (0, 1)
        Mask_1_norm = self.tanh(Mask_1)
        Mask_1_norm = self.normalize_atten_maps(Mask_1_norm)

        # conv5_4_erased = self.erase_feature_maps(Mask_1_norm, conv5_4, threshold=0.6)
        # Mask_2 = self.relu(self.conv1x1_1(conv5_4_erased)) / 512
        Mask_2 = self.relu(self.conv1x1_2(conv5_4)) / 512

        Mask_2_norm = self.tanh(Mask_2)
        Mask_2_norm = self.normalize_atten_maps(Mask_2_norm)

        Masked_conv5_4 = conv5_4 * torch.max(Mask_1_norm, Mask_2_norm)
        # get n * 512 * 7 * 7
        Masked_conv5_5 = self.MaxPooling4x4(Masked_conv5_4)
        x = Masked_conv5_5.view(Masked_conv5_5.size(0), -1)
        x = self.classifier(x)
        logits = self.FC_Top(x)

        return logits, Mask_1, Mask_2,\
               Mask_1_norm, Mask_2_norm, torch.mean(conv5_4, dim=1)

class ZeroOneClipper(object):
    def __call__(self, module):
        # filter the variables to get the ones you want
        if hasattr(module, 'weight'):
            w = module.weight.data
            module.weight.data = torch.clamp(w, min=0.0, max=1.0)


#
# class VGG19_Finetune_AwA40_STEP1and2_2Group(nn.Module):
#     def __init__(self, opt):
#         super(VGG19_Finetune_AwA40_STEP1and2_2Group, self).__init__()
#
#         self.margin = 0.02
#         self.convBase = VGG19_Finetune_ConvBase()  # datasets share the same CNN structure
#
#         """ STEP1 modules VGG19 classification 150-ways
#         """
#         self.MaxPooling = nn.MaxPool2d(kernel_size=2, stride=2)
#         org_classifier = models.vgg19(pretrained=True).classifier
#         self.classifier = nn.Sequential(
#             *list(org_classifier.children())[:-1]
#         )
#         self.FC_Top = nn.Linear(4096, 40)
#         self._initialize_weights(self.FC_Top)
#
#         """ STEP2 modules MA_CNN
#         """
#         self.avgpool = nn.AvgPool2d(28, stride=1)
#         self.avgpool_mask = nn.AvgPool2d((512, 1), stride=1)
#         self.tanh = nn.Tanh()
#         # self.softmax = nn.Softmax(dim=1)
#
#         self.part_FC_Top_1 = nn.Linear(28 * 28, 40)
#         self.part_FC_Top_2 = nn.Linear(28 * 28, 40)
#
#         self.sig_mask_1 = nn.Sequential(
#             nn.Linear(512, 512),
#             nn.Tanh(),
#             nn.Linear(512, 512),
#             nn.Sigmoid()
#         )
#         self.sig_mask_2 = nn.Sequential(
#             nn.Linear(512, 512),
#             nn.Tanh(),
#             nn.Linear(512, 512),
#             nn.Sigmoid()
#         )
#
#         self._initialize_weights(self.sig_mask_1)
#         self._initialize_weights(self.sig_mask_2)
#         self._initialize_weights(self.part_FC_Top_1)
#         self._initialize_weights(self.part_FC_Top_2)
#
#
#     def _initialize_weights(self, sub_net):
#         for m in sub_net.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.normal_(m.weight, 0, 0.01)
#                 nn.init.constant_(m.bias, 0)
#
#     """ STEP1 modules VGG19 classification 40-ways
#     """
#     def forward_STEP1(self, x):
#         x = self.convBase(x)
#         x = self.MaxPooling(x)
#         x = x.view(x.size(0), -1)
#         x = self.classifier(x)
#         x = self.FC_Top(x)
#         return x
#
#     def extract_CONV5_4_feature(self, x):
#         x = self.convBase(x)
#         return x
#
#     """ STEP2 modules MA_CNN
#     """
#     def forward_initialize_sig_mask(self, x):
#         x = self.convBase(x)
#         x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         sig_1 = self.sig_mask_1(x)
#         sig_2 = self.sig_mask_2(x)
#         return sig_1, sig_2
#
#     def forward(self, x):
#         batch_size = x.shape[0]
#         # n * 512 * 28 * 28
#         conv5_4 = self.convBase(x)
#         # n * 512 * 784
#         conv5_4_reshaped = conv5_4.view(conv5_4.shape[0], conv5_4.shape[1], -1)
#
#         """ get sig mask # n * 512
#         """
#         feat_vect = self.avgpool(conv5_4)  # 1*1*512
#         feat_vect = feat_vect.view(feat_vect.size(0), -1)  # 512
#         sig_1_org = self.sig_mask_1(feat_vect)
#         sig_2_org = self.sig_mask_2(feat_vect)
#
#         # sig_1 = F.dropout(sig_1_org, p=0.5, training=self.training)
#         # sig_2 = F.dropout(sig_2_org, p=0.5, training=self.training)
#         sig_1 = sig_1_org
#         sig_2 = sig_2_org
#
#         """ get MASK  # n * 28 * 28
#         """
#         Mask_1 = conv5_4_reshaped * sig_1.unsqueeze(2).repeat((1, 1, 784))
#         Mask_1 = self.avgpool_mask(Mask_1).squeeze()
#         Mask_1 = Mask_1
#         Mask_1 = self.tanh(Mask_1)  #.view(Mask_1.shape[0], 28, 28)
#
#
#         Mask_2 = conv5_4_reshaped * sig_2.unsqueeze(2).repeat((1, 1, 784))
#         Mask_2 = Mask_2
#         Mask_2 = self.avgpool_mask(Mask_2).squeeze()
#         Mask_2 = self.tanh(Mask_2)  #.view(Mask_2.shape[0], 28, 28)
#
#         # normalized Mask_1, Mask_2
#         max_value, _ = torch.max(Mask_1, dim=1, keepdim=True)
#         max_value_matrix = max_value.repeat(1, Mask_1.shape[1])
#         Mask_1_norm = Mask_1 / max_value_matrix
#         max_value, _ = torch.max(Mask_2, dim=1, keepdim=True)
#         max_value_matrix = max_value.repeat(1, Mask_2.shape[1])
#         Mask_2_norm = Mask_2 / max_value_matrix
#
#         """ get part_feat  # n * 784, cls_top
#         """
#         # orginal code this is sum instead of average
#         attention = self.avgpool_mask(conv5_4_reshaped).squeeze() * 512
#         part_feat_1 = attention * Mask_1_norm
#         part_feat_2 = attention * Mask_2_norm
#
#         part_cls_1 = self.part_FC_Top_1(part_feat_1)
#         part_cls_2 = self.part_FC_Top_2(part_feat_2)
#
#         return part_cls_1+part_cls_2, part_cls_2, \
#                Mask_1, Mask_2,\
#                Mask_1_norm, Mask_2_norm,\
#                sig_1_org, sig_2_org, \
#                torch.mean(conv5_4, dim=1)
#
#     def get_Mask_test(self, x):
#         # n * 512 * 28 * 28
#         # n * 512 * 28 * 28
#         conv5_4 = self.convBase(x)
#         # n * 512 * 784
#         conv5_4_reshaped = conv5_4.view(conv5_4.shape[0], conv5_4.shape[1], -1)
#
#         """ get sig mask # n * 512
#         """
#         feat_vect = self.avgpool(conv5_4)  # 1*1*512
#         feat_vect = feat_vect.view(feat_vect.size(0), -1)  # 512
#         sig_1 = self.sig_mask_1(feat_vect)
#         sig_2 = self.sig_mask_2(feat_vect)
#
#         """ get MASK  # n * 28 * 28
#         """
#         Mask_1 = conv5_4_reshaped * sig_1.unsqueeze(2).repeat((1, 1, 784))
#         Mask_1 = self.avgpool_mask(Mask_1).squeeze()
#         Mask_1 = self.softmax(Mask_1)  #.view(Mask_1.shape[0], 28, 28)
#
#         Mask_2 = conv5_4_reshaped * sig_2.unsqueeze(2).repeat((1, 1, 784))
#         Mask_2 = self.avgpool_mask(Mask_2).squeeze()
#         Mask_2 = self.softmax(Mask_2)  #.view(Mask_2.shape[0], 28, 28)
#
#         return Mask_1, Mask_2


# class VGG19_Finetune_FLO82_STEP1and2_2Group(nn.Module):
#     def __init__(self, opt):
#         super(VGG19_Finetune_FLO82_STEP1and2_2Group, self).__init__()
#
#         self.margin = 0.02
#         self.convBase = VGG19_Finetune_ConvBase()
#
#         """ STEP1 modules VGG19 classification 150-ways
#         """
#         self.MaxPooling = nn.MaxPool2d(kernel_size=2, stride=2)
#         org_classifier = models.vgg19(pretrained=True).classifier
#         self.classifier = nn.Sequential(
#             *list(org_classifier.children())[:-1]
#         )
#         self.FC_Top = nn.Linear(4096, 82)
#         self._initialize_weights(self.FC_Top)
#
#         """ STEP2 modules MA_CNN
#         """
#         self.avgpool = nn.AvgPool2d(28, stride=1)
#         self.avgpool_mask = nn.AvgPool2d((512, 1), stride=1)
#         self.tanh = nn.Tanh()
#         # self.softmax = nn.Softmax(dim=1)
#
#         self.part_FC_Top_1 = nn.Linear(28 * 28, 82)
#         self.part_FC_Top_2 = nn.Linear(28 * 28, 82)
#
#         self.sig_mask_1 = nn.Sequential(
#             nn.Linear(512, 512),
#             nn.Tanh(),
#             nn.Linear(512, 512),
#             nn.Sigmoid()
#         )
#         self.sig_mask_2 = nn.Sequential(
#             nn.Linear(512, 512),
#             nn.Tanh(),
#             nn.Linear(512, 512),
#             nn.Sigmoid()
#         )
#
#         self._initialize_weights(self.sig_mask_1)
#         self._initialize_weights(self.sig_mask_2)
#         self._initialize_weights(self.part_FC_Top_1)
#         self._initialize_weights(self.part_FC_Top_2)
#
#
#     def _initialize_weights(self, sub_net):
#         for m in sub_net.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.normal_(m.weight, 0, 0.01)
#                 nn.init.constant_(m.bias, 0)
#
#     """ STEP1 modules VGG19 classification 150-ways
#     """
#     def forward_STEP1(self, x):
#         x = self.convBase(x)
#         x = self.MaxPooling(x)
#         x = x.view(x.size(0), -1)
#         x = self.classifier(x)
#         x = self.FC_Top(x)
#         return x
#
#     def extract_CONV5_4_feature(self, x):
#         x = self.convBase(x)
#         return x
#
#     """ STEP2 modules MA_CNN
#     """
#     def forward_initialize_sig_mask(self, x):
#         x = self.convBase(x)
#         x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         sig_1 = self.sig_mask_1(x)
#         sig_2 = self.sig_mask_2(x)
#         return sig_1, sig_2
#
#     def forward(self, x):
#         batch_size = x.shape[0]
#         # n * 512 * 28 * 28
#         conv5_4 = self.convBase(x)
#         # n * 512 * 784
#         conv5_4_reshaped = conv5_4.view(conv5_4.shape[0], conv5_4.shape[1], -1)
#
#         """ get sig mask # n * 512
#         """
#         feat_vect = self.avgpool(conv5_4)  # 1*1*512
#         feat_vect = feat_vect.view(feat_vect.size(0), -1)  # 512
#         sig_1_org = self.sig_mask_1(feat_vect)
#         sig_2_org = self.sig_mask_2(feat_vect)
#
#         # sig_1 = F.dropout(sig_1_org, p=0.5, training=self.training)
#         # sig_2 = F.dropout(sig_2_org, p=0.5, training=self.training)
#         sig_1 = sig_1_org
#         sig_2 = sig_2_org
#
#         """ get MASK  # n * 28 * 28
#         """
#         Mask_1 = conv5_4_reshaped * sig_1.unsqueeze(2).repeat((1, 1, 784))
#         Mask_1 = self.avgpool_mask(Mask_1).squeeze()
#         Mask_1 = Mask_1
#         Mask_1 = self.tanh(Mask_1)  #.view(Mask_1.shape[0], 28, 28)
#
#
#         Mask_2 = conv5_4_reshaped * sig_2.unsqueeze(2).repeat((1, 1, 784))
#         Mask_2 = Mask_2
#         Mask_2 = self.avgpool_mask(Mask_2).squeeze()
#         Mask_2 = self.tanh(Mask_2)  #.view(Mask_2.shape[0], 28, 28)
#
#         # normalized Mask_1, Mask_2
#         max_value, _ = torch.max(Mask_1, dim=1, keepdim=True)
#         max_value_matrix = max_value.repeat(1, Mask_1.shape[1])
#         Mask_1_norm = Mask_1 / max_value_matrix
#         max_value, _ = torch.max(Mask_2, dim=1, keepdim=True)
#         max_value_matrix = max_value.repeat(1, Mask_2.shape[1])
#         Mask_2_norm = Mask_2 / max_value_matrix
#
#         """ get part_feat  # n * 784, cls_top
#         """
#         # orginal code this is sum instead of average
#         attention = self.avgpool_mask(conv5_4_reshaped).squeeze() * 512
#         part_feat_1 = attention * Mask_1_norm
#         part_feat_2 = attention * Mask_2_norm
#
#         part_cls_1 = self.part_FC_Top_1(part_feat_1)
#         part_cls_2 = self.part_FC_Top_2(part_feat_2)
#
#         return part_cls_1+part_cls_2, part_cls_2, \
#                Mask_1, Mask_2,\
#                Mask_1_norm, Mask_2_norm,\
#                sig_1_org, sig_2_org, \
#                torch.mean(conv5_4, dim=1)
#
#     def get_Mask_test(self, x):
#         # n * 512 * 28 * 28
#         # n * 512 * 28 * 28
#         conv5_4 = self.convBase(x)
#         # n * 512 * 784
#         conv5_4_reshaped = conv5_4.view(conv5_4.shape[0], conv5_4.shape[1], -1)
#
#         """ get sig mask # n * 512
#         """
#         feat_vect = self.avgpool(conv5_4)  # 1*1*512
#         feat_vect = feat_vect.view(feat_vect.size(0), -1)  # 512
#         sig_1 = self.sig_mask_1(feat_vect)
#         sig_2 = self.sig_mask_2(feat_vect)
#
#         """ get MASK  # n * 28 * 28
#         """
#         Mask_1 = conv5_4_reshaped * sig_1.unsqueeze(2).repeat((1, 1, 784))
#         Mask_1 = self.avgpool_mask(Mask_1).squeeze()
#         Mask_1 = self.softmax(Mask_1)  #.view(Mask_1.shape[0], 28, 28)
#
#         Mask_2 = conv5_4_reshaped * sig_2.unsqueeze(2).repeat((1, 1, 784))
#         Mask_2 = self.avgpool_mask(Mask_2).squeeze()
#         Mask_2 = self.softmax(Mask_2)  #.view(Mask_2.shape[0], 28, 28)
#
#         return Mask_1, Mask_2

# def log_print(text, color=None, on_color=None, attrs=None):
#     if cprint is not None:
#         cprint(text, color=color, on_color=on_color, attrs=attrs)
#     else:
#         print(text)


if __name__ == "__main__":
    a = VGG19_Finetune_ConvBase()
    print(a)
