#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jul  9 14:41:15 2020

@author: zw
"""


import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from torch.nn import Parameter
from torch.nn.modules.batchnorm import _BatchNorm

import numpy as np
import math

#from .resnet_model import BasicBlock
from .resnet import *
from .resnet import BasicBlock
from torch.utils import model_zoo
import torchvision
import os


############################################################################### 

class Encoding_rawres50(nn.Module):
    def __init__(self, trinum=1, pretrain=True):
        super(Encoding_rawres50, self).__init__()
        self.trinum = trinum
        self.resnet = models.resnet50(pretrained=pretrain)
        
        resb = []
        for i in range(trinum):
            resb.append(BasicBlock(2048, 2048))
        
        self.resb = nn.Sequential(*resb)
    
    def forward(self, x):
        
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x1 = self.resnet.layer1(x)
        x2 = self.resnet.layer2(x1)
        x3 = self.resnet.layer3(x2)
        x4 = self.resnet.layer4(x3)
        
        x4 = self.resb(x4)
        
        x4 = self.resb(x4)
        
        return [x1, x2, x3, x4]

class Encoding(nn.Module):
    def __init__(self, trinum=1):
        super(Encoding, self).__init__()
        self.trinum = trinum
        self.resnet = resnet50()
        
        resb = []
        for i in range(trinum):
            resb.append(BasicBlock(2048, 2048))
        
        self.resb = nn.Sequential(*resb)
    
    def forward(self, x):
        x = self.resnet.relu1(self.resnet.bn1(self.resnet.conv1(x)))
        if self.resnet.deep_base:
            x = self.resnet.relu2(self.resnet.bn2(self.resnet.conv2(x)))
            x = self.resnet.relu3(self.resnet.bn3(self.resnet.conv3(x)))
        x = self.resnet.maxpool(x)

        x1 = self.resnet.layer1(x)  # out = [88] 256
        x2 = self.resnet.layer2(x1)  # out = [44] 512
        x3 = self.resnet.layer3(x2)  # out = [22] 1024
        x4 = self.resnet.layer4(x3)  # out = [11] 2048
        
        x4 = self.resb(x4)
        
        return [x1, x2, x3, x4]

class UpConvBlock(nn.Module):
    def __init__(self, inp, out):
        super(UpConvBlock, self).__init__()
        
        self.Up = nn.Conv2d(inp, out, 3, padding=1) 
        self.Up_bn = nn.BatchNorm2d(out)
        self.Up_relu = nn.ReLU(inplace=True)
        
        self.upscore = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x, x_skip):
        
        x_up = self.upscore(x)
        new_x = torch.cat((x_up, x_skip), 1)
        new_x = self.Up_relu(self.Up_bn(self.Up(new_x)))
        
        return new_x

class Decoding(nn.Module):
    def __init__(self, ups=3):
        super(Decoding, self).__init__()
        self.ups = ups
        
        channels = [2048, 1024, 512, 256]
        layer = []
        for i in range(ups):
            layer.append(UpConvBlock(channels[i]+channels[i+1], channels[i+1]))
            
        self.layer = nn.Sequential(*layer)
        
    def forward(self, Enc_list):
        out = []
        base_in = Enc_list[-1]
        
        out.append(base_in)
        Enc_list = list(reversed(Enc_list))
        Enc_list = Enc_list[1:]
        for i in range(self.ups):
            base_in = self.layer[i](base_in, Enc_list[i])
            out.append(base_in)
        return out

class OutConvBlock(nn.Module):
    def __init__(self, inp, upscale, outch):
        super(OutConvBlock, self).__init__()
        self.outch = outch
        self.outconv = nn.Conv2d(inp, outch, 3, padding=1)   
        self.upscore = nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        
        x = self.outconv(x)            
        x = self.upscore(x)
        '''
        if(self.outch==1):    
            x = torch.sigmoid(x)
        else:
            x = x
        '''
        return x

class Out(nn.Module):
    def __init__(self, outnum, outch=1):
        super(Out, self).__init__()
        self.outnum = outnum
        scale = [32, 16, 8, 4]
        inp_channel = [2048, 1024, 512, 256]
        layer = []
        
        for i in range(outnum):
            layer.append(OutConvBlock(inp_channel[i], scale[i], outch))
        
        self.layer = nn.Sequential(*layer)

    def forward(self, out_list):
        out = []
        
        for i in range(self.outnum):
            out.append(self.layer[i](out_list[i]))
        
        out = list(reversed(out))
        return out

class FG_fuse(nn.Module):
    def __init__(self, inp):
        super(FG_fuse, self).__init__()
        self.x_tri = nn.Conv2d(inp, inp//16, 3, padding=1)   
        self.conx_tri = nn.Conv2d(inp, inp//16, 3, padding=1)  
        self.order_tri = nn.Conv2d(inp, inp//16, 3, padding=1)  
        self.outconv = nn.Conv2d((inp//16)*3, inp, 3, padding=1)   
        
    def forward(self, x, conx, order):
        
        x = self.x_tri(x)
        conx = self.conx_tri(conx)
        order = self.order_tri(order)
        
        x_out = x
        
        N, c, w, h = x.size()
        
        x = x.view(N, c, w*h)
        conx = conx.view(N, c, w*h)
        conx = torch.transpose(conx, 1, 2).contiguous()         #[N, w*h, c]
        order = order.view(N, c, w*h)
        order = torch.transpose(order, 1, 2).contiguous()       #[N, w*h, c]
        
        cross_conx = torch.matmul(x, conx)         #[N, c, c]
        cross_order = torch.matmul(x, order)       #[N, c, c]
    
        x = torch.transpose(x, 1, 2).contiguous()         #[N, w*h, c]
        
        cross_conx = torch.matmul(x, cross_conx)             #[N, c, w*h]
        cross_order = torch.matmul(x, cross_order)           #[N, c, w*h]
        
        cross_conx = cross_conx.view(N, c, w, h)
        cross_order = cross_order.view(N, c, w, h)
        
        #fuse = x_out + cross_conx + cross_order
        fuse = torch.cat((x_out, cross_conx, cross_order), 1)
        fuse = self.outconv(fuse)
        
        return fuse

class FG_Bridge(nn.Module):
    def __init__(self, outnum, fusenum=4):
        super(FG_Bridge, self).__init__()
        
        self.outnum = outnum
        self.fusenum = fusenum
        inp_channel = [2048, 1024, 512, 256]
        layer = []
        
        for i in range(outnum):
            layer.append(FG_fuse(inp_channel[i]))
        
        self.layer = nn.Sequential(*layer)

    def forward(self, out_list, conx_list, order_list):
        out = []
        
        for i in range(self.outnum):
            if(i <= self.fusenum):    
                out.append(out_list[i]*(1+self.layer[i](out_list[i], conx_list[i], order_list[i])))
            else:
                out.append(out_list[i])
        
        return out

class BaseFG(nn.Module):
    def __init__(self, ups=3, pretrain=True):
        super(BaseFG, self).__init__()
        self.encoding = Encoding_rawres50(pretrain=True)   
        #self.encoding = Encoding()
        
        self.decoding_main = Decoding(ups=ups)     
        self.out_main = Out(outnum=ups+1)
        self.final_fuse = nn.Conv2d(4, 1, 1, padding=0)
        
        self.decoding_order = Decoding(ups=ups)
        #self.out_order = Out(outnum=ups+1, outch=2)
        
        self.decoding_con = Decoding(ups=ups)
        #self.out_con = Out(outnum=ups+1, outch=2)
        
        self.fg_bridge = FG_Bridge(outnum=ups+1)
        
        # order out
        self.outorder = nn.Conv2d(256, 32, 3, padding=1)
        self.outorder_bn = nn.BatchNorm2d(32)
        self.outorder_relu = nn.ReLU(inplace=True)
        self.outorderlast = nn.Conv2d(32, 2, 1, padding=0)
        
        # convex out
        self.outconx = nn.Conv2d(256, 32, 3, padding=1)
        self.outconx_bn = nn.BatchNorm2d(32)
        self.outconx_relu = nn.ReLU(inplace=True)
        self.outconxlast = nn.Conv2d(32, 2, 1, padding=0)
        
        self.outup = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)

        
    def forward(self, x):
        '''
        '1': [256 , 88, 88]
        '2': [512 , 44, 44]
        '3': [1024, 22, 22]
        '4': [2048, 11, 11]
        '''
        # Encoding
        Ex = self.encoding(x)            #out = [x1, x2, x3, x4]
        
        # Order Branch
        DOx = self.decoding_order(Ex)    #out = [d4, d3, d2, d1]
        
        # Convex Branch
        DCx = self.decoding_con(Ex)    #out = [d4, d3, d2, d1]
        
        # Figure-Ground Cus Fuse
        Ex = list(reversed(Ex))
        Ex = self.fg_bridge(Ex, DCx, DOx)
        Ex = list(reversed(Ex))
        
        # Main Branch
        DMx = self.decoding_main(Ex)    #out = [d4, d3, d2, d1]

        # ----Output----
        # Final Out
        
        #fuse_out = torch.sigmoid(self.final_fuse(fuse_out))
        OMx = self.out_main(DMx)         #out = [o1, o2, o3, o4]
        OMx = [torch.sigmoid(temp) for temp in OMx]
        
        #OOx = self.out_order(DOx)
        #OOx = [torch.sigmoid(temp) for temp in OOx]
        #OCx = self.out_con(DCx)
        #OCx = [torch.sigmoid(temp) for temp in OCx]
        
        # Order Out
        DOx = DOx[3]
        DOx = self.outorder_relu(self.outorder_bn(self.outorder(DOx)))
        OOx = self.outup(self.outorderlast(DOx))
        OOx = [torch.sigmoid(OOx)]
        
        # Convex Out
        DCx = DCx[3]
        DCx = self.outconx_relu(self.outconx_bn(self.outconx(DCx)))
        OCx = self.outup(self.outconxlast(DCx))
        OCx = [torch.sigmoid(OCx)]
        
        return [OMx, OOx, OCx]
