import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np
from lib.pspnet import PSPNet
from lib.pspnet_new import PSPNet as PSPNet_2
from lib import gcn3d
from lib.utils import normalize_to_box, sample_farthest_points, load_obj
import lib.pytorch_utils as pt_utils

import pdb

psp_models = {
    'resnet18': lambda: PSPNet_2(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'),
    'resnet34': lambda: PSPNet_2(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'),
    'resnet50': lambda: PSPNet_2(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'),
}

class DeformNet(nn.Module):
    def __init__(self, n_cat=6, nv_prior=1024, imp=False, depth_input=True):
        super(DeformNet, self).__init__()
        self.n_cat = n_cat
        self.depth_input = depth_input
        in_dim = 3 if self.depth_input else 2
        self.imp = imp
        self.psp = PSPNet(bins=(1, 2, 3, 6), backend='resnet18')
        self.instance_color = nn.Sequential(
            nn.Conv1d(32, 64, 1),
            nn.ReLU(),
        )
        self.instance_geometry = nn.Sequential(
            nn.Conv1d(in_dim, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
        )
        self.instance_global = nn.Sequential(
            nn.Conv1d(128, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        self.category_local = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
        )
        self.category_global = nn.Sequential(
            nn.Conv1d(64, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        self.assignment = nn.Sequential(
            nn.Conv1d(2176, 512, 1),
            nn.ReLU(),
            nn.Conv1d(512, 256, 1),
            nn.ReLU(),
            nn.Conv1d(256, n_cat*nv_prior, 1),
        )
        if self.imp:
            self.pos_enc = PositionalEncoder(input_dim=3, max_freq_log2=9,
                                    N_freqs=10)

            self.deformation = nn.Sequential(
                nn.Linear(2112+3, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, n_cat*3),
            )
        else:
            self.deformation = nn.Sequential(
                nn.Conv1d(2112, 512, 1),
                nn.ReLU(),
                nn.Conv1d(512, 256, 1),
                nn.ReLU(),
                nn.Conv1d(256, n_cat*3, 1),
            )
        # Initialize weights to be small so initial deformations aren't so big
        self.deformation[4].weight.data.normal_(0, 0.0001)

    def forward(self, points, out_img, choose, cat_id, prior):
        """
        Args:
            points: bs x n_pts x 3
            img: bs x 3 x H x W
            choose: bs x n_pts
            cat_id: bs
            prior: bs x nv x 3

        Returns:
            assign_mat: bs x n_pts x nv
            inst_shape: bs x nv x 3
            deltas: bs x nv x 3
            log_assign: bs x n_pts x nv, for numerical stability

        """
        if self.n_cat == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)
        bs, n_pts = points.size()[:2]
        nv = prior.size()[1]
        # instance-specific features
        if self.imp:
            prior_emb = self.pos_enc(prior)
            prior_emb = prior_emb.permute(0, 2, 1)
        points = points.permute(0, 2, 1)
        if self.depth_input:
            points = self.instance_geometry(points)
        else:
            points = self.instance_geometry(points[:, :2])
        di = out_img.size()[1]
        emb = out_img.view(bs, di, -1)
        choose = choose.unsqueeze(1).repeat(1, di, 1)
        emb = torch.gather(emb, 2, choose).contiguous()
        emb = self.instance_color(emb)
        inst_local = torch.cat((points, emb), dim=1)     # bs x 128 x n_pts
        inst_global = self.instance_global(inst_local)    # bs x 1024 x 1
        # category-specific features
        cat_prior = prior.permute(0, 2, 1)
        cat_local = self.category_local(cat_prior)    # bs x 64 x n_pts
        cat_global = self.category_global(cat_local)  # bs x 1024 x 1
        # assignemnt matrix
        assign_feat = torch.cat((inst_local, inst_global.repeat(1, 1, n_pts), cat_global.repeat(1, 1, n_pts)), dim=1)     # bs x 2176 x n_pts
        assign_mat = self.assignment(assign_feat)
        assign_mat = assign_mat.view(-1, nv, n_pts).contiguous()   # bs, nc*nv, n_pts -> bs*nc, nv, n_pts
        index = cat_id + torch.arange(bs, dtype=torch.long, device=cat_id.device)* self.n_cat
        assign_mat = torch.index_select(assign_mat, 0, index)   # bs x nv x n_pts
        assign_mat = assign_mat.permute(0, 2, 1).contiguous()    # bs x n_pts x nv
        # deformation field
        deform_feat = torch.cat((cat_local, cat_global.repeat(1, 1, nv), inst_global.repeat(1, 1, nv)), dim=1)       # bs x 2112 x n_pts
        if self.imp:
            deform_feat = torch.cat([deform_feat, cat_prior], dim=1)
            # deform_feat = torch.cat([deform_feat, prior_emb], dim=1)
            deform_feat = deform_feat.permute(0, 2, 1).contiguous()
            deform_feat = deform_feat.reshape(-1, deform_feat.shape[-1])
            deltas = self.deformation(deform_feat)
            deltas = deltas.reshape(bs, -1, 3*self.n_cat) # bs, nv, nc*3
            deltas = deltas.permute(0, 2, 1)
            deltas = deltas.view(-1, 3, nv).contiguous()
        else:
            deltas = self.deformation(deform_feat)
            deltas = deltas.view(-1, 3, nv).contiguous()   # bs, nc*3, nv -> bs*nc, 3, nv
        deltas = torch.index_select(deltas, 0, index)   # bs x 3 x nv
        deltas = deltas.permute(0, 2, 1).contiguous()   # bs x nv x 3

        return assign_mat, deltas, points, assign_feat

class PointPoseHeadNet(nn.Module):
    def __init__(self, num_points, num_obj, use_fc=False, max_point=False, use_fuse=False):
        super(PointPoseHeadNet, self).__init__()
        
        self.psp = PSPNet(bins=(1, 2, 3, 6), backend='resnet18')
        self.num_points = num_points
        self.num_obj = num_obj
        self.use_fc = use_fc
        self.max_point = max_point

        # self.conv1 = torch.nn.Conv1d(geo_dim, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)

        self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
        self.e_conv2 = torch.nn.Conv1d(64, 128, 1)

        self.conv5 = torch.nn.Conv1d(256, 512, 1)
        self.conv6 = torch.nn.Conv1d(512, 1024, 1)

        self.ap1 = torch.nn.AvgPool1d(num_points)

        self.conv1_r = torch.nn.Conv1d(1408, 640, 1)
        self.conv1_t = torch.nn.Conv1d(1408, 640, 1)
        self.conv1_c = torch.nn.Conv1d(1408, 640, 1)

        self.conv2_r = torch.nn.Conv1d(640, 256, 1)
        self.conv2_t = torch.nn.Conv1d(640, 256, 1)
        self.conv2_c = torch.nn.Conv1d(640, 256, 1)

        self.conv3_r = torch.nn.Conv1d(256, 128, 1)
        self.conv3_t = torch.nn.Conv1d(256, 128, 1)
        self.conv3_c = torch.nn.Conv1d(256, 128, 1)

        self.ap = torch.nn.AvgPool1d(num_points)

        if self.use_fc:
            self.fc_r = torch.nn.Linear(128, num_obj*4)  # quaternion
            self.fc_t = torch.nn.Linear(128, num_obj*3)  # translation
            self.fc_c = torch.nn.Linear(128, num_obj*1)  # scale
        else:
            self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1)  # quaternion
            self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1)  # translation
            self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1)  # scale

    def forward(self, psp_feat, x, choose, cat_id):

        # psp_feat = self.psp(img)[-1]

        bs = psp_feat.size()[0]
        di = psp_feat.size()[1]
        emb = psp_feat.view(bs, di, -1)
        choose = choose.unsqueeze(1).repeat(1, di, 1)
        emb = torch.gather(emb, 2, choose).contiguous()

        # x = point.transpose(2, 1).contiguous()
        # x = F.relu(self.conv1(x))
        emb = F.relu(self.e_conv1(emb))
        pointfeat_1 = torch.cat((x, emb), dim=1)

        x = F.relu(self.conv2(x))
        emb = F.relu(self.e_conv2(emb))
        pointfeat_2 = torch.cat((x, emb), dim=1)

        x = F.relu(self.conv5(pointfeat_2))
        x = F.relu(self.conv6(x))

        ap_x = self.ap1(x)
        ap_x = ap_x.view(-1, 1024, 1).repeat(1, 1, self.num_points)

        # 128 + 256 + 1024
        pose_feat = torch.cat([pointfeat_1, pointfeat_2, ap_x], 1)

        rx = F.relu(self.conv1_r(pose_feat))
        tx = F.relu(self.conv1_t(pose_feat))
        cx = F.relu(self.conv1_c(pose_feat))

        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))
        cx = F.relu(self.conv2_c(cx))


        if self.max_point:

            rx = torch.max(rx, dim=-1, keepdim=True)[0]
            tx = torch.max(tx, dim=-1, keepdim=True)[0]
            cx = torch.max(cx, dim=-1, keepdim=True)[0]
            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))

            rx = self.conv4_r(rx).view(bs, self.num_obj, 4)
            tx = self.conv4_t(tx).view(bs, self.num_obj, 3)
            cx = self.conv4_c(cx).view(bs, self.num_obj, 1)

            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]

            out_rx = out_rx.contiguous()
            out_tx = out_tx.contiguous()
            out_cx = out_cx.contiguous()

            out_rx = F.normalize(out_rx, dim=-1)
            out_cx = F.relu(out_cx)
        
        elif self.use_fc:
            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))

            rx = self.ap(rx).squeeze(-1)
            tx = self.ap(tx).squeeze(-1)
            cx = self.ap(cx).squeeze(-1)

            rx = self.fc_r(rx).view(bs, self.num_obj, 4)
            tx = self.fc_t(tx).view(bs, self.num_obj, 3)
            cx = self.fc_c(cx).view(bs, self.num_obj, 1)
            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]
            out_rx = F.normalize(out_rx)
            out_tx = torch.sigmoid(out_tx)
            out_cx = F.relu(out_cx)

        else:
            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))

            rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points)
            tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points)
            cx = self.conv4_c(cx).view(bs, self.num_obj, 1, self.num_points)
            
            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]

            out_rx = out_rx.contiguous().transpose(2, 1).contiguous()
            out_tx = out_tx.contiguous().transpose(2, 1).contiguous()
            out_cx = out_cx.contiguous().transpose(2, 1).contiguous()

            out_rx = F.normalize(out_rx, dim=-1)
            # out_tx = torch.sigmoid(out_tx)
            out_cx = F.relu(out_cx)

            out_rx = out_rx.mean(dim=1)
            out_tx = out_tx.mean(dim=1)
            out_cx = out_cx.mean(dim=1)

        return out_cx, out_tx, out_rx


class PointLocHeadNet(nn.Module):
    def __init__(self, num_points, num_obj, with_rgb=True, geo_dim=2, max_point=False):
        super(PointLocHeadNet, self).__init__()

        self.num_points = num_points
        self.num_obj = num_obj
        self.with_rgb = with_rgb
        self.max_point = max_point
        
        # self.conv1 = torch.nn.Conv1d(geo_dim, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)

        # 
        self.nocs_conv1 = torch.nn.Conv1d(3, 64, 1)
        self.nocs_conv2 = torch.nn.Conv1d(64, 128, 1)


        self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
        self.e_conv2 = torch.nn.Conv1d(64, 128, 1)

        self.conv5 = torch.nn.Conv1d(384, 512, 1)
        self.conv6 = torch.nn.Conv1d(512, 1024, 1)

        self.conv1_r = torch.nn.Conv1d(1600, 512, 1)
        self.conv1_t = torch.nn.Conv1d(1600, 512, 1)
        self.conv1_c = torch.nn.Conv1d(1600, 512, 1)

        self.conv2_r = torch.nn.Conv1d(512, 256, 1)
        self.conv2_t = torch.nn.Conv1d(512, 256, 1)
        self.conv2_c = torch.nn.Conv1d(512, 256, 1)

        self.conv3_r = torch.nn.Conv1d(256, 128, 1)
        self.conv3_t = torch.nn.Conv1d(256, 128, 1)
        self.conv3_c = torch.nn.Conv1d(256, 128, 1)

        self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1)  # quaternion
        self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1)  # translation
        self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1)  # scale

        self.ap1 = torch.nn.AvgPool1d(num_points)


    def forward(self, psp_feat, x, nocs, choose, cat_id):
        if self.num_obj == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)

        bs = psp_feat.size()[0]
        di = psp_feat.size()[1]

        emb = psp_feat.view(bs, di, -1)
        choose = choose.unsqueeze(1).repeat(1, di, 1)
        emb = torch.gather(emb, 2, choose).contiguous()

        # x = point.transpose(2, 1).contiguous()
        # x = F.relu(self.conv1(x))
        nocs = F.relu(self.nocs_conv1(nocs))
        emb = F.relu(self.e_conv1(emb)) # 64
        pointfeat_1 = torch.cat((x, emb, nocs), dim=1)

        x = F.relu(self.conv2(x))
        emb = F.relu(self.e_conv2(emb))
        nocs = F.relu(self.nocs_conv2(nocs))
        pointfeat_2 = torch.cat((x, emb, nocs), dim=1)

        x = F.relu(self.conv5(pointfeat_2))
        x = F.relu(self.conv6(x))

        ap_x = self.ap1(x)
        ap_x = ap_x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
        
        # 128+64 + 256+128 + 1024
        pose_feat = torch.cat([pointfeat_1, pointfeat_2, ap_x], 1)


        rx = F.relu(self.conv1_r(pose_feat))
        tx = F.relu(self.conv1_t(pose_feat))
        cx = F.relu(self.conv1_c(pose_feat))

        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))
        cx = F.relu(self.conv2_c(cx))
        
        if self.max_point:

            rx = torch.max(rx, dim=-1, keepdim=True)[0]
            tx = torch.max(tx, dim=-1, keepdim=True)[0]
            cx = torch.max(cx, dim=-1, keepdim=True)[0]
            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))

            rx = self.conv4_r(rx).view(bs, self.num_obj, 4)
            tx = self.conv4_t(tx).view(bs, self.num_obj, 3)
            cx = self.conv4_c(cx).view(bs, self.num_obj, 1)
            
            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]

            out_rx = out_rx.contiguous()
            out_tx = out_tx.contiguous()
            out_cx = out_cx.contiguous()

            out_rx = F.normalize(out_rx, dim=-1)
            out_cx = F.relu(out_cx)

        else:
            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))


            rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points)
            tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points)
            cx = self.conv4_c(cx).view(bs, self.num_obj, 1, self.num_points)
            
            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]

            out_rx = out_rx.contiguous().transpose(2, 1).contiguous()
            out_tx = out_tx.contiguous().transpose(2, 1).contiguous()
            out_cx = out_cx.contiguous().transpose(2, 1).contiguous()

            out_rx = F.normalize(out_rx, dim=-1)
            out_cx = F.relu(out_cx)

            out_rx = out_rx.mean(dim=1)
            out_tx = out_tx.mean(dim=1)
            out_cx = out_cx.mean(dim=1)
        return out_cx, out_tx, out_rx

class PoseHeadNet(nn.Module):
    def __init__(self, in_channels=512, num_layers=2, num_filters=256, kernel_size=3, output_dim=3, freeze=False, img_size=256,
                 with_bias_end=True):
        super(PoseHeadNet, self).__init__()

        self.freeze = freeze

        if kernel_size == 3:
            padding = 1
        elif kernel_size == 2:
            padding = 0

        self.num_filter = num_filters

        self.features = nn.ModuleList()
        for i in range(num_layers):
            _in_channels = in_channels if i == 0 else num_filters
            self.features.append(nn.Conv2d(_in_channels, num_filters, kernel_size=kernel_size, stride=2, padding=padding, bias=False))
            # self.features.append(nn.BatchNorm2d(num_filters))
            self.features.append(nn.ReLU(inplace=True))
        
        self.linears = nn.ModuleList()
        self.feat_size = img_size // 32
        self.linears.append(nn.Linear(num_filters * self.feat_size * self.feat_size, num_filters))
        self.linears.append(nn.ReLU(inplace=True))
        self.linears.append(nn.Linear(num_filters, num_filters))
        self.linears.append(nn.ReLU(inplace=True))

        self.scale = nn.Linear(num_filters, 1)
        self.trans = nn.Linear(num_filters, 3)
        self.rot = nn.Linear(num_filters, 4)

        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         nn.init.normal_(m.weight, mean=0, std=0.001)
        #         if with_bias_end and (m.bias is not None):
        #             nn.init.constant_(m.bias, 0)
        #     elif isinstance(m, nn.BatchNorm2d):
        #         nn.init.constant_(m.weight, 1)
        #         nn.init.constant_(m.bias, 0)
        #     elif isinstance(m, nn.ConvTranspose2d):
        #         nn.init.normal_(m.weight, mean=0, std=0.001)
        #     elif isinstance(m, nn.Linear):
        #         nn.init.normal_(m.weight, mean=0, std=0.001)

    def forward(self, x):
        if self.freeze:
            with torch.no_grad():
                for i, l in enumerate(self.features):
                    x = l(x)
                x = x.view(-1, self.num_filter * self.feat_size * self.feat_size)
                for i, l in enumerate(self.linears):
                    x = l(x)
                return x.detach()
        else:
            for i, l in enumerate(self.features):
                x = l(x)
            x = x.view(-1, self.num_filter * self.feat_size * self.feat_size)
            for i, l in enumerate(self.linears):
                x = l(x)
            scale = self.scale(x)
            scale = F.relu(scale) + 1E-12
            trans = self.trans(x)
            rot = self.rot(x)
            rot = F.normalize(rot, dim=-1)
            return scale, trans, rot




class PositionalEncoder(nn.Module):
    def __init__(self, input_dim, max_freq_log2, N_freqs,
                       log_sampling=True, include_input=True,
                       periodic_fns=(torch.sin, )):
        '''
        :param input_dim: dimension of input to be embedded
        :param max_freq_log2: log2 of max freq; min freq is 1 by default
        :param N_freqs: number of frequency bands
        :param log_sampling: if True, frequency bands are linerly sampled in log-space
        :param include_input: if True, raw input is included in the embedding
        :param periodic_fns: periodic functions used to embed input
        '''
        super().__init__()

        self.input_dim = input_dim
        self.include_input = include_input
        self.periodic_fns = periodic_fns

        self.out_dim = 0
        if self.include_input:
            self.out_dim += self.input_dim

        self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)

        if log_sampling:
            self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
        else:
            self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)

        self.freq_bands = self.freq_bands.numpy().tolist()

    def forward(self, input):
        '''
        :param input: tensor of shape [..., self.input_dim]
        :return: tensor of shape [..., self.out_dim]
        '''
        assert (input.shape[-1] == self.input_dim)

        out = []
        if self.include_input:
            out.append(input)

        for i in range(len(self.freq_bands)):
            freq = self.freq_bands[i]
            for p_fn in self.periodic_fns:
                out.append(p_fn(input * freq))
        out = torch.cat(out, dim=-1)

        assert (out.shape[-1] == self.out_dim)
        return out


class DenseFusion(nn.Module):
    def __init__(self, num_points):
        super(DenseFusion, self).__init__()
        self.conv2_rgb = torch.nn.Conv1d(64, 256, 1)
        self.conv2_cld = torch.nn.Conv1d(64, 256, 1)

        self.conv3 = torch.nn.Conv1d(128, 512, 1)
        self.conv4 = torch.nn.Conv1d(512, 1024, 1)

        self.ap1 = torch.nn.AvgPool1d(num_points)

    def forward(self, rgb_emb, cld_emb):
        bs, _, n_pts = cld_emb.size()
        feat_1 = torch.cat((rgb_emb, cld_emb), dim=1)
        rgb = F.relu(self.conv2_rgb(rgb_emb))
        cld = F.relu(self.conv2_cld(cld_emb))

        feat_2 = torch.cat((rgb, cld), dim=1)

        rgbd = F.relu(self.conv3(feat_1))
        rgbd = F.relu(self.conv4(rgbd))

        ap_x = self.ap1(rgbd)

        ap_x = ap_x.view(-1, 1024, 1).repeat(1, 1, n_pts)
        return torch.cat([feat_1, feat_2, ap_x], 1) # 128 + 512 + 1024 = 1664

class PoseNet(nn.Module):
    def __init__(self, opts):
        super(PoseNet, self).__init__()
        self.opts = opts
        self.use_point_reg = opts.use_point_reg
        self.use_nocs_map = opts.use_nocs_map
        num_obj = 6 if opts.select_class == 'all' else 1
        self.num_obj = num_obj
        self.encoder = PSPNet(bins=(1, 2, 3, 6), backend='resnet18')
        self.deform_head = DeformNet(n_cat=num_obj, imp=opts.implict, depth_input=not opts.use_rgb)   
        if self.use_point_reg:
            self.pose_head = PointPoseHeadNet(opts.n_pts, num_obj, use_fc=opts.use_fc, 
                    max_point=opts.max_point)
        elif self.use_nocs_map:
            self.pose_head = PointLocHeadNet(opts.n_pts, num_obj)
        else:
            self.pose_head = PoseHeadNet(img_size=opts.img_size)
    
    def forward(self, points, img, choose, cat_id, prior):
        if self.num_obj == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)
        outputs = {}
        bs = img.shape[0]
        img_feat, psp_feat = self.encoder(img) # B,512,8,8
        assign_mat, deltas, point_feat, assign_feat = self.deform_head(points, psp_feat, choose, cat_id, prior)
        if self.use_point_reg:
            outputs['pose'] = self.pose_head(psp_feat, point_feat, choose, cat_id)
        elif self.use_nocs_map:
            # if not self.eval_mode:
            inst_shape = prior + deltas
            soft_assign = F.softmax(assign_mat, dim=2)
            coords = torch.bmm(soft_assign, inst_shape)  # bs x n_pts x 3
            coords = coords.transpose(2, 1).contiguous().detach()
            outputs['pose'] = self.pose_head(psp_feat, point_feat.detach(), coords, choose, cat_id)
        else:
            outputs['pose'] = self.pose_head(img_feat)
        outputs['assign_mat'] = assign_mat
        outputs['deltas'] = deltas
        return outputs


class PoseHeadNetV2(nn.Module):
    def __init__(self, num_points, num_obj, in_dim=1152, with_nocs=False, max_point=False):
        super(PoseHeadNetV2, self).__init__()

        self.num_points = num_points
        self.num_obj = num_obj
        self.with_nocs = with_nocs
        self.max_point = max_point
        if self.with_nocs:
            self.nocs_conv = nn.Sequential(
                                nn.Conv1d(3, 64, 1),
                                nn.ReLU(),
                                nn.Conv1d(64, 64, 1),
                                nn.ReLU(),
                                nn.Conv1d(64, 128, 1),
                                nn.ReLU()
                        )

            self.conv1_r = torch.nn.Conv1d(in_dim+128, 512, 1)
            self.conv1_t = torch.nn.Conv1d(in_dim+128, 512, 1)
            self.conv1_c = torch.nn.Conv1d(in_dim+128, 512, 1)
        else:
            self.conv1_r = torch.nn.Conv1d(in_dim, 512, 1)
            self.conv1_t = torch.nn.Conv1d(in_dim, 512, 1)
            self.conv1_c = torch.nn.Conv1d(in_dim, 512, 1)

        self.conv2_r = torch.nn.Conv1d(512, 256, 1)
        self.conv2_t = torch.nn.Conv1d(512, 256, 1)
        self.conv2_c = torch.nn.Conv1d(512, 256, 1)

        self.conv3_r = torch.nn.Conv1d(256, 128, 1)
        self.conv3_t = torch.nn.Conv1d(256, 128, 1)
        self.conv3_c = torch.nn.Conv1d(256, 128, 1)

        self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1)  # quaternion
        self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1)  # translation
        self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1)  # scale

        self.ap1 = torch.nn.AvgPool1d(num_points)


    def forward(self, psp_feat, point_feat, choose, cat_id, nocs=None):
        if self.num_obj == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)

        bs = point_feat.size()[0]
        di = point_feat.size()[1]

        if self.with_nocs:
            assert nocs is not None
            nocs = self.nocs_conv(nocs)
            pose_feat = torch.cat([point_feat, nocs], dim=1)
        else:
            pose_feat = point_feat.clone()

        rx = F.relu(self.conv1_r(pose_feat))
        tx = F.relu(self.conv1_t(pose_feat))
        cx = F.relu(self.conv1_c(pose_feat))


        
        if self.max_point:
            rx = self.conv2_r(rx)
            tx = self.conv2_t(tx)
            cx = self.conv2_c(cx)

            rx = torch.max(rx, dim=-1, keepdim=True)[0]
            tx = torch.max(tx, dim=-1, keepdim=True)[0]
            cx = torch.max(cx, dim=-1, keepdim=True)[0]
            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))

            rx = self.conv4_r(rx).view(bs, self.num_obj, 4)
            tx = self.conv4_t(tx).view(bs, self.num_obj, 3)
            cx = self.conv4_c(cx).view(bs, self.num_obj, 1)
            
            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]

            out_rx = out_rx.contiguous()
            out_tx = out_tx.contiguous()
            out_cx = out_cx.contiguous()

            out_rx = F.normalize(out_rx, dim=-1)
            out_cx = F.relu(out_cx)

        else:
            rx = F.relu(self.conv2_r(rx))
            tx = F.relu(self.conv2_t(tx))
            cx = F.relu(self.conv2_c(cx))

            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))


            rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points)
            tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points)
            cx = self.conv4_c(cx).view(bs, self.num_obj, 1, self.num_points)
            
            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]

            out_rx = out_rx.contiguous().transpose(2, 1).contiguous()
            out_tx = out_tx.contiguous().transpose(2, 1).contiguous()
            out_cx = out_cx.contiguous().transpose(2, 1).contiguous()

            out_rx = F.normalize(out_rx, dim=-1)
            out_cx = F.relu(out_cx)

            out_rx = out_rx.mean(dim=1)
            out_tx = out_tx.mean(dim=1)
            out_cx = out_cx.mean(dim=1)
        return out_cx, out_tx, out_rx

class DeformNetV2(nn.Module):
    def __init__(self, n_cat=6, nv_prior=1024, imp=False, depth_input=True):
        super(DeformNetV2, self).__init__()
        self.n_cat = n_cat
        self.depth_input = depth_input
        in_dim = 3 if self.depth_input else 2
        self.imp = imp
        self.instance_color = nn.Sequential(
            nn.Conv1d(32, 64, 1),
            nn.ReLU(),
        )
        self.instance_geometry = nn.Sequential(
            nn.Conv1d(in_dim, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
        )
        self.instance_global = nn.Sequential(
            nn.Conv1d(128, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.ReLU()
        )
            
        self.ap = nn.AdaptiveAvgPool1d(1)

        self.category_local = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
        )
        self.category_global = nn.Sequential(
            nn.Conv1d(64, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        if self.imp:
            self.pos_enc = PositionalEncoder(input_dim=3, max_freq_log2=9,
                                    N_freqs=10)

            self.deformation = nn.Sequential(
                nn.Linear(2112+3, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 3),
            )
            # self.assignment = nn.Sequential(
            #     nn.Linear(2176+3, 512),
            #     nn.ReLU(),
            #     nn.Linear(512, 256),
            #     nn.ReLU(),
            #     nn.Linear(256, nv_prior),
            # )
        else:
            self.deformation = nn.Sequential(
                nn.Conv1d(2112, 512, 1),
                nn.ReLU(),
                nn.Conv1d(512, 256, 1),
                nn.ReLU(),
                nn.Conv1d(256, n_cat*3, 1),
            )
        self.assignment = nn.Sequential(
            nn.Conv1d(2176, 512, 1),
            nn.ReLU(),
            nn.Conv1d(512, 256, 1),
            nn.ReLU(),
            nn.Conv1d(256, n_cat*nv_prior, 1),
        )
        # Initialize weights to be small so initial deformations aren't so big
        self.deformation[4].weight.data.normal_(0, 0.0001)


    def forward(self, points, out_img, choose, cat_id, prior):
        """
        Args:
            points: bs x n_pts x 3
            img: bs x 3 x H x W
            choose: bs x n_pts
            cat_id: bs
            prior: bs x nv x 3

        Returns:
            assign_mat: bs x n_pts x nv
            inst_shape: bs x nv x 3
            deltas: bs x nv x 3
            log_assign: bs x n_pts x nv, for numerical stability

        """
        if self.n_cat == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)
        bs, n_pts = points.size()[:2]
        nv = prior.size()[1]
        # instance-specific features
        if self.imp:
            prior_emb = self.pos_enc(prior)
            prior_emb = prior_emb.permute(0, 2, 1)
        points = points.permute(0, 2, 1)
        if self.depth_input:
            points_emb = self.instance_geometry(points)
        else:
            points_emb = self.instance_geometry(points[:, :2])
        di = out_img.size()[1]
        emb = out_img.view(bs, di, -1)
        choose = choose.unsqueeze(1).repeat(1, di, 1)
        emb = torch.gather(emb, 2, choose).contiguous()
        emb = self.instance_color(emb)
        inst_local = torch.cat((points_emb, emb), dim=1)     # bs x 128 x n_pts
        inst_global_p = self.instance_global(inst_local)    # bs x 1024 x n_pts
        point_feat = torch.cat([inst_local, inst_global_p], dim=1)
        inst_global = self.ap(inst_global_p) # bs x 1024 x 1
        # category-specific features
        cat_prior = prior.permute(0, 2, 1)
        cat_local = self.category_local(cat_prior)    # bs x 64 x n_pts
        cat_global = self.category_global(cat_local)  # bs x 1024 x 1
        # assignemnt matrix
        assign_feat = torch.cat((inst_local, inst_global.repeat(1, 1, n_pts), cat_global.repeat(1, 1, n_pts)), dim=1)     # bs x 2176 x n_pts
        # deformation field
        deform_feat = torch.cat((cat_local, cat_global.repeat(1, 1, nv), inst_global.repeat(1, 1, nv)), dim=1)       # bs x 2112 x n_pts
        index = cat_id + torch.arange(bs, dtype=torch.long, device=cat_id.device)* self.n_cat
        if self.imp:
            deform_feat = torch.cat([deform_feat, cat_prior], dim=1)
            # deform_feat = torch.cat([deform_feat, prior_emb], dim=1)
            deform_feat = deform_feat.permute(0, 2, 1).contiguous()
            deform_feat = deform_feat.reshape(-1, deform_feat.shape[-1])
            deltas = self.deformation(deform_feat)
            deltas = deltas.reshape(bs, -1, 3).contiguous() # bs, nv, 3
        else:
            deltas = self.deformation(deform_feat)
            deltas = deltas.view(-1, 3, nv).contiguous()   # bs, nc*3, nv -> bs*nc, 3, nv
            deltas = torch.index_select(deltas, 0, index)   # bs x 3 x nv
            deltas = deltas.permute(0, 2, 1).contiguous()   # bs x nv x 3 # bs x nv x 3
        assign_mat = self.assignment(assign_feat)
        assign_mat = assign_mat.view(-1, nv, n_pts).contiguous()   # bs, nc*nv, n_pts -> bs*nc, nv, n_pts
        assign_mat = torch.index_select(assign_mat, 0, index)   # bs x nv x n_pts
        assign_mat = assign_mat.permute(0, 2, 1).contiguous()    # bs x n_pts x nv
        return assign_mat, deltas, point_feat


class PoseNetV2(nn.Module):
    def __init__(self, opts):
        super(PoseNetV2, self).__init__()
        self.opts = opts
        self.use_nocs_map = opts.use_nocs_map
        num_obj = 6 if opts.select_class == 'all' else 1
        self.num_obj = num_obj
        self.encoder = PSPNet(bins=(1, 2, 3, 6), backend='resnet18')
        self.deform_head = DeformNetV2(n_cat=num_obj, imp=opts.implict, depth_input=not opts.use_rgb)   

        self.pose_head = PoseHeadNetV2(opts.n_pts, num_obj, with_nocs=opts.use_nocs_map, max_point=opts.max_point)
    
    def forward(self, points, img, choose, cat_id, prior):
        if self.num_obj == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)
        outputs = {}
        bs = img.shape[0]
        img_feat, psp_feat = self.encoder(img) # B,512,8,8
        assign_mat, deltas, point_feat = self.deform_head(points, psp_feat, choose, cat_id, prior)
        if self.use_nocs_map:
            # if not self.eval_mode:
            inst_shape = prior + deltas
            soft_assign = F.softmax(assign_mat, dim=2)
            coords = torch.bmm(soft_assign, inst_shape)  # bs x n_pts x 3
            coords = coords.transpose(2, 1).contiguous().detach()
            outputs['pose'] = self.pose_head(point_feat, choose, cat_id, coords)
        else:
            outputs['pose'] = self.pose_head(point_feat, choose, cat_id)
        outputs['assign_mat'] = assign_mat
        outputs['deltas'] = deltas
        return outputs


class PointNetfeat(nn.Module):
    def __init__(self, npoint = 2500, nlatent = 512):
        """Encoder"""

        super(PointNetfeat, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, nlatent, 1)
        self.lin = nn.Linear(nlatent, nlatent)

        # self.bn1 = torch.nn.BatchNorm1d(64)
        # self.bn2 = torch.nn.BatchNorm1d(128)
        # self.bn3 = torch.nn.BatchNorm1d(nlatent)
        # self.bn4 = torch.nn.BatchNorm1d(nlatent)

        self.npoint = npoint
        self.nlatent = nlatent

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        x,_ = torch.max(x, 2)
        x = x.view(-1, self.nlatent)
        x = F.relu(self.lin(x))
        return x.unsqueeze(-1)


class PoseHeadNetV3(nn.Module):
    def __init__(self, num_points, num_obj, with_nocs=False, max_point=False):
        super(PoseHeadNetV3, self).__init__()

        self.num_points = num_points
        self.num_obj = num_obj
        self.with_nocs = with_nocs
        self.max_point = max_point
        # self.rgb_conv = nn.Sequential(
        #                     nn.Conv1d(32, 64, 1),
        #                     nn.ReLU(),
        #                 )
        self.nocs_conv = nn.Sequential(
                            nn.Conv1d(6, 64, 1),
                            nn.ReLU(),
                            nn.Conv1d(64, 64, 1),
                            nn.ReLU(),
                            nn.Conv1d(64, 128, 1),
                            nn.ReLU()
                    )

        self.nocs_global = nn.Sequential(
            nn.Conv1d(128, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.ReLU()
        )

        self.conv1_r = torch.nn.Conv1d(1152, 512, 1)
        self.conv1_t = torch.nn.Conv1d(1152, 512, 1)
        self.conv1_c = torch.nn.Conv1d(1152, 512, 1)

        self.conv2_r = torch.nn.Conv1d(512, 256, 1)
        self.conv2_t = torch.nn.Conv1d(512, 256, 1)
        self.conv2_c = torch.nn.Conv1d(512, 256, 1)

        self.conv3_r = torch.nn.Conv1d(256, 128, 1)
        self.conv3_t = torch.nn.Conv1d(256, 128, 1)
        self.conv3_c = torch.nn.Conv1d(256, 128, 1)

        self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1)  # quaternion
        self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1)  # translation
        self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1)  # scale

        self.ap1 = torch.nn.AvgPool1d(num_points)


    def forward(self, psp_feat, point_feat, choose, cat_id, nocs):
        if self.num_obj == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)

        bs = psp_feat.size()[0]
        di = psp_feat.size()[1]
        # rgb_feat = psp_feat.view(bs, di, -1)
        # choose = choose.unsqueeze(1).repeat(1, di, 1)
        # rgb_feat = torch.gather(rgb_feat, 2, choose).contiguous()
        # rgb_feat = self.rgb_conv(rgb_feat)

        nocs_feat_local = self.nocs_conv(nocs)
        nocs_feat_global = self.nocs_global(nocs_feat_local)
        pose_feat = torch.cat([nocs_feat_local, nocs_feat_global], dim=1)

        rx = F.relu(self.conv1_r(pose_feat))
        tx = F.relu(self.conv1_t(pose_feat))
        cx = F.relu(self.conv1_c(pose_feat))

        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))
        cx = F.relu(self.conv2_c(cx))
        
        if self.max_point:

            rx = torch.max(rx, dim=-1, keepdim=True)[0]
            tx = torch.max(tx, dim=-1, keepdim=True)[0]
            cx = torch.max(cx, dim=-1, keepdim=True)[0]
            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))

            rx = self.conv4_r(rx).view(bs, self.num_obj, 4)
            tx = self.conv4_t(tx).view(bs, self.num_obj, 3)
            cx = self.conv4_c(cx).view(bs, self.num_obj, 1)
            
            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]

            out_rx = out_rx.contiguous()
            out_tx = out_tx.contiguous()
            out_cx = out_cx.contiguous()

            out_rx = F.normalize(out_rx, dim=-1)
            out_cx = F.relu(out_cx)

        else:
            rx = F.relu(self.conv3_r(rx))
            tx = F.relu(self.conv3_t(tx))
            cx = F.relu(self.conv3_c(cx))


            rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points)
            tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points)
            cx = self.conv4_c(cx).view(bs, self.num_obj, 1, self.num_points)
            
            indices = torch.arange(bs).cuda()
            out_rx = rx[indices, cat_id]
            out_tx = tx[indices, cat_id]
            out_cx = cx[indices, cat_id]

            out_rx = out_rx.contiguous().transpose(2, 1).contiguous()
            out_tx = out_tx.contiguous().transpose(2, 1).contiguous()
            out_cx = out_cx.contiguous().transpose(2, 1).contiguous()

            out_cx = F.relu(out_cx)

            out_rx = out_rx.mean(dim=1)
            out_rx = F.normalize(out_rx, dim=-1)
            out_tx = out_tx.mean(dim=1)
            out_cx = out_cx.mean(dim=1)
        return out_cx, out_tx, out_rx

class DeformNetV3(nn.Module):
    def __init__(self, n_cat=6, nv_prior=1024, imp=False, use_fold=False, depth_input=True, max_point=False):
        super(DeformNetV3, self).__init__()
        self.n_cat = n_cat
        self.depth_input = depth_input
        self.max_point = max_point
        self.use_fold = use_fold
        in_dim = 3 if self.depth_input else 2
        self.imp = imp
        self.instance_color = nn.Sequential(
            nn.Conv1d(32, 64, 1),
            nn.ReLU(),
        )
        self.instance_geometry = nn.Sequential(
            nn.Conv1d(in_dim, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
        )
        self.instance_global = nn.Sequential(
            nn.Conv1d(128, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.ReLU()
        )
            
        self.ap = nn.AdaptiveAvgPool1d(1)

        self.category_local = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
        )
        self.category_global = nn.Sequential(
            nn.Conv1d(64, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.ReLU(),
        )

        if self.imp:
            self.pos_enc = PositionalEncoder(input_dim=3, max_freq_log2=9,
                                    N_freqs=10)

            self.deformation = nn.Sequential(
                nn.Linear(2112+3, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 3),
            )
            self.nocs_prediction = nn.Sequential(
                nn.Linear(2176+3, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 3),
            )
        else:
            self.deformation = nn.Sequential(
                nn.Conv1d(2112, 512, 1),
                nn.ReLU(),
                nn.Conv1d(512, 256, 1),
                nn.ReLU(),
                nn.Conv1d(256, n_cat*3, 1),
            )
            self.nocs_prediction = nn.Sequential(
                nn.Conv1d(2176, 512, 1),
                nn.ReLU(),
                nn.Conv1d(512, 256, 1),
                nn.ReLU(),
                nn.Conv1d(256, n_cat*3, 1),
            )
        if self.use_fold:
            self.nocs_prediction = FoldingNet(2176+3)
        # Initialize weights to be small so initial deformations aren't so big
        self.deformation[4].weight.data.normal_(0, 0.0001)

    def forward(self, points, out_img, choose, cat_id, prior):
        """
        Args:
            points: bs x n_pts x 3
            img: bs x 3 x H x W
            choose: bs x n_pts
            cat_id: bs
            prior: bs x nv x 3

        Returns:
            assign_mat: bs x n_pts x nv
            inst_shape: bs x nv x 3
            deltas: bs x nv x 3
            log_assign: bs x n_pts x nv, for numerical stability

        """
        if self.n_cat == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)
        bs, n_pts = points.size()[:2]
        nv = prior.size()[1]
        # instance-specific features
        points = points.permute(0, 2, 1)
        if self.depth_input:
            p_emb = self.instance_geometry(points)
        else:
            p_emb = self.instance_geometry(points[:, :2])
        di = out_img.size()[1]
        emb = out_img.view(bs, di, -1)
        choose = choose.unsqueeze(1).repeat(1, di, 1)
        emb = torch.gather(emb, 2, choose).contiguous()
        emb = self.instance_color(emb)
        inst_local = torch.cat((p_emb, emb), dim=1)     # bs x 128 x n_pts
        inst_global_p = self.instance_global(inst_local)    # bs x 1024 x n_pts
        inst_global = self.ap(inst_global_p) # bs x 1024 x 1
        point_feat = torch.cat([inst_local, inst_global_p], dim=1)
        # point_feat = torch.cat([inst_local, inst_global.repeat(1, 1, n_pts)], dim=1)

        # category-specific features
        cat_prior = prior.permute(0, 2, 1)
        cat_local = self.category_local(cat_prior)    # bs x 64 x n_pts
        cat_global_p = self.category_global(cat_local)  # bs x 1024 x 1
        cat_global = self.ap(cat_global_p) # bs x 1024 x 1
        # assignemnt matrix
        assign_feat = torch.cat((inst_local, inst_global.repeat(1, 1, n_pts), cat_global.repeat(1, 1, n_pts)), dim=1)     # bs x 2176 x n_pts
        # deformation field
        deform_feat = torch.cat((cat_local, cat_global.repeat(1, 1, nv), inst_global.repeat(1, 1, nv)), dim=1)       # bs x 2112 x n_pts
        index = cat_id + torch.arange(bs, dtype=torch.long, device=cat_id.device)* self.n_cat
        if self.imp:
            # cat_prior = self.pos_encoding_sin_wave(cat_prior)
            # points = self.pos_encoding_sin_wave(points)
            deform_feat = torch.cat([deform_feat, cat_prior], dim=1)
            deform_feat = deform_feat.permute(0, 2, 1).contiguous()
            deform_feat = deform_feat.reshape(-1, deform_feat.shape[-1])
            deltas = self.deformation(deform_feat)
            deltas = deltas.reshape(bs, -1, 3).contiguous() # bs, nv, 3
            assign_feat = torch.cat([assign_feat, points], dim=1)
            if self.use_fold:
                nocs = self.nocs_prediction(assign_feat)
            else:
                assign_feat = assign_feat.permute(0, 2, 1).contiguous()
                assign_feat = assign_feat.reshape(-1, assign_feat.shape[-1])
                nocs = self.nocs_prediction(assign_feat)
                nocs = nocs.reshape(bs, -1, 3).contiguous() # bs, npt, 3
        else:
            deltas = self.deformation(deform_feat)
            deltas = deltas.view(-1, 3, nv).contiguous()   # bs, nc*3, nv -> bs*nc, 3, nv
            deltas = torch.index_select(deltas, 0, index)   # bs x 3 x nv
            deltas = deltas.permute(0, 2, 1).contiguous()   # bs x nv x 3
            nocs = self.nocs_prediction(assign_feat)
            nocs = nocs.view(-1, 3, n_pts).contiguous()   # bs, nc*3, n_pts -> bs*nc, nv, n_pts
            nocs = torch.index_select(nocs, 0, index)   # bs x 3 x n_pts
            nocs = nocs.permute(0, 2, 1).contiguous()    # bs x n_pts x 3
        

        return nocs, deltas, point_feat

    def pos_encoding_sin_wave(self, coor):
        # ref to https://arxiv.org/pdf/2003.08934v2.pdf
        D = 64 #
        # normal the coor into [-1, 1], batch wise
        normal_coor = 2 * ((coor - coor.min()) / (coor.max() - coor.min())) - 1 

        # define sin wave freq
        freqs = torch.arange(D, dtype=torch.float).cuda() 
        freqs = np.pi * (2**freqs)       

        freqs = freqs.view(*[1]*len(normal_coor.shape), -1) # 1 x 1 x 1 x D
        normal_coor = normal_coor.unsqueeze(-1) # B x 3 x N x 1
        k = normal_coor * freqs # B x 3 x N x D
        s = torch.sin(k) # B x 3 x N x D
        c = torch.cos(k) # B x 3 x N x D
        x = torch.cat([s,c], -1) # B x 3 x N x 2D
        pos = x.transpose(-1,-2).reshape(coor.shape[0], -1, coor.shape[-1]) # B 6D N
        # zero_pad = torch.zeros(x.size(0), 2, x.size(-1)).cuda()
        # pos = torch.cat([x, zero_pad], dim = 1)
        # pos = self.pos_embed_wave(x)
        return pos


class FoldingNet(nn.Module):
    def __init__(self, encoder_channel=1024, num_pred=1024):
        super(FoldingNet, self).__init__()
        self.num_pred = num_pred
        self.encoder_channel = 1024
        self.grid_size = int(pow(self.num_pred,0.5) + 0.5)

        self.conv = nn.Sequential(
            nn.Conv1d(encoder_channel, 1024, 1),
            nn.ReLU(inplace=True),
            nn.Conv1d(1024, 1024, 1)
        )

        self.folding1 = nn.Sequential(
            nn.Conv1d(1024 + 2, 512, 1),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, 1),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 3, 1),
        )

        self.folding2 = nn.Sequential(
            nn.Conv1d(1024 + 3, 512, 1),
            nn.ReLU(),
            nn.Conv1d(512, 512, 1),
            nn.ReLU(),
            nn.Conv1d(512, 3, 1),
        )

        a = torch.linspace(-0.5, 0.5, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(self.grid_size, self.grid_size).reshape(1, -1)
        b = torch.linspace(-0.5, 0.5, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(self.grid_size, self.grid_size).reshape(1, -1)
        self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2).cuda() # 1 2 N

    def forward(self, point_feat):
        feature = self.conv(point_feat)
        feature_global = torch.max(feature,dim=2,keepdim=False)[0] # B 1024
        # folding decoder
        fd1, fd2 = self.decoder(feature_global) # B N 3
        return fd2 # FoldingNet producing final result directly
        
    def decoder(self,x):
        num_sample = self.grid_size * self.grid_size
        bs = x.size(0)
        features = x.view(bs, self.encoder_channel, 1).expand(bs, self.encoder_channel, num_sample)
        seed = self.folding_seed.view(1, 2, num_sample).expand(bs, 2, num_sample).to(x.device)

        x = torch.cat([seed, features], dim=1)
        fd1 = self.folding1(x)
        x = torch.cat([fd1, features], dim=1)
        fd2 = self.folding2(x)

        return fd1.transpose(2,1).contiguous() , fd2.transpose(2,1).contiguous()


class PoseNetV3(nn.Module):
    def __init__(self, opts):
        super(PoseNetV3, self).__init__()
        self.opts = opts
        self.use_nocs_map = opts.use_nocs_map
        num_obj = 6 if opts.select_class == 'all' else 1
        self.num_obj = num_obj
        self.with_recon = opts.with_recon
        self.encoder = PSPNet(bins=(1, 2, 3, 6), backend='resnet18')
        if opts.version == 'v4':
            self.deform_head = DeformNetV4(n_cat=num_obj, imp=opts.implict, 
                    depth_input=not opts.use_rgb, max_point=opts.max_point)
            self.pose_head = PoseHeadNetV2(opts.n_pts, num_obj, in_dim=1792, with_nocs=opts.use_nocs_map, 
                max_point=opts.max_point)
        else:
            self.deform_head = DeformNetV3(n_cat=num_obj, imp=opts.implict, use_fold=opts.use_fold,
                    depth_input=not opts.use_rgb, max_point=opts.max_point)    
            self.pose_head = PoseHeadNetV2(opts.n_pts, num_obj, with_nocs=opts.use_nocs_map, 
                        max_point=opts.max_point)
    
    def forward(self, points, img, choose, cat_id, prior):
        if self.num_obj == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)
        outputs = {}
        bs = img.shape[0]
        img_feat, psp_feat = self.encoder(img) # B,512,8,8
        nocs, deltas, point_feat = self.deform_head(points, psp_feat, choose, cat_id, prior)
        if self.use_nocs_map:
            coords = nocs.detach().clone()
            coords = coords.permute(0, 2, 1).contiguous()
            outputs['pose'] = self.pose_head(psp_feat, point_feat, choose, cat_id, coords)
        else:
            outputs['pose'] = self.pose_head(psp_feat, point_feat, choose, cat_id)
        
        outputs['assign_mat'] = nocs
        outputs['deltas'] = deltas
        di = psp_feat.size(1)
        emb = psp_feat.view(bs, di, -1)
        choose = choose.unsqueeze(1).repeat(1, di, 1)
        emb = torch.gather(emb, 2, choose).contiguous()
        outputs['feat_map'] = psp_feat.reshape(bs, di, -1)
        outputs['feat_pix'] = emb
       
        return outputs

class DeformNetV4(nn.Module):
    def __init__(self, n_cat=6, nv_prior=1024, imp=False, depth_input=True,  max_point=False):
        super(DeformNetV4, self).__init__()
        self.n_cat = n_cat
        self.depth_input = depth_input
        self.max_point = max_point
        in_dim = 3 if self.depth_input else 2
        self.imp = imp
        self.instance_color = nn.Sequential(
            nn.Conv1d(32, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.ReLU(),
        )
        self.instance_geometry = GCN3D_segR()
            
        self.ap = nn.AdaptiveAvgPool1d(1)

        self.category_local = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
        )
        self.category_global = nn.Sequential(
            nn.Conv1d(64, 128, 1),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.ReLU(),
        )
        if self.imp:
            self.deformation = nn.Sequential(
                nn.Linear(2816+3, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 3),
            )
            self.nocs_prediction = nn.Sequential(
                nn.Linear(2816+3, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 3),
            )
        else:
            self.deformation = nn.Sequential(
                nn.Conv1d(2816, 512, 1),
                nn.ReLU(),
                nn.Conv1d(512, 256, 1),
                nn.ReLU(),
                nn.Conv1d(256, n_cat*3, 1),
            )
            self.nocs_prediction = nn.Sequential(
                nn.Conv1d(2816, 512, 1),
                nn.ReLU(),
                nn.Conv1d(512, 256, 1),
                nn.ReLU(),
                nn.Conv1d(256, n_cat*3, 1),
            )

        # Initialize weights to be small so initial deformations aren't so big
        self.deformation[4].weight.data.normal_(0, 0.0001)

    def forward(self, points, out_img, choose, cat_id, prior):
        """
        Args:
            points: bs x n_pts x 3
            img: bs x 3 x H x W
            choose: bs x n_pts
            cat_id: bs
            prior: bs x nv x 3

        Returns:
            assign_mat: bs x n_pts x nv
            inst_shape: bs x nv x 3
            deltas: bs x nv x 3
            log_assign: bs x n_pts x nv, for numerical stability

        """
        if self.n_cat == 1:
            cat_id = torch.zeros_like(cat_id, device=cat_id.device)
        bs, n_pts = points.size()[:2]
        nv = prior.size()[1]
        # instance-specific features
        di = out_img.size()[1]
        emb = out_img.view(bs, di, -1)
        choose = choose.unsqueeze(1).repeat(1, di, 1)
        emb = torch.gather(emb, 2, choose).contiguous()
        inst_local, inst_global_p = self.instance_geometry(points, emb)
        inst_global = torch.max(inst_global_p, dim=-1, keepdim=True)[0]
        # category-specific features
        cat_prior = prior.permute(0, 2, 1)
        cat_local = self.category_local(cat_prior)    # bs x 64 x n_pts
        cat_global_p = self.category_global(cat_local)  # bs x 1024 x 1
        cat_global = torch.max(cat_global_p, dim=-1, keepdim=True)[0]
        # assignemnt matrix
        assign_feat = torch.cat((inst_global_p, cat_global.repeat(1, 1, n_pts)), dim=1)     # bs x 2176 x n_pts
        # deformation field
        deform_feat = torch.cat((cat_global_p, inst_global.repeat(1, 1, nv)), dim=1)       # bs x 2112 x n_pts
        index = cat_id + torch.arange(bs, dtype=torch.long, device=cat_id.device) * self.n_cat
        
        # cat_prior = self.pos_encoding_sin_wave(cat_prior)
        # points = self.pos_encoding_sin_wave(points)
        if self.imp:
            deform_feat = torch.cat([deform_feat, cat_prior], dim=1)
            deform_feat = deform_feat.permute(0, 2, 1).contiguous()
            deform_feat = deform_feat.reshape(-1, deform_feat.shape[-1])
            deltas = self.deformation(deform_feat)
            deltas = deltas.reshape(bs, -1, 3).contiguous() # bs, nv, 3
            assign_feat = torch.cat([assign_feat, points.permute(0,2,1)], dim=1)
            assign_feat = assign_feat.permute(0, 2, 1).contiguous()
            assign_feat = assign_feat.reshape(-1, assign_feat.shape[-1])
            nocs = self.nocs_prediction(assign_feat)
            nocs = nocs.reshape(bs, -1, 3).contiguous() # bs, npt, 3
        else:
            deltas = self.deformation(deform_feat)
            deltas = deltas.view(-1, 3, nv).contiguous()   # bs, nc*3, nv -> bs*nc, 3, nv
            deltas = torch.index_select(deltas, 0, index)   # bs x 3 x nv
            deltas = deltas.permute(0, 2, 1).contiguous()   # bs x nv x 3
            nocs = self.nocs_prediction(assign_feat)
            nocs = nocs.view(-1, 3, n_pts).contiguous()   # bs, nc*3, n_pts -> bs*nc, nv, n_pts
            nocs = torch.index_select(nocs, 0, index)   # bs x 3 x n_pts
            nocs = nocs.permute(0, 2, 1).contiguous()    # bs x n_pts x 3
        
        return nocs, deltas, inst_global_p

class GCN3D_segR(nn.Module):
    def __init__(self, support_num= 7, neighbor_num= 10):
        super(GCN3D_segR, self).__init__()
        self.neighbor_num = neighbor_num

        self.conv_0 = gcn3d.Conv_surface(kernel_num= 64, support_num= support_num)
        self.conv_1 = gcn3d.Conv_layer(128, 128, support_num= support_num)
        self.pool_1 = gcn3d.Pool_layer(pooling_rate= 4, neighbor_num= 4)
        self.conv_2 = gcn3d.Conv_layer(128, 256, support_num= support_num)
        self.conv_3 = gcn3d.Conv_layer(256, 256, support_num= support_num)
        self.pool_2 = gcn3d.Pool_layer(pooling_rate= 4, neighbor_num= 4)
        self.conv_4 = gcn3d.Conv_layer(256, 512, support_num= support_num)
        
        self.rgb_conv_1 = nn.Sequential(
            nn.Conv1d(32, 64, 1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
        )
        self.rgb_conv_2 = nn.Sequential(
            nn.Conv1d(128, 256, 1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Conv1d(256, 256, 1),
            nn.ReLU(),
            nn.BatchNorm1d(256)
        )

        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(256)

        dim_fuse = sum([128, 128, 256, 256, 512, 512, 16])

    def forward(self,
                vertices: "tensor (bs, vetice_num, 3)", 
                rgb_f):
        """
        Return: (bs, vertice_num, class_num)
        """

        bs, vertice_num, _ = vertices.size()

        neighbor_index = gcn3d.get_neighbor_index(vertices, self.neighbor_num)
        # ss = time.time()
        fm_0 = F.relu(self.conv_0(neighbor_index, vertices), inplace= True)

        rgb_f = self.rgb_conv_1(rgb_f)
        fm_0 = torch.cat([fm_0, rgb_f.permute(0, 2, 1)], dim=-1)

        fm_1 = F.relu(self.bn1(self.conv_1(neighbor_index, vertices, fm_0).transpose(1,2)).transpose(1,2), inplace= True)
        v_pool_1, fm_pool_1 = self.pool_1(vertices, fm_1)
        # neighbor_index = gcn3d.get_neighbor_index(v_pool_1, self.neighbor_num)
        neighbor_index = gcn3d.get_neighbor_index(v_pool_1,
                                                  min(self.neighbor_num, v_pool_1.shape[1] // 8))
        fm_2 = F.relu(self.bn2(self.conv_2(neighbor_index, v_pool_1, fm_pool_1).transpose(1,2)).transpose(1,2), inplace= True)
        fm_3 = F.relu(self.bn3(self.conv_3(neighbor_index, v_pool_1, fm_2).transpose(1,2)).transpose(1,2), inplace= True)
        v_pool_2, fm_pool_2 = self.pool_2(v_pool_1, fm_3)
        # neighbor_index = gcn3d.get_neighbor_index(v_pool_2, self.neighbor_num)
        neighbor_index = gcn3d.get_neighbor_index(v_pool_2, min(self.neighbor_num,
                                                                     v_pool_2.shape[1] // 8))
        fm_4 = self.conv_4(neighbor_index, v_pool_2, fm_pool_2)
        f_global = fm_4.max(1)[0] #(bs, f)

        nearest_pool_1 = gcn3d.get_nearest_index(vertices, v_pool_1)
        nearest_pool_2 = gcn3d.get_nearest_index(vertices, v_pool_2)
        fm_2 = gcn3d.indexing_neighbor(fm_2, nearest_pool_1).squeeze(2)
        fm_3 = gcn3d.indexing_neighbor(fm_3, nearest_pool_1).squeeze(2)
        fm_4 = gcn3d.indexing_neighbor(fm_4, nearest_pool_2).squeeze(2)
        f_global = f_global.unsqueeze(1).repeat(1, vertice_num, 1)

        feat = torch.cat([fm_0, fm_1, fm_2, fm_3, fm_4], dim= 2)
        fm_fuse = torch.cat([fm_0, fm_1, fm_2, fm_3, fm_4, f_global], dim= 2)
        return feat.permute(0, 2, 1), fm_fuse.permute(0, 2, 1)