import numpy as np
import os
from matplotlib import pyplot as plt
import torch
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from skimage.color import rgb2yuv
import torch.nn.functional as F
import math
import torchvision
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
import cv2
import lpips
import random
from torch_geometric.utils import softmax as group_softmax
from torch_geometric.utils import scatter as geo_scatter
import torchist
import softsplat
from pytorch_msssim import ms_ssim
from torch.nn.utils import spectral_norm
import torchvision.models as models

device = torch.device("cuda")
grid = None

lpips = lpips.LPIPS(net='alex').cuda()
metrics_to_save = []

class MeanShift(nn.Conv2d):
    def __init__(self, data_mean, data_std, data_range=1, norm=True):
        c = len(data_mean)
        super(MeanShift, self).__init__(c, c, kernel_size=1)
        std = torch.Tensor(data_std)
        self.weight.data = torch.eye(c).view(c, c, 1, 1)
        if norm:
            self.weight.data.div_(std.view(c, 1, 1, 1))
            self.bias.data = -1 * data_range * torch.Tensor(data_mean)
            self.bias.data.div_(std)
        else:
            self.weight.data.mul_(std.view(c, 1, 1, 1))
            self.bias.data = data_range * torch.Tensor(data_mean)
        self.requires_grad = False

def unravel_index(indices,shape,coord_cuda):
    r"""Converts flat indices into unraveled coordinates in a target shape.
    from: https://github.com/pytorch/pytorch/issues/35674#issuecomment-739492875

    This is a `torch` implementation of `numpy.unravel_index`.

    Args:
        indices: A tensor of indices, (*, N).
        shape: The targeted shape, (D,).

    Returns:
        unravel coordinates, (*, N, D).
    """

    shape = torch.tensor(shape).to('cuda')
    indices = indices.to('cuda') % torch.prod(shape)  # prevent out-of-bounds indices
    coord_size = indices.size() + shape.size()
    # return None
    coord = coord_cuda[:coord_size[0],:coord_size[1],:coord_size[2],:coord_size[3],:coord_size[4]].clone()
    for i, dim in enumerate(reversed(shape)):
        coord[..., i] = indices % dim
        # indices = torchindices / dim
        indices = torch.div(indices,dim,rounding_mode ='trunc')

    return coord.flip(-1)

def ravel_index(indices,shape):
    r"""Converts flat indices into unraveled coordinates in a target shape.
    from: https://github.com/pytorch/pytorch/issues/35674#issuecomment-739492875

    This is a `torch` implementation of `numpy.unravel_index`.

    Args:
        indices: A tensor of indices, (*, N,2).
        shape: The targeted shape, (D,).

    Returns:
        ravel coordinates, (*, N).
    """

    coord = torch.zeros(indices.size()[:-1]).to(indices.device)
    coord = indices[...,0] * shape[1] + indices[...,1]

    return (coord+0.5).long()
    


def augmentation_pred(config,prev_frames_tensor,fut_frames_tensor):
    if config['flip_aug']:
        flag = random.uniform(0,1)
        if flag < 0.5:
            
            prev_frames_tensor = torch.flip(prev_frames_tensor,dims=[2])
            fut_frames_tensor = torch.flip(fut_frames_tensor,dims=[2])
        if config['name'].find('ucf') > -1:
            flag = random.uniform(0,1)
            if flag < 0.5 :
                prev_frames_tensor = torch.flip(prev_frames_tensor,dims=[3])
                fut_frames_tensor = torch.flip(fut_frames_tensor,dims=[3])

    if config['rot_aug']:
        flag = random.uniform(0,1)
        if flag < 0.5:
            if config['name'].find('ucf') > -1:
                k = random.randint(1, 3)
            else:
                k = 2
            prev_frames_tensor = torch.rot90(prev_frames_tensor,dims=(2,3),k=k)
            fut_frames_tensor = torch.rot90(fut_frames_tensor,dims=(2,3),k=k)

    return prev_frames_tensor,fut_frames_tensor

def augmentation_recon(config,frame_tensor):
    if config['flip_aug']:
        flag = random.uniform(0,1)
        if flag < 0.5:
            
            frame_tensor = torch.flip(frame_tensor,dims=[1])
        if config['name'].find('ucf') > -1:
            flag = random.uniform(0,1)
            if flag < 0.5 :
                prev_frames_tensor = torch.flip(prev_frames_tensor,dims=[2])
                fut_frames_tensor = torch.flip(fut_frames_tensor,dims=[2])

    if config['rot_aug']:
        flag = random.uniform(0,1)
        if flag < 0.5:
            if config['name'].find('ucf') > -1:
                k = random.randint(1, 3)
            else:
                k = 2
            prev_frames_tensor = torch.rot90(prev_frames_tensor,dims=(1,2),k=k)
            fut_frames_tensor = torch.rot90(fut_frames_tensor,dims=(1,2),k=k)

    return prev_frames_tensor,fut_frames_tensor


def feat_compose(source_feat,sim_matrix):
    '''

    :param source_feat: previous feats at time t, (B,c,h,w)
    :param sim_matrix: composition guide of time t to future t' (B,h,w,h,w)
    :return: fut_feat: composed feats for time t' (B,c,h,w)
    '''

    B,c,h,w = source_feat.shape
    source_feat = source_feat.reshape(B,c,h*w)
    sim_matrix = sim_matrix.reshape(B,h*w,h*w).permute(0,2,1)
    fut_feat = torch.bmm(source_feat,sim_matrix)
    fut_feat = fut_feat.reshape(B,c,h,w)

    return fut_feat

def get_feature_patch(feat,kernel):
    #input: B,T,C,H,W
    B,T,C,H,W = feat.shape
    h_w,w_w = kernel
    
    feat = feat.reshape(B*T,C,H,W)
    pad_h = h_w//2
    pad_w = w_w//2
    pad_feat = F.pad(feat, (pad_w, pad_w, pad_h, pad_h ), mode='constant')

    patches = F.unfold(pad_feat, kernel_size=kernel).view(B,T, C, h_w,w_w,H,W).permute(0,1,2,5,6,3,4).contiguous()

    return patches

def fold_feature_patch(feat,kernel):
    #input: B,h,w,ws,ws,c
    
    B,h,w,h_w,w_w,c = feat.shape
    pad_h = h_w//2
    pad_w = w_w//2
    feat = feat.permute(0,5,3,4,1,2).reshape(B,c*h_w*w_w,h*w)
    weight  = torch.ones_like(feat)
    feature_map = F.fold(feat, kernel_size=kernel,output_size=(h+pad_h*2,w+pad_w*2))
    weight = F.fold(weight, kernel_size=kernel,output_size=(h+pad_h//2*2,w+pad_w//2*2))
    feature_map = feature_map[:,:,pad_h:-pad_h,pad_w:-pad_w]
    weight = weight[:,:,pad_h:-pad_h,pad_w:-pad_w]
    feature_map /= weight

    return feature_map

def motion_node_indices(config,B=1,T=None):
    if T is None:
        T = config['prev_len']
    res_x = config['mat_size'][-1][0]
    res_y = config['mat_size'][-1][1]
    ts = torch.linspace(0,T-1,steps=T)
    ys = torch.linspace(0, res_y-1, steps=res_y)
    xs = torch.linspace(0, res_x-1, steps=res_x)
    t, x, y = torch.meshgrid(ts, xs, ys)
    # single_graph = torch.stack([x.float()/(float(res_x)-1.),y.float()/(float(res_y)-1.),t],dim=-1).to(config['device']) #normalized index
    single_graph = torch.stack([x.float(),y.float(),t],dim=-1).to(config['device']) #normalized index

    node_indices = single_graph.unsqueeze(0).repeat([B,1,1,1,1]).reshape(B,T,-1,1,3).repeat([1,1,1,config['edge_num'],1])

    # node_indices[...,2] = 0
    
    return node_indices

def motion_node_indices_upsample(config,B=1,T=None):
    
    if T is None:
        T = config['prev_len']
    ts = torch.linspace(0,T-1,steps=T)
    node_indices_list = []
    res_x = config['in_res'][0] // 2 if config['shuffle_scale'] == 2 else config['in_res'][0]
    res_y = config['in_res'][1] // 2 if config['shuffle_scale'] == 2 else config['in_res'][1]

    res_x_list = [res_x // (2**i) for i in range(len(config['downsample_scale'])+1)]
    res_y_list = [res_y // (2**i) for i in range(len(config['downsample_scale'])+1)]
    for i in range(len(res_x_list)):
        res_x = res_x_list[i]
        res_y = res_y_list[i]
        ys = torch.linspace(0, res_y-1, steps=res_y)
        xs = torch.linspace(0, res_x-1, steps=res_x)
        t, x, y = torch.meshgrid(ts, xs, ys)
        # single_graph = torch.stack([x.float()/(float(res_x)-1.),y.float()/(float(res_y)-1.),t],dim=-1).to(config['device']) #normalized index
        single_graph = torch.stack([x.float(),y.float(),t],dim=-1).to(config['device']) #normalized index
        node_indices = single_graph.unsqueeze(0).repeat([B,1,1,1,1]).reshape(B,T,-1,1,3).repeat([1,1,1,config['out_edge_num'],1])
        node_indices_list.append(node_indices.clone().to(config['device']))
    # node_indices[...,2] = 0
    
    return node_indices_list

def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
    """Warp an image or feature map with optical flow
    Args:
        x (Tensor): size (B, C, H, W)
        flow (Tensor): size (N, H, W, 2), normal value
        interp_mode (str): 'nearest' or 'bilinear'
        padding_mode (str): 'zeros' or 'border' or 'reflection'

    Returns:
        Tensor: warped image or feature map
    """
    assert x.size()[-2:] == flow.size()[1:3]
    B, C, H, W = x.size()
    # mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
    grid.requires_grad = False
    grid = grid.type_as(x)
    vgrid = grid + flow
    # scale grid to [-1,1]
    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
    output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
    return output 




def build_similarity_matrix(emb_feats,config,mat_type='forward',window_mask=False):
    '''

    :param emb_feats: a sequence of embeddings for every frame (N,T,c,h,w)
    :return: similarity matrix (N, T-1, h*w, h*w) current frame --> next frame
    '''
    B,T,c,h,w = emb_feats.shape
    emb_feats = emb_feats.permute(0,1,3,4,2) #  (B,T,h,w,c)
    normalize_feats = emb_feats.clone() / (torch.norm(emb_feats.clone(),dim=-1,keepdim=True)+1e-6) #  (B,T,h,w,c)

    if mat_type == 'spatial':
        cur_frame = normalize_feats[:,:config['prev_len']].clone().reshape(-1,h*w,c) # (B*(T-1),h*w,c)
        similarity_matrix = torch.bmm(cur_frame,cur_frame.clone().permute(0,2,1)).reshape(B,config['prev_len'],h*w,h*w)
    elif mat_type == 'forward':
        prev_frame = normalize_feats[:,:(config['prev_len']-1)].reshape(-1,h*w,c) # (B*(T-1),h*w,c)
        next_frame = normalize_feats[:,1:config['prev_len']].reshape(-1,h*w,c) # (B*(T-1),h*w,c)                                                 
        similarity_matrix = torch.einsum('bij,bjk->bik', (prev_frame,next_frame.permute(0,2,1))).reshape(B,-1,h*w,h*w)
    elif mat_type == 'backward':
        prev_frame = normalize_feats[:,:(config['prev_len']-1)].reshape(-1,h*w,c) # (B*(T-1),h*w,c)
        next_frame = normalize_feats[:,1:config['prev_len']].reshape(-1,h*w,c) # (B*(T-1),h*w,c)                                                      
        similarity_matrix = torch.einsum('bij,bjk->bik', (next_frame,prev_frame.permute(0,2,1))).reshape(B,config['prev_len']-1,h*w,h*w)
    elif mat_type == 'gt_graph':
        exist_frame = normalize_feats[:,:config['prev_len']].reshape(B,config['prev_len'],1,h*w,c).repeat((1,1,config['fut_len'],1,1)).reshape(-1,h*w,c)
        gt_frame = normalize_feats[:,config['prev_len']:].reshape(B,1,config['fut_len'],h*w,c).repeat((1,config['prev_len'],1,1,1)).reshape(-1,h*w,c)                                                      
        similarity_matrix = torch.einsum('bij,bjk->bik', (exist_frame,gt_frame.permute(0,2,1))).reshape(B,config['prev_len'] * config['fut_len'],h*w,h*w)

        # diagonal_mask = torch.zeros(h*w,h*w).to(similarity_matrix.device).bool() #(h*w,h*w)
        # diagonal_mask.fill_diagonal_(True)
        # diagonal_mask = diagonal_mask.reshape(1,1,h*w,h*w).repeat(B,config['prev_len'] * config['fut_len'],1,1).bool()
        # similarity_matrix[diagonal_mask]  = 1.

    similarity_matrix[similarity_matrix<0] = 0.
    return similarity_matrix


def build_similarity_matrix_masked(emb_feats,config,mat_type='forward',window_mask=False):
    '''

    :param emb_feats: a sequence of embeddings for every frame (N,T,c,h,w)
    :return: similarity matrix (N, T-1, h*w, h*w) current frame --> next frame
    '''
    B,T,c,h,w = emb_feats.shape
    emb_feats = emb_feats.permute(0,1,3,4,2) #  (B,T,h,w,c)
    normalize_feats = emb_feats.clone() / (torch.norm(emb_feats.clone(),dim=-1,keepdim=True)+1e-6) #  (B,T,h,w,c)
    
    if mat_type == 'spatial':
        cur_frame = normalize_feats[:,:config['prev_len']].clone().reshape(-1,h*w,c) # (B*(T-1),h*w,c) 
        window_mask = window_mask[:1,0,].clone().repeat(cur_frame.shape[0],1,1)
        cur_frame_masked = cur_frame.clone().unsqueeze(1).repeat(1,h*w,1,1)[window_mask]
        print(cur_frame_masked.shape)
        exit()
        
        similarity_matrix = torch.bmm(cur_frame,cur_frame.clone().permute(0,2,1)).reshape(B,config['prev_len'],h*w,h*w)
    elif mat_type == 'forward':
        prev_frame = normalize_feats[:,:(config['prev_len']-1)].reshape(-1,h*w,c) # (B*(T-1),h*w,c)
        next_frame = normalize_feats[:,1:config['prev_len']].reshape(-1,h*w,c) # (B*(T-1),h*w,c)                                                 
        similarity_matrix = torch.einsum('bij,bjk->bik', (prev_frame,next_frame.permute(0,2,1))).reshape(B,-1,h*w,h*w)
    elif mat_type == 'backward':
        prev_frame = normalize_feats[:,:(config['prev_len']-1)].reshape(-1,h*w,c) # (B*(T-1),h*w,c)
        next_frame = normalize_feats[:,1:config['prev_len']].reshape(-1,h*w,c) # (B*(T-1),h*w,c)                                                      
        similarity_matrix = torch.einsum('bij,bjk->bik', (next_frame,prev_frame.permute(0,2,1))).reshape(B,config['prev_len']-1,h*w,h*w)
    elif mat_type == 'gt_graph':
        exist_frame = normalize_feats[:,:config['prev_len']].reshape(B,config['prev_len'],1,h*w,c).repeat((1,1,config['fut_len'],1,1)).reshape(-1,h*w,c)
        gt_frame = normalize_feats[:,config['prev_len']:].reshape(B,1,config['fut_len'],h*w,c).repeat((1,config['prev_len'],1,1,1)).reshape(-1,h*w,c)                                                      
        similarity_matrix = torch.einsum('bij,bjk->bik', (exist_frame,gt_frame.permute(0,2,1))).reshape(B,config['prev_len'] * config['fut_len'],h*w,h*w)

    similarity_matrix[similarity_matrix<0] = 0.
    return similarity_matrix


def build_similarity_matrix_efficient(emb_feats,thre=-1,sigmoid=False,k=-1,cut_off=False):
    '''

    :param emb_feats: a sequence of embeddings for every frame (N,T,c,h,w)
    :return: similarity matrix (N, T-1, h*w, h*w) current frame --> next frame
    '''
    B,T,c,h,w = emb_feats.shape
    emb_feats = emb_feats.permute(0,1,3,4,2) #  (B,T,h,w,c)
    normalize_feats = emb_feats / (torch.norm(emb_feats,dim=-1,keepdim=True)+1e-6) #  (B,T,h,w,c)
    prev_frame = normalize_feats[:,:T-1].reshape(-1,c).unsqueeze(1) # (B*(T-1)*h*w,1,c)
    next_frame = normalize_feats[:,1:].reshape(-1,h*w,c).unsqueeze(1).repeat(1,h*w,1,1).reshape(-1,h*w,c) # (B*(T-1)*h*w,h*w,c)
    similarity_matrix = torch.bmm(prev_frame,next_frame.permute(0,2,1)).reshape(B,T-1,h*w,h*w) # (N*(T-1)*h*w)

    if cut_off:
        similarity_matrix = cut_off_process(similarity_matrix,thre,sigmoid,k)

    return similarity_matrix


def build_img_patch_diff(emb_feats):
    '''

    :param emb_feats: a sequence of embeddings for every frame (N,T,c,h,w)
    :return: img_patch_diff (B, T-1, c, h*w, h*w) current frame - next frame
    '''
    B,T,c,h,w = emb_feats.shape
    prev_frame = emb_feats[:,:-1].clone().reshape(-1,c,h*w).unsqueeze(-1).repeat(1,1,1,h*w) # (B*(T-1),c,h*w,1)
    next_frame = emb_feats[:,1:].clone().reshape(-1,c,h*w).unsqueeze(-2).repeat(1,1,h*w,1) # (B*(T-1),c,1,h*w)
    #img_patch_diff = prev_frame - next_frame
    img_patch_diff = torch.cat([prev_frame,next_frame],dim=1)
    # print(img_patch_diff.shape)
    # exit()

    return img_patch_diff.reshape(B,T-1,c*2,h*w,h*w)


def retrieve_diag(similar_matrix,config):
    B,T,hw,hw = similar_matrix.shape
    h = config['mat_size'][-1][0]
    w = config['mat_size'][-1][1]
    diagonal_mask = torch.zeros(hw,hw).to(similar_matrix.device).bool() #(h*w,h*w)
    diagonal_mask.fill_diagonal_(True)
    diagonal_mask = diagonal_mask.reshape(1,1,hw,hw).repeat(B,T,1,1).bool()
    diag_matrix = similar_matrix[diagonal_mask].reshape(B,T,h,w)

    return diag_matrix

def sim_matrix_softmax(similar_matrix):
    B,T,hw,hw = similar_matrix.shape
    similar_matrix = similar_matrix.reshape(similar_matrix.shape[0],similar_matrix.shape[1],-1)
    similar_matrix = F.softmax(similar_matrix,dim=-1)
    
    return similar_matrix.reshape(B,T,hw,hw)


def transform_graph_edge(config,pred_motion,pred_weight,gt=False):
    #shape: (B,T,HW,K,1),(B,T,HW,K,2)
    B = pred_motion.shape[0]
    T = pred_motion.shape[1]
    hw = pred_motion.shape[2]
    k_num = pred_motion.shape[3]

    pred_motion[...,0] *= (config['mat_size'][-1][0]-1)
    pred_motion[...,1] *= (config['mat_size'][-1][1]-1)
    

    pred_motion_ravel = ravel_index((pred_motion+0.5).long(),config['mat_size'][-1])

    Bs = torch.linspace(0,B-1,steps=B)
    Ts = torch.linspace(0, T-1, steps=T)
    HWs = torch.linspace(0, hw-1, steps=hw)
    ks = torch.linspace(0, k_num-1, steps=k_num)

    b,t,hw,k_num = torch.meshgrid(Bs, Ts, HWs, ks, indexing='ij')
    structure = (torch.stack([b.reshape(-1),t.reshape(-1),hw.reshape(-1),k_num.reshape(-1)],dim=-1)+0.5).long().to(config['device'])
    edge = structure.clone()
    edge[:,-1] = pred_motion_ravel.reshape(-1)
    weight = pred_weight.reshape(-1)

    return edge,weight

def edge_softmax(edge,weight,batch,hw):

    raveled_index = torchist.ravel_multi_index(torch.stack([edge[:,0],edge[:,1],edge[:,3]],dim=-1),shape=(batch,int(torch.max(edge[:,1]+1)),hw))
    # print(torch.unique(raveled_index).shape)
    # exit()
    weight = group_softmax(weight,(raveled_index+0.5).long())

    return weight

def edge_normalize(edge,weight,batch,hw):

    raveled_index = (torchist.ravel_multi_index(torch.stack([edge[:,0],edge[:,1],edge[:,3]],dim=-1),shape=(batch,int(torch.max(edge[:,1]+1)),hw))+0.5).long()
    weight_sum = geo_scatter(weight,raveled_index,reduce='sum')
    
    div_weight = weight_sum[raveled_index]
    
    weight /= (div_weight + 1e-6)

    return weight


def build_graph_edge(similarity_matrix,config,gt=False,coord_cuda=None):

    thre = config['sim_thre']
    B = similarity_matrix.shape[0]
    T_prime = similarity_matrix.shape[1]
    hw = similarity_matrix.shape[2]
    new_similarity_matrix = similarity_matrix.clone()
    select_num = max(config['edge_num'],1)

    top_k,indices = torch.topk(new_similarity_matrix,select_num,dim=-1) #shape: B,T,HW,K

    if gt:
        diag_element = retrieve_diag(similarity_matrix,config)
        weight_map = torch.min(diag_element,dim=1)[0] * (-1.) + 1.

    else:
        weight_map = None

    B = top_k.shape[0]
    T = top_k.shape[1]
    hw = top_k.shape[2]
    k_num = top_k.shape[3]

    Bs = torch.linspace(0,B-1,steps=B).to(config['device'])
    Ts = torch.linspace(0, T-1, steps=T).to(config['device'])
    HWs = torch.linspace(0, hw-1, steps=hw).to(config['device'])
    ks = torch.linspace(0, k_num-1, steps=k_num).to(config['device'])


    b,t,hw,k_num = torch.meshgrid(Bs, Ts, HWs, ks, indexing='ij')

    structure = (torch.stack([b.reshape(-1).to(config['device']),t.reshape(-1).to(config['device']),hw.reshape(-1).to(config['device']),k_num.reshape(-1).to(config['device'])],dim=-1)+0.5).long().to(config['device'])
    edge = structure.clone()
    

    unraveled_indices = unravel_index(indices.clone(),config['mat_size'][-1],coord_cuda)

    motion_gt = torch.cat([unraveled_indices.float().to(config['device']),top_k.clone().unsqueeze(-1)],dim=-1) #B,T,HW,K,3 (x,y,weight)
    
    edge[:,-1] = indices.reshape(-1)
    weight = top_k.reshape(-1)
    
    
    return edge,weight,motion_gt,weight_map

def build_graph_edge_masked(similarity_matrix,config,gt=False,coord_cuda=None,matrix_mask=None,indices_lookup = None):

    thre = config['sim_thre']
    B = similarity_matrix.shape[0]
    T_prime = similarity_matrix.shape[1]
    hw = similarity_matrix.shape[2]
    new_similarity_matrix = similarity_matrix.clone()
    select_num = max(config['edge_num'],1)
    matrix_mask =matrix_mask[:B].clone()
    indices_lookup = indices_lookup[:B].clone()
    
    masked_similarity = new_similarity_matrix[matrix_mask].reshape(B,T_prime,hw,-1)
    top_k,indices = torch.topk(masked_similarity,select_num,dim=-1) #shape: B,T,HW,K
    

    B = top_k.shape[0]
    T = top_k.shape[1]
    hw = top_k.shape[2]
    k_num = top_k.shape[3]


    Bs = torch.linspace(0,B-1,steps=B).to(config['device'])
    Ts = torch.linspace(0, T-1, steps=T).to(config['device'])
    HWs = torch.linspace(0, hw-1, steps=hw).to(config['device'])
    ks = torch.linspace(0, k_num-1, steps=k_num).to(config['device'])
    
    b,t,hw,k_num = torch.meshgrid(Bs, Ts, HWs, ks, indexing='ij')
    indices = (indices_lookup[b.reshape(-1).long(),t.reshape(-1).long(),hw.reshape(-1).long(),indices.reshape(-1)]).reshape(B,T,-1,select_num)

    structure = (torch.stack([b.reshape(-1).to(config['device']),t.reshape(-1).to(config['device']),hw.reshape(-1).to(config['device']),k_num.reshape(-1).to(config['device'])],dim=-1)+0.5).long().to(config['device'])
    edge = structure.clone()
    
    unraveled_indices = unravel_index(indices.clone(),config['mat_size'][-1],coord_cuda)

    motion_gt = torch.cat([unraveled_indices.float().to(config['device']),top_k.clone().unsqueeze(-1)],dim=-1) #B,T,HW,K,3 (x,y,weight)
    
    edge[:,-1] = indices.reshape(-1)
    weight = top_k.reshape(-1)
    
    
    return edge,weight,motion_gt,None

# def multi_warp(img, flow):
#     '''
#     img: B,T,C,H,W
#     flow: B,T,HW,K,3
#     '''
#     B,T,C,H,W = img.shape
#     K = flow.shape[-2]

    
#     img = img.unsqueeze(2).repeat(1,1,K,1,1,1) #B,T,K,C,H,W
#     flow = flow.reshape(B,T,H,W,K,3).permute(0,1,4,5,2,3)


#     flow = flow.reshape(B*T*K,3,H,W)

#     img = img.reshape(B*T*K,C,H,W)
#     weight = flow[:,-1:] # BTK,1,H,W

#     BTK = B*T*K

#     xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(BTK, -1, H, -1)
#     yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(BTK, -1, -1, W)
#     grid = torch.cat([xx, yy], 1).to(img)

#     flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)
    
#     grid_ = (grid + flow_).permute(0, 2, 3, 1)
    

#     output = F.grid_sample(input=img * weight, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) # BTK, C, H, W
#     weight_map = F.grid_sample(input=torch.ones_like(img[:,-1:]) * weight, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
    
#     output = output.reshape(-1,T*K,C,H,W)
#     weight_map = weight_map.reshape(-1,T*K,1,H,W)
#     weight_map = torch.sum(weight_map,dim=1)
#     output = torch.sum(output,dim=1)
#     output /= (weight_map + 1e-6)

#     return output.reshape(-1,C,H,W)

def multi_warp(img, flow,last_only=False):
    '''
    img: B,T,C,H,W
    flow: B,T,HW,K,3
    '''
    B,T,C,H,W = img.shape
    K = flow.shape[-2]
    img = img.unsqueeze(2).repeat(1,1,K,1,1,1) #B,T,K,C,H,W
    flow = flow.reshape(B,T,H,W,K,3).permute(0,1,4,5,2,3)
    
    if last_only:
        img = img[:,-1:]
        flow = flow[:,-1:]
        T = 1

    
    flow = flow.reshape(B*T*K,3,H,W)
    weight = flow[:,-1:] # BTK,1,H,W
    img = img.reshape(B*T*K,C,H,W)
    
    flow = flow[:,:2]
    flow = torch.cat([flow[:,1:2],flow[:,:1]],dim=1)

    
    
    # Customized Normalization Splatting
    splat_source = torch.cat([img,torch.ones_like(img[:,-1:])],dim=1) * weight
    BTK = B*T*K
    output = softsplat.softsplat(tenIn=splat_source, tenFlow=flow[:,:2], tenMetric=None, strMode='sum')
    weight_map = output[:,-1:]
    output = output[:,:-1]

    output = output.reshape(-1,T*K,C,H,W)
    weight_map = weight_map.reshape(-1,T*K,1,H,W)
    weight_map = torch.sum(weight_map,dim=1)
    output = torch.sum(output,dim=1)
    output /= (weight_map + 1e-6)

    return output.reshape(-1,C,H,W)

def check_folder(folder_path):
    if not os.path.exists(folder_path):
        os.mkdir(folder_path)
    return folder_path




class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f',is_val=False):
        self.name = name
        self.fmt = fmt
        self.is_val = is_val
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        #fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        fmtstr = '{name} {avg' + self.fmt + '}'
        if self.is_val:
            fmtstr = 'val {name} {avg' + self.fmt + '}'

        return fmtstr.format(**self.__dict__)


def metric_print(metrics,epoch_num,step_num,t,last_iter=False):
    '''

    :param metrics: metric list
    :param epoch_num: epoch number
    :param step_num: step number
    :param t: time duration
    :return: string
    '''
    if last_iter:
        base = 'Epoch {} \t'.format(epoch_num)
    else:
        base = 'Epoch {} iter {}\t'.format(epoch_num, step_num)
    for key in metrics.keys():
        base = base + str(metrics[key]) + '\t'
    final = base + 'takes {} s'.format(t)
    return final


def update_metrics(metrics,loss_dict):
    for key in loss_dict.keys():
        metrics[key].update(loss_dict[key])

    return metrics


def img_valid(img):
    img = img + 0.5
    img = img
    img[img < 0] = 0.
    img[img > 1.] = 1.

    return img

def img_clamp(img):

    img[img < 0] = 0.
    img[img > 255] = 255
    if torch.is_tensor(img):
        img = img.cpu().numpy()
    img = img.astype(np.uint8)

    return img

def torch_img_clamp_normalize(img):

    img[img < 0] = 0.
    img[img > 255] = 255
    img /= 255.

    return img

def MAE(true,pred):
    return (np.abs(pred-true)).sum()

def MSE(true,pred):
    return ((pred-true)**2).sum()



def visualization_check(save_path,epoch,image_list,valid=False,is_train=False):
    plt.clf()
    if is_train:
        save_file = save_path + '/visual_check_train_' + str(epoch) + '.png'
    else:
        save_file = save_path + '/visual_check_' + str(epoch) + '.png'
    sample_num = len(image_list)
    h1_image = []   #gt
    h2_image = [] #recon
    
    for i in range(sample_num):
        gt,recon = image_list[i]
        if valid:
            gt = img_clamp(gt)
            recon = img_clamp(recon)
        else:
            gt = img_valid(gt)
            recon = img_valid(recon)
        h1_image.append(gt)
        h2_image.append(recon)
        

    h1_image = np.hstack(h1_image)
    h2_image = np.hstack(h2_image)
    whole_image = np.vstack([h1_image,h2_image])
    plt.imshow(whole_image,interpolation="nearest")
    plt.savefig(save_file, dpi=400, bbox_inches ="tight", pad_inches = 1)

def visualization_check_video(save_path,epoch,image_list,valid=False,is_train=False,matrix=False,config=None,long_term=False):
    plt.clf()
    if is_train:
        save_file = save_path + '/visual_check_train_'
    else:
        save_file = save_path + '/visual_check_'

    if matrix:
        save_file = save_file + 'matrix_'
    save_file = save_file + str(epoch) + '.png'
    sample_num = len(image_list)
    vis_image = []
    dtf_image = []
    diff_prev_image = [] # diff_prev
    diff_gt_image = [] # diff_gt
    for i in range(sample_num):
        gt_seq,recon_seq = image_list[i]
        
        if matrix:
            h= config['mat_size'][0][0]
            w = config['mat_size'][0][1]
            gt_seq = gt_seq.reshape(h,w,h,w)
            recon = recon_seq.reshape(h,w,h,w)
            select_index = [[h//4,w//4],[h//4,w*3//4],[h//2,w//2],[h*3//4,w//4],[h*3//4,w*3//4]]
           # select_index = [[5,5],[5,25],[15,15],[25,5],[25,25]]
            vis_image_row = []
            #dtf_image_row = []
            for index in select_index:
                vis_image_row.append(np.hstack([np.log(gt_seq[index[0],index[1]]),np.log(recon[index[0],index[1]]),np.log(gt_seq[index[0],index[1]]-recon[index[0],index[1]])]))
            vis_image_row = np.hstack(vis_image_row)
            #dtf_image_row = np.hstack(dtf_image_row)
            vis_image.append(vis_image_row)
            #dtf_image.append(dtf_image_row)


        else:
            
           
            if gt_seq.shape[0] > 20 or long_term:

                gt = np.hstack(gt_seq)
                recon = np.hstack(recon_seq)
                if valid:
                    gt = img_clamp(gt)
                    recon = img_clamp(recon)
                else:
                    gt = img_valid(gt)
                    recon = img_valid(recon)
                recon_range = np.zeros_like(gt)
                recon_range[:,-recon.shape[1]:,:] = recon
                recon = recon_range
                vis_image.append(np.vstack([gt,recon]))
                
            else:
                gt = np.hstack(gt_seq[-2:])
            
                recon = np.hstack(recon_seq)
                if valid:
                    gt = img_clamp(gt)
                    recon = img_clamp(recon)
                else:
                    gt = img_valid(gt)
                    recon = img_valid(recon)

            
                overlay = recon.copy()
                gt_frame = img_clamp(gt_seq[-1]) if valid else img_valid(gt_seq[-1])
                last_frame = img_clamp(gt_seq[-2]) if valid else img_valid(gt_seq[-2])
                
                overlay_last = last_frame*.5 + overlay *.5
                overlay_gt = gt_frame*.5 + overlay *.5
                real_overlay = gt_frame *.5 + last_frame *.5 
                if valid:
                    overlay_last = overlay_last.astype(np.uint8)
                    overlay_gt = overlay_gt.astype(np.uint8)
                    real_overlay = real_overlay.astype(np.uint8)
                vis_image.append(np.hstack([gt,recon,overlay_last,overlay_gt,real_overlay]))
        #exit()
    
    whole_image = np.vstack(vis_image)
    if matrix:
        plt.imshow(whole_image,interpolation="nearest",cmap='hot')
    else:
        if whole_image.shape[-1] == 1:
            plt.imshow(whole_image,interpolation="nearest",cmap='gray')
        else:
            plt.imshow(whole_image,interpolation="nearest")
    plt.axis('off')
    plt.savefig(save_file, dpi=400, bbox_inches ="tight", pad_inches = 0)
    '''
    if (not is_train) and (matrix):
        plt.clf()
        dtf_whole_image = np.vstack(dtf_image)
        plt.imshow(dtf_whole_image)
        plt.colorbar()
        dtf_save_file = save_file.replace('check_matrix','check_dtf')
        plt.savefig(dtf_save_file, dpi=400, bbox_inches ="tight", pad_inches = 0)
    '''

def visualization_check_video_testmode(save_path,image_list,valid=False,config=None,iter_id=0):
    plt.clf()
    check_folder(save_path)

    save_file = save_path + '/visual_iter_'

    save_file = save_file + str(iter_id) + '.png'
    prev_seq,gt_seq,recon_seq= image_list
    sample_num = len(prev_seq)
    vis_image = []
    for i in range(sample_num):
        
        last_frame = img_clamp(prev_seq[i][-1]) if valid else img_valid(prev_seq[i][-1])
        gt_frame = img_clamp(gt_seq[i][0]) if valid else img_valid(gt_seq[i][0])
        recon_frame = img_clamp(recon_seq[i][0]) if valid else img_valid(recon_seq[i][0])
        cur_prev = img_clamp(np.hstack(prev_seq[i])) if valid else img_valid(np.hstack(prev_seq[i]))
        cur_gt = img_clamp(np.hstack(gt_seq[i])) if valid else img_valid(np.hstack(gt_seq[i]))
        cur_recon = img_clamp(np.hstack(recon_seq[i])) if valid else img_valid(np.hstack(recon_seq[i]))
        
        
        overlay_last = last_frame*.5 + recon_frame *.5
        overlay_gt = gt_frame*.5 + recon_frame *.5
        real_overlay = gt_frame *.5 + last_frame *.5 
        if valid:
            overlay_last = overlay_last.astype(np.uint8)
            overlay_gt = overlay_gt.astype(np.uint8)
            real_overlay = real_overlay.astype(np.uint8)
        vline = np.zeros((cur_prev.shape[0],20,3)).astype(np.uint8)
        vis_image.append(np.hstack([cur_prev,vline,cur_gt,cur_recon,vline,overlay_last,overlay_gt,real_overlay]))
        #exit()

    whole_image = np.vstack(vis_image)
    if whole_image.shape[-1] == 1 :
        plt.imshow(whole_image,interpolation="nearest",cmap='gray')
    else:
        plt.imshow(whole_image,interpolation="nearest")
    plt.axis('off')
    plt.savefig(save_file, dpi=400, bbox_inches ="tight", pad_inches = 0)




def image_evaluation(image_list,gt_image_list,eval_metrics,valid=False,full_test=False):

    
    size = image_list.shape
    if len(size) > 4:
        image_list = image_list.reshape(size[0]*size[1],size[2],size[3],size[4])
    size = gt_image_list.shape
    if len(size) > 4:
        gt_image_list = gt_image_list.reshape(size[0]*size[1],size[2],size[3],size[4])
    for i in range(image_list.shape[0]):
        if valid:
            image = img_clamp(image_list[i]) /255.
            gt_image = img_clamp(gt_image_list[i]) / 255.
        else:
            image = img_valid(image_list[i])
            gt_image = img_valid(gt_image_list[i])

        gt_image_gpu = gt_image.clone()
        images_gpu = image.clone()

        gt_image = gt_image.cpu().numpy()
        image = image.cpu().numpy()
        '''
        # if save_image:
        val = ms_ssim(X=images_gpu.permute(2,0,1).unsqueeze(0), Y=gt_image_gpu.permute(2,0,1).unsqueeze(0),data_range=1.0, size_average=True)
        if val < 0.4:
            print(val)
        if val<0.4:
            cv2.imwrite('/mnt/team-luming-mount/t-yiqizhong/Summer2023/video_prediction/results/kitti/ours_graph/pred_EgdeNum=8_16_0_32_4_t100_att=1_reconlpips_3resMDecoder_outEdge8_seqBugFix_10-30_00-23/gt_' + str(i)+'.png',np.clip(gt_image*255, 0, 255).astype(np.uint8))
            cv2.imwrite('/mnt/team-luming-mount/t-yiqizhong/Summer2023/video_prediction/results/kitti/ours_graph/pred_EgdeNum=8_16_0_32_4_t100_att=1_reconlpips_3resMDecoder_outEdge8_seqBugFix_10-30_00-23/pred_' + str(i)+'.png',np.clip(image*255, 0, 255).astype(np.uint8))
        '''
        
        for key in eval_metrics:
            
            if key == 'psnr':
                #metrics_to_save.append(psnr(gt_image.copy(),image.copy()))
                eval_metrics[key].update(psnr(gt_image.copy(),image.copy()))
            if key == 'psnr_y':
                #metrics_to_save.append(psnr(gt_image.copy(),image.copy()))
                eval_metrics[key].update(psnr(rgb2yuv(gt_image.copy())[:,:,:1],rgb2yuv(image.copy()))[:,:,:1])
            elif key == 'ssim':
                eval_metrics[key].update(ssim(gt_image.copy(),image.copy(),multichannel=True))
            elif key == 'mae':
                eval_metrics[key].update(MAE(gt_image.copy(),image.copy()))
            elif key == 'mse':
                eval_metrics[key].update(MSE(gt_image.copy(),image.copy()))
            elif key == 'lpips' and full_test:
                eval_metrics[key].update(lpips(images_gpu.permute(2,0,1).unsqueeze(0).cuda(),gt_image_gpu.permute(2,0,1).unsqueeze(0).cuda()).item())
            elif key == 'ms_ssim' and full_test:
                eval_metrics[key].update(ms_ssim(X=images_gpu.permute(2,0,1).unsqueeze(0), Y=gt_image_gpu.permute(2,0,1).unsqueeze(0),data_range=1.0, size_average=True))
    return eval_metrics

class VGG_feature(nn.Module):
    def __init__(self,device):
        super(VGG_feature, self).__init__()
        vgg16 = torchvision.models.vgg16(pretrained=True).to(device)

        self.vgg16_conv_4_3 = torch.nn.Sequential(*list(vgg16.children())[0][:22])
        for param in self.vgg16_conv_4_3.parameters():
            param.requires_grad = False

    def forward(self, input):
        '''
        input: B,T,c,H,W
        output: B,T,c,H//8,W//8
        '''
        B,T,C,H,W = input.shape
        input = input.reshape(-1,C,H,W)
        with torch.no_grad():
            vgg_output = self.vgg16_conv_4_3(input.clone())
        X,c,h,w = vgg_output.shape
        vgg_output = vgg_output.reshape(B,T,c,h,w)

        return vgg_output


class VGG_loss(nn.Module):
    def __init__(self,device,reduction='mean'):
        super(VGG_loss, self).__init__()
        vgg16 = torchvision.models.vgg16(pretrained=True).to(device)
        self.reduction = reduction
        self.vgg16_conv_4_3 = torch.nn.Sequential(*list(vgg16.children())[0][:22])
        for param in self.vgg16_conv_4_3.parameters():
            param.requires_grad = False

    def forward(self, output, gt,norm=False):
        
        if len(output.shape) > 4:
            output = output.reshape(-1,output.shape[-3],output.shape[-2],output.shape[-1])
        if len(gt.shape) > 4:
            gt = gt.reshape(-1,output.shape[-3],output.shape[-2],output.shape[-1])
        if output.shape[1] != 3:

            output = output.permute(0,3,1,2)
            gt = gt.permute(0,3,1,2)
        if not norm:
            output = torch_img_clamp_normalize(output)
            gt = torch_img_clamp_normalize(gt)
        else:
            gt += 0.5
            output += 0.5


        vgg_output = self.vgg16_conv_4_3(output.clone())
        with torch.no_grad():
            vgg_gt = self.vgg16_conv_4_3(gt.detach())
        
        if self.reduction == 'sum':
            loss =torch.sum(F.mse_loss(vgg_output, vgg_gt,reduction='sum')) / (output.shape[0])
        else:
            loss =F.mse_loss(vgg_output, vgg_gt,reduction='mean')


        return loss

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, rank=0):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        pretrained = True
        self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
        self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, X, Y, indices=None):
        X = self.normalize(X)
        Y = self.normalize(Y)
        indices = [2, 7, 12, 21, 30]
        weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
        k = 0
        loss = 0
        for i in range(indices[-1]):
            X = self.vgg_pretrained_features[i](X)
            Y = self.vgg_pretrained_features[i](Y)
            if (i+1) in indices:
                loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
                k += 1
        return loss




class CharbonnierLoss(nn.Module):
    """Charbonnier Loss"""

    def __init__(self, eps=1e-6,reduction='sum'):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps
        self.reduction = reduction

    def forward(self, x, y,weight=None,reduction='sum'):
        diff = (x - y)
        if self.reduction == 'sum':
            if weight is None:
                loss = torch.sum(torch.sqrt(diff * diff + self.eps))
            else:
                loss = torch.sum(torch.sqrt(diff * diff + self.eps)*weight)
            
            loss /= x.shape[0] #batch mean
        elif self.reduction == 'mean':
            if weight is None:
                loss = torch.mean(torch.sqrt(diff * diff + self.eps)) 
            else:
                loss = torch.mean(torch.sqrt(diff * diff + self.eps)*weight)

        
        return loss

class JSDLoss(nn.Module):
    """JSD Loss"""

    def __init__(self, weight=1.):
        super(JSDLoss, self).__init__()
        self.weight = weight

    def forward(self, feat_1,feat_2):
        c = feat_1.shape[1]
        BT = feat_1.shape[0]
        feat_1 = feat_1.permute(0,2,3,1).reshape(-1,c)
        feat_2 = feat_2.permute(0,2,3,1).reshape(-1,c)
        feat_1 = F.softmax(feat_1,dim=-1)
        feat_2 = F.softmax(feat_2,dim=-1)
        p_mixture = torch.clamp((feat_1 + feat_2) / 2., 1e-7, 1).log()
        loss = self.weight * (F.kl_div(p_mixture, feat_1, reduction='batchmean') +
        F.kl_div(p_mixture, feat_2, reduction='batchmean')) / 2. * BT

        #print(loss)
                
        return loss


class CosineAnnealingLR_Restart(_LRScheduler):
    def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=1e-5, last_epoch=-1,ratio=0.5):
        self.T_period = list(T_period)
        self.T_max = self.T_period[0]  # current T period
        self.eta_min = eta_min
        self.restarts = list(restarts)
        self.restart_weights = [ratio ** (i+1) for i in range(len(restarts))]
        self.last_restart = 0
        print('restart ratio: ',ratio,' T_period: ',T_period,' minimum lr: ',eta_min)
        assert len(self.restarts) == len(
            self.restart_weights), 'restarts and their weights do not match.'
        super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lrs
        elif self.last_epoch in self.restarts:
            self.last_restart = self.last_epoch
            self.T_max = self.T_period[list(self.restarts).index(self.last_epoch) + 1]
            weight = self.restart_weights[list(self.restarts).index(self.last_epoch)]
            return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
        elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
            return [
                group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]
        return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
                (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
                (group['lr'] - self.eta_min) + self.eta_min
                for group in self.optimizer.param_groups]


def get_local_weights(residual, ksize):
    pad = (ksize - 1) // 2
    residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')

    unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
    pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)

    return pixel_level_weight

def get_refined_artifact_map(img_gt, img_output, ksize):
    # residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
    residual = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
    patch_level_weight = torch.var(residual.clone(), dim=(-1, -2, -3), keepdim=True) ** (1/5)
    pixel_level_weight = get_local_weights(residual.clone(), ksize)
    overall_weight = patch_level_weight * pixel_level_weight
    # overall_weight[residual_SR < residual_ema] = 0

    return overall_weight

class L1Loss(nn.Module):
    """L1 (mean absolute error, MAE) loss.

    Args:
        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
    """

    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(L1Loss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')

        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        """
        Args:
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
                weights. Default: None.
        """
        return self.loss_weight * F.l1_loss(pred, target, weight, reduction=self.reduction)

class TVLoss(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(TVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:,:,1:,:])
        count_w = self._tensor_size(x[:,:,:,1:])
        h_tv = torch.pow(((x[:,:,1:,:]-x[:,:,:h_x-1,:]).abs()+1e-8),2).sum()
        w_tv = torch.pow(((x[:,:,:,1:]-x[:,:,:,:w_x-1]).abs()+1e-8),2).sum()
        if torch.sum(torch.isnan(x))>0:
            print('Caught NaN!!!!')
            exit()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]


class LapLoss(torch.nn.Module):
    def __init__(self, max_levels=5, channels=3):
        super(LapLoss, self).__init__()
        self.max_levels = max_levels
        self.gauss_kernel = gauss_kernel(channels=channels)
        
    def forward(self, input, target):
        pyr_input  = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
        pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
        return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))

def gauss_kernel(size=5, channels=3):
    kernel = torch.tensor([[1., 4., 6., 4., 1],
                           [4., 16., 24., 16., 4.],
                           [6., 24., 36., 24., 6.],
                           [4., 16., 24., 16., 4.],
                           [1., 4., 6., 4., 1.]])
    kernel /= 256.
    kernel = kernel.repeat(channels, 1, 1, 1)
    kernel = kernel.to(device)
    return kernel

def laplacian_pyramid(img, kernel, max_levels=3):
    current = img
    pyr = []
    for level in range(max_levels):
        filtered = conv_gauss(current, kernel)
        down = downsample(filtered)
        up = upsample(down)
        diff = current-up
        pyr.append(diff)
        current = down
    return pyr

def conv_gauss(img, kernel):
    img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
    out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
    return out

def downsample(x):
    return x[:, :, ::2, ::2]

def upsample(x):
    cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
    cc = cc.permute(0,1,3,2)
    cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
    x_up = cc.permute(0,1,3,2)
    return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1]))
