import torch
from torchvision.models.resnet import *
import torch.nn.functional as F
import torch.nn as nn
from matplotlib import pyplot as plt
from layers import *
import math
import time
from utils import *
import h5py
# import cupy_module.adacof as adacof

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config
        self.bilinear = config['bilinear']
        self.prev_len = config['prev_len']
        self.fut_len = config['fut_len']
        self.available_pred_len = 1 if config['long_term'] else config['fut_len'] # currently only support pred 1 future frame at a time
        self.coord_cuda = torch.zeros((config['batch'],config['total_len'],config['mat_size'][-1][0]*config['mat_size'][-1][1],self.config['edge_num'],2), dtype=int).to('cuda')
        self.n_channel = config['n_channel']
        if self.config['shuffle_scale'] > 1:
            self.unshuffle = nn.PixelUnshuffle(self.config['shuffle_scale'])
            self.shuffle = nn.PixelShuffle(self.config['shuffle_scale'])

        if config['recon_type'] == 'RRDB':
            self.encoder = RRDBEncoder(config)
            
        
        else:
            print(config['recon_type']+ ' is not implemented for '+config['method'])
            exit()

        

        if config['model_type'] == 'pred':
            # self.moduleAdaCoF = adacof.FunctionAdaCoF.apply
            if config['learnable_sim']:
            
                self.learnable_sim_proj = []
                for i in range(self.config['scale_in_use']):
                    scale_num = len(self.config['downsample_scale'])+1-self.config['scale_in_use']+i
                    feat_len = self.config['base_channel']*(2**(scale_num))

                    # sim_feat_len = self.config['base_channel'] * (2**(i))
                    sim_feat_len = feat_len
                    self.learnable_sim_proj.append(nn.Sequential(nn.Conv2d(feat_len,feat_len,kernel_size=3,padding=1), 
                        nn.BatchNorm2d(feat_len),
                        nn.LeakyReLU()))
                self.learnable_sim_proj = nn.ModuleList(self.learnable_sim_proj)
            
            if config['window_length'] < 0:
                self.window_mask = None
            else:
                w,h = config['mat_size'][-1]
                l = config['window_length']//2
                self.window_mask = torch.zeros(w,h,w,h).to(config['device'])
                for i in range(w):
                    for j in range(h):
                        for m in range(max(0,i-l),min(i+l,w)):
                            for n in range(max(0,j-l),min(j+l,h)):
                                self.window_mask[i][j][m][n] = 1.
                self.window_mask = self.window_mask.reshape(w*h,w*h).bool()
                
            
            #Motion indices
            self.motion_indices = motion_node_indices(self.config).to(self.config['device']) #1,T,HW,3 (x,y,t), range 0~1
            self.motion_indices[...,2] = 0

                

            # if ('recon' in self.config['loss_list']) or ('weighted_recon' in self.config['loss_list']) or ('lap' in self.config['loss_list']):
                
                


            # node encoder 
            if self.config['pred_base_channel'] > 0:
                encoder_feat_len =self.config['pred_base_channel']
                self.node_encoder = nn.Sequential(nn.Linear(3 * self.config['edge_num'],encoder_feat_len),
                nn.GroupNorm(1,encoder_feat_len),
                nn.LeakyReLU(),
                nn.Linear(encoder_feat_len,encoder_feat_len),
                nn.GroupNorm(1,encoder_feat_len),
                nn.LeakyReLU(),
                nn.Linear(encoder_feat_len,encoder_feat_len),
                nn.GroupNorm(1,encoder_feat_len),
                nn.LeakyReLU(),
                )
            self.tdc_len = config['tendency_len']
            if (self.config['tendency_len'] > 0):
                
                self.tdc_encoder = nn.Sequential(nn.Linear(3 ,self.tdc_len),
                nn.GroupNorm(1,self.tdc_len),
                nn.LeakyReLU(),
                nn.Linear(self.tdc_len,self.tdc_len),
                nn.GroupNorm(1,self.tdc_len),
                nn.LeakyReLU(),
                nn.Linear(self.tdc_len,self.tdc_len),
                nn.GroupNorm(1,self.tdc_len),
                nn.LeakyReLU(),
                )
            self.pos_len = self.config['pos_len']
            if (self.config['pos_len'] > 0):
                
                self.pos_encoder = nn.Sequential(nn.Linear(2 ,self.pos_len),
                nn.GroupNorm(1,self.pos_len),
                nn.LeakyReLU(),
                nn.Linear(self.pos_len,self.pos_len),
                nn.GroupNorm(1,self.pos_len),
                nn.LeakyReLU()
                )
            

            # graph attention for motion prediction
            spatial_att_list = []
            temporal_forward_att_list = []
            temporal_backward_att_list = []
            for i in range(self.config['scale_in_use']):
                spatial_att = []
                temporal_forward_att = []
                temporal_backward_att = []
                for j in range(config['pred_att_iter_num']):
                    if self.config['spatial_conv']:
                        spatial_att.append(SpatialAtt(config, edge_type = 'spatial'))
                    else:
                        spatial_att.append(GraphAtt(config, edge_type = 'spatial'))
                    if 'forward' in self.config['edge_list']:
                        temporal_forward_att.append(GraphAtt(config, edge_type = 'forward'))
                    if 'backward' in self.config['edge_list']:
                        if self.config['spatial_conv']:
                            spatial_att.append(SpatialAtt(config, edge_type = 'spatial'))
                        else:
                            spatial_att.append(GraphAtt(config, edge_type = 'spatial'))

                        temporal_backward_att.append(GraphAtt(config, edge_type = 'backward'))
                if self.config['spatial_conv']:
                    
                        decoder_len = self.config['pred_base_channel'] + self.tdc_len + self.pos_len
                        spatial_att.append(nn.Sequential(
                            nn.Conv3d(decoder_len,decoder_len, kernel_size=(3,3,3), stride=(1,1,1), padding='same'),
                            nn.BatchNorm3d(decoder_len),
                            nn.LeakyReLU(),
                            nn.Conv3d(decoder_len,decoder_len, kernel_size=(3,3,3), stride=(1,1,1), padding='same'),
                            nn.BatchNorm3d(decoder_len),
                            nn.LeakyReLU()
                        ))
                else:
                    spatial_att.append(GraphAtt(config, edge_type = 'spatial'))
                spatial_att_list.append(nn.ModuleList(spatial_att))
                temporal_forward_att_list.append(nn.ModuleList(temporal_forward_att))
                temporal_backward_att_list.append(nn.ModuleList(temporal_backward_att))

            self.spatial_att_list = nn.ModuleList(spatial_att_list)
            self.temporal_forward_att_list = nn.ModuleList(temporal_forward_att_list)
            self.temporal_backward_att_list = nn.ModuleList(temporal_backward_att_list)

            
            # motion decoder
            decoder_len = self.config['pred_base_channel'] + self.tdc_len + self.pos_len
            self.motion_decoder = nn.Sequential(nn.Linear(decoder_len,decoder_len),
            nn.GroupNorm(1,decoder_len),
            nn.LeakyReLU(),
            nn.Linear(decoder_len,self.config['pred_base_channel'] ),
            nn.GroupNorm(1,self.config['pred_base_channel']),
            nn.LeakyReLU(),
            nn.Linear(self.config['pred_base_channel'] , config['out_edge_num'] * 3))
            self.sigmoid = nn.Sigmoid()

            #multiscale motion fuse
            if self.config['motion_fuse']:
                self.motion_fuse = nn.Sequential(
                    nn.Conv3d(decoder_len,decoder_len, kernel_size=(self.config['scale_in_use'],3,3), stride=(1,1,1), padding=(0,1,1)),
                    nn.BatchNorm3d(decoder_len),
                    nn.LeakyReLU()
                )
            if self.config['motion_upsample']:
                self.motion_indices_upsample = motion_node_indices_upsample(self.config) #1,T,HW,3 (x,y,t), range 0~1
                n_channel = config['n_channel'] * (config['shuffle_scale']**2)
                feat_len = config['base_channel']
                self.motion_upsampler = MotionDecoder(config)
                self.motion_decoder = None

        
            #Post GAT
            if self.config['post_gat']:
                self.post_gat = PostGAT(config)

            # graph attention for composition
            if not self.config['multiflow_compose']:
                att = []
                img_feat_len = self.config['base_channel']*(2**(len(config['downsample_scale'])))
                for i in range(config['scale_in_use']-1,-1,-1): 
                    att_cur_scale = []
                    for j in range(config['compose_att_iter_num']):
                        
                        att_cur_scale.append(GraphAtt(config,img_feat = img_feat_len * (2**i),edge_type = 'compose'))
                    att.append(nn.ModuleList(att_cur_scale))
                self.compose_att = nn.ModuleList(att)
            

            # feature compoistion

            self.compose_shuffle = []
            self.compose_unshuffle = []
            self.compose_scale = []
            for i in range(len(config['downsample_scale'])):
                feat_shuffle_scale = 1
                for s in range(len(config['downsample_scale'])-1,i-1,-1):
                    feat_shuffle_scale *= config['downsample_scale'][s]
                
                self.compose_scale.append(feat_shuffle_scale)
                self.compose_shuffle.append(nn.PixelShuffle(feat_shuffle_scale))
                self.compose_unshuffle.append(nn.PixelUnshuffle(feat_shuffle_scale))

            self.compose_scale.append(1)
            self.compose_shuffle = nn.ModuleList(self.compose_shuffle)
            self.compose_unshuffle = nn.ModuleList(self.compose_unshuffle)
            # if self.config['rrdb_enhance_num'] > 0 and config['res_cat_img']:
            #     self.enhancer = ImageEnhancer(config=self.config,n_channels=self.config['n_channel'] , base_channel=self.config['base_channel'] //2 )

                # exit()
    def graph_construct(self,sim_feat,B,T):
        N = sim_feat.shape[0]   
        c = sim_feat.shape[1]
        h = sim_feat.shape[2]
        w = sim_feat.shape[3]
        sim_feat = sim_feat.reshape(B, T, -1, h, w)
        
        mat_hw = self.config['mat_size'][-1][0] *  self.config['mat_size'][-1][1]
        gt_motion = None
        weight_map =None

        '''
        Calculate
        '''
        mat_list = {}
        for mat_name in self.config['edge_list']:
            if self.config['masked_matrix']:
                mat_list[mat_name] = build_similarity_matrix_masked(sim_feat.clone(),self.config,mat_type=mat_name,window_mask = self.matrix_mask.clone() )
            else:
                mat_list[mat_name] = build_similarity_matrix(sim_feat.clone(),self.config,mat_type=mat_name)
            
        '''
        Build motion graph
        '''
        edge_list = {}
        weight_list = {}
        for mat_name in self.config['edge_list']:
            if mat_name == 'spatial' and self.config['spatial_conv']:
                edge_list[mat_name] = None
                weight_list[mat_name] = None
                continue
            if self.config['masked_topk']:
                
                edge,weight,X,Y = build_graph_edge_masked(mat_list[mat_name],self.config,gt=(mat_name == 'gt_graph'),coord_cuda=self.coord_cuda,\
                matrix_mask = self.matrix_mask.clone(),indices_lookup = self.indices_lookup.clone())
            else:
                edge,weight,X,Y = build_graph_edge(mat_list[mat_name],self.config,gt=(mat_name == 'gt_graph'),coord_cuda=self.coord_cuda)

            if self.config['edge_normalize']:
                weight = edge_normalize(edge,weight,B,mat_hw)

            if mat_name == 'gt_graph':
                gt_motion = X.clone() #normalize coords
                weight_map = Y.clone()
            elif mat_name == 'forward':
                node_init = X.clone() #normalize coords

            edge_list[mat_name] = edge.clone()
            weight_list[mat_name] = weight.clone()

        return edge_list,weight_list,gt_motion,node_init,weight_map

    def compose(self,graph_feat,edge,weight):
        N = graph_feat[-1].shape[0]   
        c = graph_feat[-1].shape[1]
        h = graph_feat[-1].shape[2]
        w = graph_feat[-1].shape[3]
        # print('compose func:',graph_feat[-1][3,:10,16,16])
        pred_feat_list = [None for i in range(len(self.config['downsample_scale'])+1)]
        scale_num = -1

        for i in range(len(graph_feat)):
            
            #------Unshuffle------#
            
            if graph_feat[i] is None:
                continue
            else:
                cur_feat = graph_feat[i].clone()

            scale_num += 1
            BT,ori_c,ori_h,ori_w = cur_feat.shape
            
            if not (ori_h == h and ori_w == w):
                
                # test if eligible for shuffling
                if (ori_h != h*self.compose_scale[i]) or (ori_w != w*self.compose_scale[i]):
                    cur_feat = F.interpolate(cur_feat,(h*self.compose_scale[i],w*self.compose_scale[i]))
                cur_feat = self.compose_unshuffle[i](cur_feat.clone())
                
                
            BT,cur_c,_,_ = cur_feat.shape
                
            #---------------------#


            #add masked all zero feature
            cur_feat = cur_feat.reshape(self.cur_B,-1,cur_c,h,w)
            cur_feat = torch.cat([cur_feat[:,:self.config['prev_len']],torch.zeros(self.cur_B,self.config['fut_len'],cur_c,h,w).to(self.config['device'])],dim=1)

                
            
            #------- Compose--------#
            for j in range(self.config['compose_att_iter_num']):
               
                cur_feat = self.compose_att[scale_num][j](cur_feat,edge[i],weight[i])

            #---------------------#

            pred_feat_list[i] = cur_feat[:,-self.config['fut_len']:].reshape(self.cur_B*self.config['fut_len'],cur_c,h,w).clone()

            if not (ori_h == h and ori_w == w):
                pred_feat_list[i] = self.compose_shuffle[i](pred_feat_list[i].clone())
                
                if (ori_h != h*self.compose_scale[i]) or (ori_w != w*self.compose_scale[i]): 
                    pred_feat_list[i] = F.interpolate(pred_feat_list[i],(ori_h,ori_w)).clone()

        return pred_feat_list

    
    def multiflow_compose(self,graph_feat,flow_list,img=False):
        N = graph_feat[-1].shape[0]   
        c = graph_feat[-1].shape[1]
        h = graph_feat[-1].shape[2]
        w = graph_feat[-1].shape[3]
        pred_feat_list = [None for i in range(len(graph_feat))]
        scale_num = -1
        for i in range(len(graph_feat)):

            if graph_feat[i] is None:
                continue
            else:
                cur_feat = graph_feat[i].clone()
                cur_flow = flow_list[i].clone()


            if img:
                N = graph_feat[i].shape[0]   
                c = graph_feat[i].shape[1]
                h = graph_feat[i].shape[2]
                w = graph_feat[i].shape[3]
            
            #------Unshuffle------#
            
            

            scale_num += 1
            
            B,T,hw,K,_ = cur_flow.shape
            BT,ori_c,ori_h,ori_w = cur_feat.shape

            
            if not (ori_h == h and ori_w == w):
                
                # test if eligible for shuffling

                
                if (ori_h != h*self.compose_scale[i]) or (ori_w != w*self.compose_scale[i]):
                    cur_feat = F.interpolate(cur_feat,(h*self.compose_scale[i],w*self.compose_scale[i]))
                if not img:
                    cur_feat = self.compose_unshuffle[i](cur_feat.clone())

                    
                    
            BT,cur_c,_,_ = cur_feat.shape
            cur_feat = cur_feat.reshape(B,-1,cur_c,h,w)[:,:self.prev_len] 
            
            #------- Compose--------#
            pred_feat = multi_warp(cur_feat,cur_flow,last_only = self.config['last_only'])



            #---------------------#

            pred_feat_list[i] = pred_feat.clone()


            if not (ori_h == h and ori_w == w):
                if not img:
                    pred_feat_list[i] = self.compose_shuffle[i](pred_feat_list[i].clone())
                if (ori_h != h*self.compose_scale[i]) or (ori_w != w*self.compose_scale[i]): 
                    pred_feat_list[i] = F.interpolate(pred_feat_list[i],(ori_h,ori_w)).clone()

        return pred_feat_list

    def long_term_forward(self, input_image):
        output_list = {}
        pred_img_list = []
        B, T, H, W, C = input_image.shape
        cur_input_seq = input_image.clone()[:,:self.prev_len]
        
        for i in range(self.fut_len):

            cur_output = self.forward(cur_input_seq,inference=True)
            pred_img = cur_output['recon_img'].reshape(B,-1,H,W,C)
            pred_img_list.append(pred_img.clone())
            cur_input_seq = torch.cat((cur_input_seq[:,1:],pred_img),dim=1)

        '''
        prepare for output
        '''
        output_list['recon_img'] = torch.cat(pred_img_list,dim=1)

        return output_list
            

    def forward(self, input_image,inference=False,visualization=False):

        ori_input = input_image.clone()
        output_list = {}
        
        if self.config['model_type'] == 'recon':
            B, H, W, C = input_image.shape
            input = input_image.permute(0, 3, 1, 2).clone()
            if self.config['shuffle_scale'] > 1:
                input = self.unshuffle(input)
                
            emb_feat_list = self.encoder(input)  # N, C, H, W

            recon_img = self.decoder(emb_feat_list)
            if self.config['shuffle_scale'] > 1:
                recon_img = self.shuffle(recon_img)

            if recon_img.shape[-2] != H or recon_img.shape[-1] != W:
                recon_img = F.interpolate(recon_img,(H,W))
            
        
            output_list['recon_img'] = recon_img.permute(0, 2, 3, 1)
            return output_list

        elif self.config['model_type'] == 'pred':

            '''
            spatial feature extraction
            '''
            # input_image = input_image.unsqueeze(0)

            t = time.time()
            start_time = t
            # input_image = input_image.unsqueeze(0)
            B, T, H, W, C = input_image.shape

            # T, H, W, C = input_image.shape
            # B = 1
            self.cur_B = B
            ori_input_image = input_image.clone()
            input_image = input_image.reshape(-1, H, W, C)  # B*T,H,W,C
            input_image = input_image.permute(0, 3, 1, 2)
            
            if self.config['shuffle_scale'] > 1:
                input_image = self.unshuffle(input_image)
            input_image_raw = input_image.clone()
            raw_img_wh = input_image_raw.shape[-2:]
            t = time.time()
            emb_feat_list = self.encoder(input_image) # N, C, H, W

            # print('encode times: ',time.time()-t,' s')
            t = time.time()

            

            #-----------#
            
            '''
            sim matrix calculation
            '''
            output_list['gt_motion'] = []
            output_list['pred_motion'] = []
            edge_list = []
            weight_list = []
            node_init_list = []
            non  = 0
            for i in range(len(self.config['downsample_scale'])+1):
                if emb_feat_list[i] is None:
                    output_list['gt_motion'].append(None)
                    edge_list.append(None)
                    weight_list.append(None)
                    node_init_list.append(None)
                    non += 1
                    continue
                cur_feat = emb_feat_list[i].clone()

                if self.config['learnable_sim']:
                    
                    sim_feat = self.learnable_sim_proj[i-non](cur_feat) #use lowest res features to construct graph 
                else:
                    sim_feat = cur_feat
                
                if i != len(self.config['downsample_scale']):
                    sim_feat = self.compose_unshuffle[i](sim_feat.clone())
            
            
                edge,weight,gt_motion,node_init,weight_map = self.graph_construct(sim_feat,B,T)
  
                if 'weighted_recon' in self.config['loss_list']:
                    repeat = weight_map.unsqueeze(1).repeat(1,4**(len(self.config['downsample_scale'])+1),1,1)
                    weight_map = self.compose_shuffle[0](repeat)
                    if self.config['shuffle_scale'] > 1:
                        weight_map = self.shuffle(weight_map)
                    
                    output_list['weight_map'] = weight_map.permute(0,2,3,1)
                # print(self.motion_indices.clone().shape)
                # exit()
                
                node_init[:,:,:,:,:2] -= (self.motion_indices.clone().repeat([B,1,1,1,1]))[:,:node_init.shape[1],:,:,:2] #record the offset
                node_init_list.append(torch.cat([node_init,torch.zeros_like(node_init[:,-1:])],dim=1)) #B,T,HW,K,3; Add maskd last frame info
 
                edge_list.append(edge)
                weight_list.append(weight)
                if 'gt_graph' in self.config['edge_list']:
                    output_list['gt_motion'].append(gt_motion.clone())

            
            # return edge_list
            if 'motion_kd' in self.config['loss_list']:
                output_list['motion_kd_teacher'] = []
                output_list['motion_kd_student'] = []
                for i in range(len(self.config['downsample_scale'])+1):
                    if not(output_list['gt_motion'][i] is None):
                        cur_motion = output_list['gt_motion'][i].clone()
                        b_,t_,hw_,k_,c_ = cur_motion.shape
                        cur_motion[:,:,:,:,:2] -= (self.motion_indices.clone().repeat([B,1,1,1,1]))[:,:cur_motion.shape[1],:,:,:2] #record the offset
                        
                        output_list['motion_kd_teacher'].append(self.node_encoder(cur_motion.reshape(b_,t_,hw_,k_*c_)))


            if visualization:
                output_list['edge_list'] = edge_list
                output_list['weight_list'] = weight_list
            
            '''
            Composition
            '''
            
            # if not inference:
            # if ('kd' in self.config['loss_list']) or ('gt_recon' in self.config['loss_list']):
            if ('gt_recon' in self.config['loss_list']):
                gt_edge_list = []
                gt_weight_list = []
                gt_flow_list = []
                start = 0
                for i in range(len(self.config['downsample_scale'])+1):
                    if emb_feat_list[i] is None:
                        gt_edge_list.append(None)
                        gt_weight_list.append(None)
                        gt_flow_list.append(None)
                        start += 1
                        continue
                    gt_motion = output_list['gt_motion'][i].clone()
                    if self.config['multiflow_compose']:
                        gt_motion[...,:2] -= (self.motion_indices.clone().repeat([B,1,1,1,1]))[...,:2]
                        # gt_motion[...,0] *= (self.config['mat_size'][-1][0]-1)
                        # gt_motion[...,1] *= (self.config['mat_size'][-1][1]-1)
                        gt_flow_list.append(gt_motion)
                    else:

                        gt_edge,gt_weight = transform_graph_edge(self.config,gt_motion[...,:2].clone(),gt_motion[...,2].clone(),gt=True)
                        
                        if self.config['last_only']:
                            gt_motion = gtedge_normalize_motion[:,-1]
                            gt_edge,gt_weight = gt_edge[gt_edge[:,1] == (self.config['prev_len']-1)],gt_weight[gt_edge[:,1] == (self.config['prev_len']-1)]

                        if self.config['edge_softmax']:
                            mat_hw = self.config['mat_size'][-1][0] *  self.config['mat_size'][-1][1]
                            gt_weight = edge_softmax(gt_edge,gt_weight,B,mat_hw)
                        elif self.config['edge_normalize']:
                            mat_hw = self.config['mat_size'][-1][0] *  self.config['mat_size'][-1][1]
                            gt_weight = edge_normalize(gt_edge,gt_weight,B,mat_hw)

                        gt_edge_list.append(gt_edge.clone())
                        gt_weight_list.append(gt_weight.clone())  
                if self.config['multiflow_compose']:
                    if self.config['motion_upsample']:
                        gt_feat_list = self.multiflow_compose([(emb_feat_list)],[gt_flow_list])
                    else:
                        gt_feat_list = self.multiflow_compose(emb_feat_list,gt_flow_list)
                else:
                    gt_feat_list = self.compose(emb_feat_list,gt_edge_list,gt_weight_list)

                if 'kd' in self.config['loss_list']:
                    output_list['teacher'] = []
                    output_list['student'] = []
                    for i in range(len(emb_feat_list)):
                        if not(emb_feat_list[i] is None):
                            B,C,H,W = gt_feat_list[i].shape
                            output_list['teacher'].append(emb_feat_list[i].reshape(B,-1,C,H,W)[:,-1:])
                            output_list['student'].append(gt_feat_list[i].reshape(B,-1,C,H,W))  


                if 'gt_recon' in self.config['loss_list']:
                
                    if self.config['post_gat']:
                        gt_feat_list = self.post_gat(gt_feat_list)
                    if self.config['motion_upsample']: 
                    #     gt_recon_img = self.decoder(gt_feat_list[0])
                    # else:
                        gt_recon_img = self.decoder(gt_feat_list)
                    
                    

                    if self.config['shuffle_scale'] > 1:
                        gt_recon_img = self.shuffle(gt_recon_img)
                    output_list['gt_recon_img'] = gt_recon_img.permute(0, 2, 3, 1).clone() # B, H,W,C

                    if not('recon' in self.config['loss_list']):
                        output_list['recon_img'] = output_list['gt_recon_img']
            else:
                # print('Current not available!')
                # exit()
                '''
                Motion prediction steps:
                1. Init node with indices & Node encoding
                2. Spatial - Temporal Interaction
                3. Motion Decodeing
                4. Weight Prediction
                '''

                start = 0
                pred_edge_list = []
                pred_weight_list = []
                pred_flow_list = []
                tendency_feat_list = []
                for j in range(len(emb_feat_list)):
                    if not(emb_feat_list[j] is None):
                        b_,t_,hw_,k_,c_ = node_init_list[j].shape
                        if self.config['pred_base_channel'] > 0:
                            init_node_feat = self.node_encoder(node_init_list[j].reshape(-1,k_*c_)).reshape(b_,t_,hw_,-1)
                        if (self.config['tendency_len'] > 0):
                            if self.config['tdc_pool'] == 'max':
                                tendency_feat = torch.max(self.tdc_encoder(node_init_list[j].reshape(-1,3)).reshape(b_,t_,hw_,k_,-1),dim=-2)[0]
                            elif self.config['tdc_pool'] == 'avg':
                                tendency_feat = torch.mean(self.tdc_encoder(node_init_list[j].reshape(-1,3)).reshape(b_,t_,hw_,k_,-1),dim=-2)
                            
                            tendency_feat_list.append(tendency_feat.clone())
                            if self.config['pred_base_channel'] > 0:
                                init_node_feat = torch.cat([init_node_feat,tendency_feat],dim=-1)
                            else:
                                init_node_feat = tendency_feat

                        if self.config['pos_len'] > 0:
                            normalize_pos = self.motion_indices.clone().repeat([B,1,1,1,1])[...,0,:2].reshape(-1,2)
                            normalize_pos[:,0] /= (self.config['mat_size'][-1][0]-1.)
                            normalize_pos[:,1] /= (self.config['mat_size'][-1][1]-1.)
                            pos_id = self.pos_encoder(normalize_pos).reshape(b_,t_,hw_,-1)

                            init_node_feat = torch.cat([pos_id,init_node_feat],dim=-1)
                        cur_node = init_node_feat.clone()
                        for i in range(self.config['pred_att_iter_num']):
                            if 'spatial' in self.config['edge_list']:
                                idx = i if ((self.config['last_only'])  or ('backward' not in self.config['edge_list']))else i*2
                                cur_node = self.spatial_att_list[j-start][idx](cur_node,edge_list[j]['spatial'],weight_list[j]['spatial'])
                            if 'forward' in self.config['edge_list']:
                                cur_node = self.temporal_forward_att_list[j-start][i](cur_node,edge_list[j]['forward'],weight_list[j]['forward'],position=self.motion_indices[...,0,:2].clone())
                            
                            if 'backward' in self.config['edge_list']:
                                if 'spatial' in self.config['edge_list']:
                                    cur_node = self.spatial_att_list[j-start][i*2+1](cur_node,edge_list[j]['spatial'],weight_list[j]['spatial'])
                                cur_node = self.temporal_backward_att_list[j-start][i](cur_node,edge_list[j]['backward'],weight_list[j]['backward'],position=self.motion_indices[...,0,:2].clone())
                        
                        # if self.config['last_only']:
                        #     cur_node = self.spatial_att_list[j-start][-1](cur_node,edge_list[j]['spatial'],weight_list[j]['spatial'])
                        # else:
                        w,h = self.config['mat_size'][-1]
                        cur_node = cur_node.reshape(b_,t_,h,w,-1).permute(0,4,1,2,3)
                        cur_node = self.spatial_att_list[j-start][-1](cur_node)
                        cur_node = cur_node.permute(0,2,3,4,1).reshape(b_,t_,hw_,-1)
                            
                        if 'motion_kd' in self.config['loss_list']:
                            output_list['motion_kd_student'].append(cur_node.clone())
                        if self.config['motion_fuse']:
                            pred_flow_list.append(cur_node.clone())
                        else:
                            pred_node_flow = self.motion_decoder(cur_node.reshape(b_*t_*hw_,-1)).reshape(b_,t_,hw_,k_,c_)
                            pred_node_flow[...,:,-1] = pred_node_flow[...,:,-1].exp() / (1. + torch.sum(pred_node_flow[...,:,-1].exp(),dim=-1).unsqueeze(-1))
                            pred_node_motion = pred_node_flow.clone()
                            
                            pred_node_motion[...,:2] = pred_node_motion[...,:2]+ (self.motion_indices.clone().repeat([B,1,1,1,1]))[...,:pred_node_flow.shape[-2],:2]
                        
                            output_list['pred_motion'].append(pred_node_motion.clone())
                            
                            if self.config['multiflow_compose']:
                                pred_node_flow[...,0] *= (self.config['mat_size'][-1][0]-1)
                                pred_node_flow[...,1] *= (self.config['mat_size'][-1][1]-1)
                                pred_flow_list.append(pred_node_flow.clone())
                            else:
                                pred_node_motion[...,:2] = torch.clamp(pred_node_motion[...,:2].clone(),min=0,max=1)
                                pred_edge,pred_weight = transform_graph_edge(self.config,pred_node_motion[...,:2].clone(),pred_node_motion[...,2].clone())
                                pred_edge_list.append(pred_edge)
                                pred_weight_list.append(pred_weight)
                        

                    else:
                        output_list['pred_motion'].append(None)
                        pred_edge_list.append(None)
                        pred_weight_list.append(None)
                        pred_flow_list.append(None)
                        start += 1

                output_list['tendency_feat_list'] = tendency_feat_list

                if self.config['motion_fuse']:
                    valid_flow = []
                    for i in range(len(pred_flow_list)):
                        if pred_flow_list[i] is None:
                            continue
                        else:
                            valid_flow.append(pred_flow_list[i].clone())
                    
                    multi_scale_motion = torch.stack(valid_flow,dim=2)
                    b_,t_,s_,hw_,c_ = multi_scale_motion.shape
                    h,w = self.config['mat_size'][-1]

                    multi_scale_motion = multi_scale_motion.reshape(b_*t_,s_,h,w,c_).permute(0,4,1,2,3)
                    fused_motion = self.motion_fuse(multi_scale_motion).squeeze(2)
                    
                    if self.config['motion_upsample']:
                        output_list['pred_motion'] = []
                        pred_flow_list = []
                        pred_node_flow_list = self.motion_upsampler(fused_motion)
                        length  = len(pred_node_flow_list)
                        for f_id in range(len(pred_node_flow_list)):
                            pred_node_flow = pred_node_flow_list[length-f_id-1] # scale in pred_node_flow_list is from small to large
                            pred_node_flow = pred_node_flow.reshape(b_,t_,self.config['out_edge_num'],3,-1).permute(0,1,4,2,3)
                            pred_node_flow[...,:,-1] = pred_node_flow[...,:,-1].exp() / (1. + torch.sum(pred_node_flow[...,:,-1].exp(),dim=-1).unsqueeze(-1))
                            pred_node_motion = pred_node_flow.clone()
                            if self.config['high_res_only']:
                                pred_node_motion[...,:2] = pred_node_motion[...,:2]+ (self.motion_indices_upsample[0].clone().repeat([B,1,1,1,1]))[...,:pred_node_flow.shape[-2],:2]                        
                            elif 'motion_loss' in self.config['loss_list']:
                                pred_node_motion[...,:2] = pred_node_motion[...,:2]+ (self.motion_indices_upsample[0].clone().repeat([B,1,1,1,1]))[...,:pred_node_flow.shape[-2],:2]                        
                            else:

                                pred_node_motion[...,:2] = pred_node_motion[...,:2]+ (self.motion_indices_upsample[0].clone().repeat([B,1,1,1,1]))[...,:pred_node_flow.shape[-2],:2]                        
                            pred_flow_list.append(pred_node_flow.clone())
                            output_list['pred_motion'].append(pred_node_motion.clone())

                        
                    else:
                        fused_motion = fused_motion.reshape(b_,t_,c_,hw_).permute(0,1,3,2)
                        pred_node_flow = self.motion_decoder(fused_motion.reshape(b_*t_*hw_,-1)).reshape(b_,t_,hw_,k_,-1)
                        pred_node_flow[...,:,-1] = pred_node_flow[...,:,-1].exp() / (1. + torch.sum(pred_node_flow[...,:,-1].exp(),dim=-1).unsqueeze(-1))
                        pred_node_motion = pred_node_flow.clone()
                        pred_node_motion[...,:2] = pred_node_motion[...,:2]+ (self.motion_indices.clone().repeat([B,1,1,1,1]))[...,:pred_node_flow.shape[-2],:2]
                        for i in range(len(pred_flow_list)):
                            output_list['pred_motion'] = []
                            if pred_flow_list[i] is None:
                                output_list['pred_motion'].append(None)
                                continue
                            else:
                                pred_flow_list[i] = pred_node_flow.clone()
                                output_list['pred_motion'].append(pred_node_motion.clone())
                    
                    
                
                if 'motion_loss' in self.config['loss_list']:
                    if self.config['multiflow_compose']:
                        shuffle_image_list = [input_image_raw.clone() for i in range(len(self.config['downsample_scale'])+1)]

                        warped_image_list = self.multiflow_compose(shuffle_image_list,pred_flow_list,img=True)
                        output_list['warped_img'] = []
                        for i in range(len(self.config['downsample_scale'])+1):
                            output_list['warped_img'].append(warped_image_list[i].clone())
                            if self.config['shuffle_scale'] > 1:
                                output_list['warped_img'][i] = self.shuffle(output_list['warped_img'][i])
                            
                            output_list['warped_img'][i] = output_list['warped_img'][i].permute(0,2,3,1)    
            
                else:
                    shuffle_image_list = []
                    shuffle_image_list.append(input_image_raw.clone())
                    warped_image_list = self.multiflow_compose(shuffle_image_list,[pred_flow_list[0]],img=True)
                    output_list['warped_img'] = []
                    output_list['warped_img'].append(warped_image_list[0].clone())
                    if self.config['shuffle_scale'] > 1:
                        output_list['warped_img'][0] = self.shuffle(output_list['warped_img'][0])
                    output_list['warped_img'][0] = output_list['warped_img'][0].permute(0,2,3,1)

                output_list['recon_img'] = output_list['warped_img'][0]
            return output_list








