#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jul  9 14:41:15 2020

@author: zw
"""

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from torch.nn import Parameter
from torch.nn.modules.batchnorm import _BatchNorm

import numpy as np
import math

from .resnet import *
from torch.utils import model_zoo
import torchvision
import os

############################################################################### 

class Encoding_rawres50(nn.Module):
    def __init__(self, pretrain=True):
        super(Encoding_rawres50, self).__init__()
        self.resnet = models.resnet50(pretrained=pretrain)
    
    def forward(self, x):
        
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x1 = self.resnet.layer1(x)
        x2 = self.resnet.layer2(x1)
        x3 = self.resnet.layer3(x2)
        x4 = self.resnet.layer4(x3)
        
        return [x1, x2, x3, x4]

class UpConvBlock(nn.Module):
    def __init__(self, inp, out):
        super(UpConvBlock, self).__init__()
        
        self.Up = nn.Conv2d(inp, out, 3, padding=1) 
        self.Up_bn = nn.BatchNorm2d(out)
        self.Up_relu = nn.ReLU(inplace=True)
        
        self.upscore = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x, x_skip):
        
        x_up = self.upscore(x)
        new_x = torch.cat((x_up, x_skip), 1)
        new_x = self.Up_relu(self.Up_bn(self.Up(new_x)))
        
        return new_x

class Decoding(nn.Module):
    def __init__(self, ups=3):
        super(Decoding, self).__init__()
        self.ups = ups
        
        channels = [2048, 1024, 512, 256]
        layer = []
        for i in range(ups):
            layer.append(UpConvBlock(channels[i]+channels[i+1], channels[i+1]))
            
        self.layer = nn.Sequential(*layer)
        
    def forward(self, Enc_list):
        out = []
        base_in = Enc_list[-1]
        
        out.append(base_in)
        Enc_list = list(reversed(Enc_list))
        Enc_list = Enc_list[1:]
        for i in range(self.ups):
            base_in = self.layer[i](base_in, Enc_list[i])
            out.append(base_in)
        return out

class OutConvBlock(nn.Module):
    def __init__(self, inp, upscale, outch):
        super(OutConvBlock, self).__init__()
        self.outch = outch
        self.outconv = nn.Conv2d(inp, outch, 3, padding=1)   
        self.upscore = nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        
        x = self.outconv(x)            
        x = self.upscore(x)

        return x

class Out(nn.Module):
    def __init__(self, outnum, outch=1):
        super(Out, self).__init__()
        self.outnum = outnum
        scale = [32, 16, 8, 4]
        inp_channel = [2048, 1024, 512, 256]
        layer = []
        
        for i in range(outnum):
            layer.append(OutConvBlock(inp_channel[i], scale[i], outch))
        
        self.layer = nn.Sequential(*layer)

    def forward(self, out_list):
        out = []
        
        for i in range(self.outnum):
            out.append(self.layer[i](out_list[i]))
        
        out = list(reversed(out))
        return out

class Unet(nn.Module):
    def __init__(self, ups=3, pretrain=True, is_edge=False):
        super(Unet, self).__init__()
        self.encoding = Encoding_rawres50(pretrain=True)   
        #self.encoding = Encoding()
        
        self.decoding_main = Decoding(ups=ups)     
        self.out_main = Out(outnum=ups+1)
        self.final_fuse = nn.Conv2d(4, 1, 1, padding=0)
        
        self.outup = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        self.is_edge = is_edge
        if self.is_edge:

            self.decoding_edge = Decoding(ups=ups)
            
            # edge out
            self.outedge = nn.Conv2d(256, 32, 3, padding=1)
            self.outedge_bn = nn.BatchNorm2d(32)
            self.outedge_relu = nn.ReLU(inplace=True)
            self.outedgelast = nn.Conv2d(32, 2, 1, padding=0)
        
    def forward(self, x):
        '''
        '1': [256 , 88, 88]
        '2': [512 , 44, 44]
        '3': [1024, 22, 22]
        '4': [2048, 11, 11]
        '''
        # Encoding
        Ex = self.encoding(x)               #out = [x1, x2, x3, x4]
        
        # Main Branch
        DMx = self.decoding_main(Ex)        #out = [d4, d3, d2, d1]

        # ----Output----
        # Final Out
        OMx = self.out_main(DMx)            #out = [o1, o2, o3, o4]
        OMx = [torch.sigmoid(temp) for temp in OMx]

        if self.is_edge:
            # Edge Branch
            DOx = self.decoding_edge(Ex)        #out = [d4, d3, d2, d1]
            DOx = DOx[3]
            DOx = self.outedge_relu(self.outedge_bn(self.outedge(DOx)))
            OOx = self.outup(self.outedgelast(DOx))
            OOx = [torch.sigmoid(OOx)]
            
            return [OMx, OOx, OMx]
        
        else:
            return [OMx, OMx, OMx]

class ASPP_Bottleneck(nn.Module):
    def __init__(self, num_classes):
        super(ASPP_Bottleneck, self).__init__()

        self.conv_1x1_1 = nn.Conv2d(4*512, 256, kernel_size=1)
        self.bn_conv_1x1_1 = nn.BatchNorm2d(256)

        self.conv_3x3_1 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=6, dilation=6)
        self.bn_conv_3x3_1 = nn.BatchNorm2d(256)

        self.conv_3x3_2 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=12, dilation=12)
        self.bn_conv_3x3_2 = nn.BatchNorm2d(256)

        self.conv_3x3_3 = nn.Conv2d(4*512, 256, kernel_size=3, stride=1, padding=18, dilation=18)
        self.bn_conv_3x3_3 = nn.BatchNorm2d(256)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.conv_1x1_2 = nn.Conv2d(4*512, 256, kernel_size=1)
        self.bn_conv_1x1_2 = nn.BatchNorm2d(256)

        self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256)
        self.bn_conv_1x1_3 = nn.BatchNorm2d(256)

        self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, feature_map):
        # (feature_map has shape (batch_size, 4*512, h/16, w/16))

        feature_map_h = feature_map.size()[2] # (== h/16)
        feature_map_w = feature_map.size()[3] # (== w/16)

        out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_1 = F.relu(self.bn_conv_3x3_1(self.conv_3x3_1(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_2 = F.relu(self.bn_conv_3x3_2(self.conv_3x3_2(feature_map))) # (shape: (batch_size, 256, h/16, w/16))
        out_3x3_3 = F.relu(self.bn_conv_3x3_3(self.conv_3x3_3(feature_map))) # (shape: (batch_size, 256, h/16, w/16))

        out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1))
        out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1))
        out_img = F.upsample(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/16, w/16))

        out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], 1) # (shape: (batch_size, 1280, h/16, w/16))
        out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/16, w/16))
        out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/16, w/16))

        return out

class DeepLabV3(nn.Module):
    def __init__(self, is_edge=False):
        super(DeepLabV3, self).__init__()

        self.num_classes = 1
        self.is_edge = is_edge
        self.resnet = ResNet_Bottleneck_OS16(num_layers=101)
        #self.resnet = ResNet18_OS8() # NOTE! specify the type of ResNet here
        #self.aspp = ASPP(num_classes=self.num_classes) # NOTE! if you use ResNet50-152, set self.aspp = ASPP_Bottleneck(num_classes=self.num_classes) instead
        if self.is_edge:
            self.aspp = ASPP_Bottleneck(num_classes=self.num_classes)
            self.aspp_edge = ASPP_Bottleneck(num_classes=self.num_classes)
        else:
            self.aspp = ASPP_Bottleneck(num_classes=self.num_classes)

    def forward(self, x):
        # (x has shape (batch_size, 3, h, w))

        h = x.size()[2]
        w = x.size()[3]

        feature_map = self.resnet(x) # (shape: (batch_size, 512, h/16, w/16)) (assuming self.resnet is ResNet18_OS16 or ResNet34_OS16. If self.resnet is ResNet18_OS8 or ResNet34_OS8, it will be (batch_size, 512, h/8, w/8). If self.resnet is ResNet50-152, it will be (batch_size, 4*512, h/16, w/16))

        if self.is_edge:
            output = self.aspp(feature_map)
            output = F.upsample(output, size=(h, w), mode="bilinear")

            output_edge = self.aspp_edge(feature_map)
            output_edge = F.upsample(output_edge, size=(h, w), mode="bilinear")

            return [torch.sigmoid(output), torch.sigmoid(output_edge)]

        else:
            output = self.aspp(feature_map) # (shape: (batch_size, num_classes, h/16, w/16))
            output = F.upsample(output, size=(h, w), mode="bilinear") # (shape: (batch_size, num_classes, h, w))
            return [torch.sigmoid(output)]

def make_layer(block, in_channels, channels, num_blocks, stride=1, dilation=1):
    strides = [stride] + [1]*(num_blocks - 1) # (stride == 2, num_blocks == 4 --> strides == [2, 1, 1, 1])

    blocks = []
    for stride in strides:
        blocks.append(block(in_channels=in_channels, channels=channels, stride=stride, dilation=dilation))
        in_channels = block.expansion*channels

    layer = nn.Sequential(*blocks) # (*blocks: call with unpacked list entires as arguments)

    return layer

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, stride=1, dilation=1):
        super(BasicBlock, self).__init__()

        out_channels = self.expansion*channels

        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)

        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

        if (stride != 1) or (in_channels != out_channels):
            conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            bn = nn.BatchNorm2d(out_channels)
            self.downsample = nn.Sequential(conv, bn)
        else:
            self.downsample = nn.Sequential()

    def forward(self, x):
        # (x has shape: (batch_size, in_channels, h, w))

        out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
        out = self.bn2(self.conv2(out)) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)

        out = out + self.downsample(x) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)

        out = F.relu(out) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, channels, stride=1, dilation=1):
        super(Bottleneck, self).__init__()

        out_channels = self.expansion*channels

        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)

        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

        self.conv3 = nn.Conv2d(channels, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        if (stride != 1) or (in_channels != out_channels):
            conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            bn = nn.BatchNorm2d(out_channels)
            self.downsample = nn.Sequential(conv, bn)
        else:
            self.downsample = nn.Sequential()

    def forward(self, x):
        # (x has shape: (batch_size, in_channels, h, w))

        out = F.relu(self.bn1(self.conv1(x))) # (shape: (batch_size, channels, h, w))
        out = F.relu(self.bn2(self.conv2(out))) # (shape: (batch_size, channels, h, w) if stride == 1, (batch_size, channels, h/2, w/2) if stride == 2)
        out = self.bn3(self.conv3(out)) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)

        out = out + self.downsample(x) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)

        out = F.relu(out) # (shape: (batch_size, out_channels, h, w) if stride == 1, (batch_size, out_channels, h/2, w/2) if stride == 2)

        return out

class ResNet_Bottleneck_OS16(nn.Module):
    def __init__(self, num_layers):
        super(ResNet_Bottleneck_OS16, self).__init__()

        if num_layers == 50:
            resnet = models.resnet50()
            # remove fully connected layer, avg pool and layer5:
            self.resnet = nn.Sequential(*list(resnet.children())[:-3])

        elif num_layers == 101:
            resnet = models.resnet101()
            # remove fully connected layer, avg pool and layer5:
            self.resnet = nn.Sequential(*list(resnet.children())[:-3])

        elif num_layers == 152:
            resnet = models.resnet152()
            # remove fully connected layer, avg pool and layer5:
            self.resnet = nn.Sequential(*list(resnet.children())[:-3])

        else:
            raise Exception("num_layers must be in {50, 101, 152}!")

        self.layer5 = make_layer(Bottleneck, in_channels=4*256, channels=512, num_blocks=3, stride=1, dilation=2)

    def forward(self, x):
        # (x has shape (batch_size, 3, h, w))

        # pass x through (parts of) the pretrained ResNet:
        c4 = self.resnet(x) # (shape: (batch_size, 4*256, h/16, w/16)) (it's called c4 since 16 == 2^4)

        output = self.layer5(c4) # (shape: (batch_size, 4*512, h/16, w/16))

        return output